Skip to content

Commit

Permalink
adapt to new model structure
Browse files Browse the repository at this point in the history
  • Loading branch information
tcm390 committed Jan 4, 2025
1 parent 2af9a09 commit f736ba3
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 25 deletions.
9 changes: 6 additions & 3 deletions packages/client-discord/src/actions/chat_with_attachments.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { composeContext } from "@elizaos/core";
import { composeContext, getModelSettings } from "@elizaos/core";
import { generateText, trimTokens } from "@elizaos/core";
import { models } from "@elizaos/core";

Check failure on line 3 in packages/client-discord/src/actions/chat_with_attachments.ts

View workflow job for this annotation

GitHub Actions / check

'models' is defined but never used. Allowed unused vars must match /^_/u
import { parseJSONObjectFromText } from "@elizaos/core";
Expand Down Expand Up @@ -185,8 +185,11 @@ const summarizeAction = {

let currentSummary = "";

const model = models[runtime.character.modelProvider];
const chunkSize = model.settings.maxOutputTokens;
const modelSettings = getModelSettings(
runtime.modelProvider,
ModelClass.SMALL
);
const chunkSize = modelSettings.maxOutputTokens;

state.attachmentsWithText = attachmentsWithText;
state.objective = objective;
Expand Down
9 changes: 6 additions & 3 deletions packages/client-discord/src/actions/summarize_conversation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { composeContext } from "@elizaos/core";
import { composeContext, getModelSettings } from "@elizaos/core";
import { generateText, splitChunks, trimTokens } from "@elizaos/core";
import { getActorDetails } from "@elizaos/core";
import { models } from "@elizaos/core";

Check failure on line 4 in packages/client-discord/src/actions/summarize_conversation.ts

View workflow job for this annotation

GitHub Actions / check

'models' is defined but never used. Allowed unused vars must match /^_/u
Expand Down Expand Up @@ -247,8 +247,11 @@ const summarizeAction = {

let currentSummary = "";

const model = models[runtime.character.settings.model];
const chunkSize = model.settings.maxContextLength - 1000;
const modelSettings = getModelSettings(
runtime.modelProvider,
ModelClass.SMALL
);
const chunkSize = modelSettings.maxOutputTokens - 1000;

const chunks = await splitChunks(formattedMemories, chunkSize, 0);

Expand Down
8 changes: 6 additions & 2 deletions packages/client-slack/src/actions/chat_with_attachments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
generateText,
trimTokens,
parseJSONObjectFromText,
getModelSettings,
} from "@elizaos/core";
import { models } from "@elizaos/core";
import {
Expand Down Expand Up @@ -194,8 +195,11 @@ const summarizeAction: Action = {

let currentSummary = "";

const model = models[runtime.character.modelProvider];
const chunkSize = model.settings.maxOutputTokens;
const modelSettings = getModelSettings(
runtime.modelProvider,
ModelClass.SMALL
);
const chunkSize = modelSettings.maxOutputTokens;

currentState.attachmentsWithText = attachmentsWithText;
currentState.objective = objective;
Expand Down
8 changes: 6 additions & 2 deletions packages/client-slack/src/actions/summarize_conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
splitChunks,
trimTokens,
parseJSONObjectFromText,
getModelSettings,
} from "@elizaos/core";
import { models } from "@elizaos/core";
import { getActorDetails } from "@elizaos/core";
Expand Down Expand Up @@ -265,8 +266,11 @@ const summarizeAction: Action = {

let currentSummary = "";

const model = models[runtime.character.modelProvider];
const chunkSize = model.settings.maxOutputTokens;
const modelSettings = getModelSettings(
runtime.modelProvider,
ModelClass.SMALL
);
const chunkSize = modelSettings.maxOutputTokens;

const chunks = await splitChunks(formattedMemories, chunkSize, 0);

Expand Down
8 changes: 4 additions & 4 deletions packages/core/src/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import path from "node:path";
import { models } from "./models.ts";
import { getEndpoint, models } from "./models.ts";
import { IAgentRuntime, ModelProviderName } from "./types.ts";
import settings from "./settings.ts";
import elizaLogger from "./logger.ts";
Expand Down Expand Up @@ -202,7 +202,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
model: config.model,
endpoint:
runtime.character.modelEndpointOverride ||
models[ModelProviderName.OLLAMA].endpoint,
getEndpoint(ModelProviderName.OLLAMA),
isOllama: true,
dimensions: config.dimensions,
});
Expand All @@ -213,7 +213,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
model: config.model,
endpoint:
runtime.character.modelEndpointOverride ||
models[ModelProviderName.GAIANET].endpoint ||
getEndpoint(ModelProviderName.GAIANET) ||
settings.SMALL_GAIANET_SERVER_URL ||
settings.MEDIUM_GAIANET_SERVER_URL ||
settings.LARGE_GAIANET_SERVER_URL,
Expand All @@ -239,7 +239,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
model: config.model,
endpoint:
runtime.character.modelEndpointOverride ||
models[runtime.character.modelProvider].endpoint,
getEndpoint(runtime.character.modelProvider),
apiKey: runtime.token,
dimensions: config.dimensions,
});
Expand Down
19 changes: 12 additions & 7 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ import { AutoTokenizer } from "@huggingface/transformers";
import Together from "together-ai";
import { ZodSchema } from "zod";
import { elizaLogger } from "./index.ts";
import { models, getModelSettings, getImageModelSettings } from "./models.ts";
import {
models,
getModelSettings,
getImageModelSettings,
getEndpoint,
} from "./models.ts";
import {
parseBooleanFromText,
parseJsonArrayFromText,
Expand Down Expand Up @@ -200,7 +205,7 @@ export async function generateText({

const provider = runtime.modelProvider;
const endpoint =
runtime.character.modelEndpointOverride || models[provider].endpoint;
runtime.character.modelEndpointOverride || getEndpoint(provider);
const modelSettings = getModelSettings(runtime.modelProvider, modelClass);
let model = modelSettings.name;

Expand Down Expand Up @@ -563,7 +568,7 @@ export async function generateText({

case ModelProviderName.REDPILL: {
elizaLogger.debug("Initializing RedPill model.");
const serverUrl = models[provider].endpoint;
const serverUrl = getEndpoint(provider);
const openai = createOpenAI({
apiKey,
baseURL: serverUrl,
Expand Down Expand Up @@ -594,7 +599,7 @@ export async function generateText({

case ModelProviderName.OPENROUTER: {
elizaLogger.debug("Initializing OpenRouter model.");
const serverUrl = models[provider].endpoint;
const serverUrl = getEndpoint(provider);
const openrouter = createOpenAI({
apiKey,
baseURL: serverUrl,
Expand Down Expand Up @@ -628,7 +633,7 @@ export async function generateText({
elizaLogger.debug("Initializing Ollama model.");

const ollamaProvider = createOllama({
baseURL: models[provider].endpoint + "/api",
baseURL: getEndpoint(provider) + "/api",
fetch: runtime.fetch,
});
const ollama = ollamaProvider(model);
Expand Down Expand Up @@ -686,7 +691,7 @@ export async function generateText({
case ModelProviderName.GAIANET: {
elizaLogger.debug("Initializing GAIANET model.");

var baseURL = models[provider].endpoint;
var baseURL = getEndpoint(provider);
if (!baseURL) {
switch (modelClass) {
case ModelClass.SMALL:
Expand Down Expand Up @@ -1866,7 +1871,7 @@ async function handleOllama({
provider,
}: ProviderOptions): Promise<GenerateObjectResult<unknown>> {
const ollamaProvider = createOllama({
baseURL: models[provider].endpoint + "/api",
baseURL: getEndpoint(provider) + "/api",
});
const ollama = ollamaProvider(model);
return await aiGenerateObject({
Expand Down
13 changes: 9 additions & 4 deletions packages/plugin-node/src/services/image.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { elizaLogger, models } from "@elizaos/core";
import { elizaLogger, getEndpoint, models } from "@elizaos/core";
import { Service } from "@elizaos/core";
import {
IAgentRuntime,
Expand Down Expand Up @@ -187,7 +187,12 @@ export class ImageDescriptionService
): Promise<string> {
for (let attempt = 0; attempt < 3; attempt++) {
try {
const shouldUseBase64 = (isGif || isLocalFile)&& !(this.runtime.imageModelProvider === ModelProviderName.OPENAI);
const shouldUseBase64 =
(isGif || isLocalFile) &&
!(
this.runtime.imageModelProvider ===
ModelProviderName.OPENAI
);
const mimeType = isGif
? "png"
: path.extname(imageUrl).slice(1) || "jpeg";
Expand All @@ -209,8 +214,8 @@ export class ImageDescriptionService
// If model provider is openai, use the endpoint, otherwise use the default openai endpoint.
const endpoint =
this.runtime.imageModelProvider === ModelProviderName.OPENAI
? models[this.runtime.imageModelProvider].endpoint
: "https://api.openai.com/v1";
? getEndpoint(this.runtime.imageModelProvider)
: "https://api.openai.com/v1";
const response = await fetch(endpoint + "/chat/completions", {
method: "POST",
headers: {
Expand Down

0 comments on commit f736ba3

Please sign in to comment.