diff --git a/.env.example b/.env.example index ddb00f0723e..23ab612820f 100644 --- a/.env.example +++ b/.env.example @@ -184,6 +184,12 @@ MEDIUM_GOOGLE_MODEL= # Default: gemini-1.5-flash-latest LARGE_GOOGLE_MODEL= # Default: gemini-1.5-pro-latest EMBEDDING_GOOGLE_MODEL= # Default: text-embedding-004 +# Mistral Configuration +MISTRAL_MODEL= +SMALL_MISTRAL_MODEL= # Default: mistral-small-latest +MEDIUM_MISTRAL_MODEL= # Default: mistral-large-latest +LARGE_MISTRAL_MODEL= # Default: mistral-large-latest + # Groq Configuration GROQ_API_KEY= # Starts with gsk_ SMALL_GROQ_MODEL= # Default: llama-3.1-8b-instant diff --git a/agent/src/index.ts b/agent/src/index.ts index b3f8fb0fcf0..d83913bf960 100644 --- a/agent/src/index.ts +++ b/agent/src/index.ts @@ -398,6 +398,11 @@ export function getTokenForProvider( character.settings?.secrets?.GOOGLE_GENERATIVE_AI_API_KEY || settings.GOOGLE_GENERATIVE_AI_API_KEY ); + case ModelProviderName.MISTRAL: + return ( + character.settings?.secrets?.MISTRAL_API_KEY || + settings.MISTRAL_API_KEY + ); case ModelProviderName.LETZAI: return ( character.settings?.secrets?.LETZAI_API_KEY || diff --git a/docs/docs/advanced/fine-tuning.md b/docs/docs/advanced/fine-tuning.md index 7822e9010ff..2a3220ddac6 100644 --- a/docs/docs/advanced/fine-tuning.md +++ b/docs/docs/advanced/fine-tuning.md @@ -22,6 +22,7 @@ enum ModelProviderName { LLAMACLOUD, LLAMALOCAL, GOOGLE, + MISTRAL, REDPILL, OPENROUTER, HEURIST, diff --git a/packages/core/package.json b/packages/core/package.json index 3a1b74388fe..ccea2942a4b 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -69,6 +69,7 @@ "@ai-sdk/google": "0.0.55", "@ai-sdk/google-vertex": "0.0.43", "@ai-sdk/groq": "0.0.3", + "@ai-sdk/mistral": "^1.0.8", "@ai-sdk/openai": "1.0.5", "@anthropic-ai/sdk": "0.30.1", "@fal-ai/client": "1.2.0", diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index 2dcec84eac5..f1dbe0d4699 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -1,5 +1,6 @@ import { createAnthropic } from "@ai-sdk/anthropic"; import { createGoogleGenerativeAI } from "@ai-sdk/google"; +import { createMistral } from "@ai-sdk/mistral"; import { createGroq } from "@ai-sdk/groq"; import { createOpenAI } from "@ai-sdk/openai"; import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; @@ -528,6 +529,27 @@ export async function generateText({ break; } + case ModelProviderName.MISTRAL: { + const mistral = createMistral(); + + const { text: mistralResponse } = await aiGenerateText({ + model: mistral(model), + prompt: context, + system: + runtime.character.system ?? + settings.SYSTEM_PROMPT ?? + undefined, + temperature: temperature, + maxTokens: max_response_length, + frequencyPenalty: frequency_penalty, + presencePenalty: presence_penalty, + }); + + response = mistralResponse; + elizaLogger.debug("Received response from Mistral model."); + break; + } + case ModelProviderName.ANTHROPIC: { elizaLogger.debug("Initializing Anthropic model with Cloudflare check"); const baseURL = getCloudflareGatewayBaseURL(runtime, 'anthropic') || "https://api.anthropic.com/v1"; @@ -1863,6 +1885,8 @@ export async function handleProvider( }); case ModelProviderName.GOOGLE: return await handleGoogle(options); + case ModelProviderName.MISTRAL: + return await handleMistral(options); case ModelProviderName.REDPILL: return await handleRedPill(options); case ModelProviderName.OPENROUTER: @@ -2019,6 +2043,31 @@ async function handleGoogle({ }); } +/** + * Handles object generation for Mistral models. + * + * @param {ProviderOptions} options - Options specific to Mistral. + * @returns {Promise>} - A promise that resolves to generated objects. + */ +async function handleMistral({ + model, + schema, + schemaName, + schemaDescription, + mode, + modelOptions, +}: ProviderOptions): Promise> { + const mistral = createMistral(); + return await aiGenerateObject({ + model: mistral(model), + schema, + schemaName, + schemaDescription, + mode, + ...modelOptions, + }); +} + /** * Handles object generation for Redpill models. * diff --git a/packages/core/src/models.ts b/packages/core/src/models.ts index 663aaa518a6..b543419762d 100644 --- a/packages/core/src/models.ts +++ b/packages/core/src/models.ts @@ -378,6 +378,46 @@ export const models: Models = { }, }, }, + [ModelProviderName.MISTRAL]: { + model: { + [ModelClass.SMALL]: { + name: + settings.SMALL_MISTRAL_MODEL || + settings.MISTRAL_MODEL || + "mistral-small-latest", + stop: [], + maxInputTokens: 128000, + maxOutputTokens: 8192, + frequency_penalty: 0.4, + presence_penalty: 0.4, + temperature: 0.7, + }, + [ModelClass.MEDIUM]: { + name: + settings.MEDIUM_MISTRAL_MODEL || + settings.MISTRAL_MODEL || + "mistral-large-latest", + stop: [], + maxInputTokens: 128000, + maxOutputTokens: 8192, + frequency_penalty: 0.4, + presence_penalty: 0.4, + temperature: 0.7, + }, + [ModelClass.LARGE]: { + name: + settings.LARGE_MISTRAL_MODEL || + settings.MISTRAL_MODEL || + "mistral-large-latest", + stop: [], + maxInputTokens: 128000, + maxOutputTokens: 8192, + frequency_penalty: 0.4, + presence_penalty: 0.4, + temperature: 0.7, + } + }, + }, [ModelProviderName.REDPILL]: { endpoint: "https://api.red-pill.ai/v1", // Available models: https://docs.red-pill.ai/get-started/supported-models diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index bb0b081b964..45e11262f6a 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -210,6 +210,7 @@ export type Models = { [ModelProviderName.TOGETHER]: Model; [ModelProviderName.LLAMALOCAL]: Model; [ModelProviderName.GOOGLE]: Model; + [ModelProviderName.MISTRAL]: Model; [ModelProviderName.CLAUDE_VERTEX]: Model; [ModelProviderName.REDPILL]: Model; [ModelProviderName.OPENROUTER]: Model; @@ -242,6 +243,7 @@ export enum ModelProviderName { TOGETHER = "together", LLAMALOCAL = "llama_local", GOOGLE = "google", + MISTRAL = "mistral", CLAUDE_VERTEX = "claude_vertex", REDPILL = "redpill", OPENROUTER = "openrouter", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8b18a1730cb..87ad1658f82 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -982,6 +982,9 @@ importers: '@ai-sdk/groq': specifier: 0.0.3 version: 0.0.3(zod@3.23.8) + '@ai-sdk/mistral': + specifier: ^1.0.8 + version: 1.0.8(zod@3.23.8) '@ai-sdk/openai': specifier: 1.0.5 version: 1.0.5(zod@3.23.8) @@ -2827,6 +2830,12 @@ packages: peerDependencies: zod: ^3.0.0 + '@ai-sdk/mistral@1.0.8': + resolution: {integrity: sha512-jWH4HHK4cYvXaac9UprMiSUBwOVb3e0hpbiL1wPb+2bF75pqQQKFQWQyfmoLFrh1oXlMOGn+B6IzwUDSFHLanA==} + engines: {node: '>=18'} + peerDependencies: + zod: ^3.0.0 + '@ai-sdk/openai@1.0.17': resolution: {integrity: sha512-W0+VHIDuj8AFyuRJNIxunCf0WhjZSGM3ZtronMikd+QAqbkowN9ytah2fgW503nRq0Vvb77MGEV5mL/Zj7fmEg==} engines: {node: '>=18'} @@ -22362,6 +22371,12 @@ snapshots: '@ai-sdk/provider-utils': 1.0.22(zod@3.23.8) zod: 3.23.8 + '@ai-sdk/mistral@1.0.8(zod@3.23.8)': + dependencies: + '@ai-sdk/provider': 1.0.4 + '@ai-sdk/provider-utils': 2.0.7(zod@3.23.8) + zod: 3.23.8 + '@ai-sdk/openai@1.0.17(zod@3.24.1)': dependencies: '@ai-sdk/provider': 1.0.4 @@ -24863,7 +24878,7 @@ snapshots: '@cosmjs/proto-signing@0.32.2': dependencies: - '@cosmjs/amino': 0.32.2 + '@cosmjs/amino': 0.32.4 '@cosmjs/crypto': 0.32.4 '@cosmjs/encoding': 0.32.4 '@cosmjs/math': 0.32.4 @@ -24921,7 +24936,7 @@ snapshots: '@cosmjs/stargate@0.32.2(bufferutil@4.0.9)(utf-8-validate@5.0.10)': dependencies: '@confio/ics23': 0.6.8 - '@cosmjs/amino': 0.32.2 + '@cosmjs/amino': 0.32.4 '@cosmjs/encoding': 0.32.4 '@cosmjs/math': 0.32.4 '@cosmjs/proto-signing': 0.32.4 @@ -39018,7 +39033,7 @@ snapshots: extract-zip@2.0.1: dependencies: - debug: 4.3.4 + debug: 4.4.0(supports-color@5.5.0) get-stream: 5.2.0 yauzl: 2.10.0 optionalDependencies: @@ -41365,7 +41380,7 @@ snapshots: jest-diff@29.7.0: dependencies: - chalk: 4.1.0 + chalk: 4.1.2 diff-sequences: 29.6.3 jest-get-type: 29.6.3 pretty-format: 29.7.0 @@ -43438,7 +43453,7 @@ snapshots: array-differ: 3.0.0 array-union: 2.1.0 arrify: 2.0.1 - minimatch: 3.0.5 + minimatch: 3.1.2 multistream@4.1.0: dependencies: @@ -43872,7 +43887,7 @@ snapshots: '@yarnpkg/parsers': 3.0.0-rc.46 '@zkochan/js-yaml': 0.0.7 axios: 1.7.9(debug@4.4.0) - chalk: 4.1.0 + chalk: 4.1.2 cli-cursor: 3.1.0 cli-spinners: 2.6.1 cliui: 8.0.1 @@ -44163,7 +44178,7 @@ snapshots: ora@5.3.0: dependencies: bl: 4.1.0 - chalk: 4.1.0 + chalk: 4.1.2 cli-cursor: 3.1.0 cli-spinners: 2.6.1 is-interactive: 1.0.0