Skip to content

Commit

Permalink
fix: fix and test
Browse files Browse the repository at this point in the history
  • Loading branch information
opendeeple committed Dec 17, 2024
1 parent e60cc9d commit bd9a22e
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 65 deletions.
Binary file added src/.DS_Store
Binary file not shown.
15 changes: 6 additions & 9 deletions src/anthropic/index.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import Anthropic from "@anthropic-ai/sdk";
import { omit } from "../utils";
import {
AnthropicChatCompletationParams,
FetchAIChatCompletation,
} from "../type";
import { omit } from "../utils/index";
import { AnthropicChatCompletionParams, FetchAIChatCompletion } from "../type";

export default class AnthropicChatRepository {
constructor(readonly provider: Anthropic) {}

async create(body: AnthropicChatCompletationParams) {
async create(body: AnthropicChatCompletionParams) {
try {
const completion = await this.provider.messages.create(
omit(body, ["prediction_tokens"])
Expand All @@ -18,7 +15,7 @@ export default class AnthropicChatRepository {
content.type == "text"
? content.text
: `${content.name}(id=${content.id}, value=${content.input})`;
const result: FetchAIChatCompletation = {
const result: FetchAIChatCompletion = {
provider: "Anthropic",
success: true,
prediction: {
Expand All @@ -39,7 +36,7 @@ export default class AnthropicChatRepository {
return result;
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
const result: FetchAIChatCompletation = {
const result: FetchAIChatCompletion = {
provider: "Anthropic",
success: false,
error: message,
Expand All @@ -54,7 +51,7 @@ export default class AnthropicChatRepository {
}
}

async countTokens(body: AnthropicChatCompletationParams) {
async countTokens(body: AnthropicChatCompletionParams) {
const response = await this.provider.beta.messages.countTokens({
betas: ["token-counting-2024-11-01"],
...omit(body, ["max_tokens"]),
Expand Down
37 changes: 28 additions & 9 deletions src/fetchai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ import Anthropic from "@anthropic-ai/sdk";
import AnthropicChatRepository from "../anthropic";
import OpenAIChatRepository from "../openai";
import {
AnthropicChatCompletationParams,
AnthropicChatCompletionParams,
ClientOptions,
FetchAIChatCompletationParams,
FetchAIChatCompletionParams,
FetchAIChatCount,
FetchAIChatModels,
FetchAIProviders,
} from "../type";

export default class FetchAIChatRepository {
Expand All @@ -19,23 +22,39 @@ export default class FetchAIChatRepository {
);
}

private isAnthropicModel(model: FetchAIChatModels) {
return /^claude/.test(model);
}

private isAnthropic(
body: FetchAIChatCompletationParams
): body is AnthropicChatCompletationParams {
return /^claude/.test(body.model);
body: FetchAIChatCompletionParams
): body is AnthropicChatCompletionParams {
return this.isAnthropicModel(body.model);
}

async create(body: FetchAIChatCompletationParams) {
async create(body: FetchAIChatCompletionParams) {
if (this.isAnthropic(body)) {
return this.Anthropic.create(body);
}
return this.OpenAI.create(body);
}

async countTokens(body: FetchAIChatCompletationParams) {
async countTokens(
body: FetchAIChatCompletionParams
): Promise<FetchAIChatCount> {
if (this.isAnthropic(body)) {
return this.Anthropic.countTokens(body);
return {
provider: "Anthropic",
input_tokens: await this.Anthropic.countTokens(body),
};
}
return this.OpenAI.countTokens(body);
return {
provider: "OpenAI",
input_tokens: await this.OpenAI.countTokens(body),
};
}

provider(model: FetchAIChatModels): FetchAIProviders {
return this.isAnthropicModel(model) ? "Anthropic" : "OpenAI";
}
}
6 changes: 3 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import AnthropicChatRepository from "./anthropic";
import FetchAIChatRepository from "./fetchai";
import OpenAIChatRepository from "./openai";
import AnthropicChatRepository from "./anthropic/index";
import FetchAIChatRepository from "./fetchai/index";
import OpenAIChatRepository from "./openai/index";
export { OpenAIChatRepository, AnthropicChatRepository };
export default FetchAIChatRepository;
30 changes: 13 additions & 17 deletions src/openai/index.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import OpenAI from "openai";
import { encoding_for_model } from "tiktoken";
import {
FetchAIChatCompletation,
OpenAIChatCompletationParams,
OpenAIChatMessageParams,
} from "../type";
import { FetchAIChatCompletion, OpenAIChatCompletionParams } from "../type";

export default class OpenAIChatRepository {
constructor(readonly provider: OpenAI) {}

async create(body: OpenAIChatCompletationParams) {
async create(body: OpenAIChatCompletionParams) {
const prediction = body.prediction;
try {
const completion = await this.provider.chat.completions.create({
Expand All @@ -23,7 +19,7 @@ export default class OpenAIChatRepository {
}
: undefined,
});
const result: FetchAIChatCompletation = {
const result: FetchAIChatCompletion = {
provider: "OpenAI",
success: true,
prediction: {
Expand All @@ -43,7 +39,7 @@ export default class OpenAIChatRepository {
return result;
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
const result: FetchAIChatCompletation = {
const result: FetchAIChatCompletion = {
provider: "OpenAI",
success: false,
error: message,
Expand All @@ -58,10 +54,10 @@ export default class OpenAIChatRepository {
}
}

async countTokens(body: OpenAIChatCompletationParams) {
async countTokens(body: OpenAIChatCompletionParams) {
const modelEncoding = encoding_for_model(body.model);

let tokens = modelEncoding.encode(body.system).length;
let tokens = body.system ? modelEncoding.encode(body.system).length : 0;

const messages = body.messages;
for (const message of messages) {
Expand All @@ -80,14 +76,14 @@ export default class OpenAIChatRepository {
return tokens;
}

private buildMessages(body: OpenAIChatCompletationParams) {
const messages: OpenAIChatMessageParams[] = [
{
private buildMessages(body: OpenAIChatCompletionParams) {
const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [];
if (body.system) {
messages.push({
role: "system",
content: body.system,
},
...body.messages,
];
return messages;
});
}
return [...messages, ...body.messages];
}
}
63 changes: 36 additions & 27 deletions src/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@ import { TiktokenModel } from "tiktoken";

export type Maybe<T> = T | null | undefined;

export interface FetchAIChatMessageParams {
role: "assistant" | "user";
content: string;
}

export type AnthropicChatModels = Anthropic.Messages.Model;
export type AnthropicChatMessageParams = Anthropic.Messages.MessageParam;
export type AnthropicChatMessageParams = FetchAIChatMessageParams;

export type AnthropicChatCompletionParams = {
model: AnthropicChatModels;
system?: string; // Optional system message
system?: string;
messages: AnthropicChatMessageParams[];
prediction?: string;
max_tokens: number;
prediction_tokens?: number;
};

export type OpenAIChatModels = TiktokenModel;
export type OpenAIChatMessageParams = OpenAI.Chat.ChatCompletionMessageParam;
export type OpenAIChatMessageParams = FetchAIChatMessageParams;

export type OpenAIChatCompletionParams = {
model: OpenAIChatModels;
system?: string; // Optional system message
system?: string;
messages: OpenAIChatMessageParams[];
prediction?: string;
max_tokens: number;
Expand All @@ -34,36 +40,39 @@ export interface ClientOptions {
anthropic?: AnthropicClientOptions;
}

export type FetchAIChatMessageParams =
| (OpenAIChatMessageParams & {
role: "system" | "user" | "assistant";
})
| (AnthropicChatMessageParams & {
role: "user" | "assistant";
});

export type FetchAIChatModels = OpenAIChatModels | AnthropicChatModels;

export type FetchAIChatCompletionParams =
| OpenAIChatCompletionParams
| AnthropicChatCompletionParams;

export type FetchAIProviders = "OpenAI" | "Anthropic";

export interface FetchAIChatPrediction {
id?: string;
content: string;
openai?: OpenAI.Chat.Completions.ChatCompletion;
anthropic?: Omit<Anthropic.Messages.Message, "id">;
}

export interface FetchAIUsage {
input_tokens?: number;
output_tokens?: number;
total_tokens?: number;
predicted_tokens: number;
openai?: OpenAI.Completions.CompletionUsage;
anthropic?: Anthropic.Messages.Usage;
}

export interface FetchAIChatCompletion {
provider: "OpenAI" | "Anthropic";
provider: FetchAIProviders;
success: boolean;
error?: string;
prediction: {
id?: string;
content: string;
openai?: OpenAI.Chat.Completions.ChatCompletion;
anthropic?: Omit<Anthropic.Messages.Message, "id">;
};
usage: {
input_tokens?: number;
output_tokens?: number;
total_tokens?: number;
predicted_tokens: number;
openai?: OpenAI.Completions.CompletionUsage;
anthropic?: Anthropic.Messages.Usage;
};
prediction: FetchAIChatPrediction;
usage: FetchAIUsage;
}

export interface FetchAIChatCount {
provider: FetchAIProviders;
input_tokens: number;
}

0 comments on commit bd9a22e

Please sign in to comment.