Skip to content

Commit

Permalink
Merge pull request #2137 from GravitonINC/mistral-text-generation
Browse files Browse the repository at this point in the history
feat: Add Mistral AI as new model provider
  • Loading branch information
wtfsayo authored Jan 11, 2025
2 parents 5bade12 + a50c24f commit eedf278
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ MEDIUM_GOOGLE_MODEL= # Default: gemini-1.5-flash-latest
LARGE_GOOGLE_MODEL= # Default: gemini-1.5-pro-latest
EMBEDDING_GOOGLE_MODEL= # Default: text-embedding-004

# Mistral Configuration
MISTRAL_MODEL=
SMALL_MISTRAL_MODEL= # Default: mistral-small-latest
MEDIUM_MISTRAL_MODEL= # Default: mistral-large-latest
LARGE_MISTRAL_MODEL= # Default: mistral-large-latest

# Groq Configuration
GROQ_API_KEY= # Starts with gsk_
SMALL_GROQ_MODEL= # Default: llama-3.1-8b-instant
Expand Down
5 changes: 5 additions & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,11 @@ export function getTokenForProvider(
character.settings?.secrets?.GOOGLE_GENERATIVE_AI_API_KEY ||
settings.GOOGLE_GENERATIVE_AI_API_KEY
);
case ModelProviderName.MISTRAL:
return (
character.settings?.secrets?.MISTRAL_API_KEY ||
settings.MISTRAL_API_KEY
);
case ModelProviderName.LETZAI:
return (
character.settings?.secrets?.LETZAI_API_KEY ||
Expand Down
1 change: 1 addition & 0 deletions docs/docs/advanced/fine-tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum ModelProviderName {
LLAMACLOUD,
LLAMALOCAL,
GOOGLE,
MISTRAL,
REDPILL,
OPENROUTER,
HEURIST,
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"@ai-sdk/google": "0.0.55",
"@ai-sdk/google-vertex": "0.0.43",
"@ai-sdk/groq": "0.0.3",
"@ai-sdk/mistral": "^1.0.8",
"@ai-sdk/openai": "1.0.5",
"@anthropic-ai/sdk": "0.30.1",
"@fal-ai/client": "1.2.0",
Expand Down
49 changes: 49 additions & 0 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { createAnthropic } from "@ai-sdk/anthropic";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createMistral } from "@ai-sdk/mistral";
import { createGroq } from "@ai-sdk/groq";
import { createOpenAI } from "@ai-sdk/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
Expand Down Expand Up @@ -528,6 +529,27 @@ export async function generateText({
break;
}

case ModelProviderName.MISTRAL: {
const mistral = createMistral();

const { text: mistralResponse } = await aiGenerateText({
model: mistral(model),
prompt: context,
system:
runtime.character.system ??
settings.SYSTEM_PROMPT ??
undefined,
temperature: temperature,
maxTokens: max_response_length,
frequencyPenalty: frequency_penalty,
presencePenalty: presence_penalty,
});

response = mistralResponse;
elizaLogger.debug("Received response from Mistral model.");
break;
}

case ModelProviderName.ANTHROPIC: {
elizaLogger.debug("Initializing Anthropic model with Cloudflare check");
const baseURL = getCloudflareGatewayBaseURL(runtime, 'anthropic') || "https://api.anthropic.com/v1";
Expand Down Expand Up @@ -1863,6 +1885,8 @@ export async function handleProvider(
});
case ModelProviderName.GOOGLE:
return await handleGoogle(options);
case ModelProviderName.MISTRAL:
return await handleMistral(options);
case ModelProviderName.REDPILL:
return await handleRedPill(options);
case ModelProviderName.OPENROUTER:
Expand Down Expand Up @@ -2019,6 +2043,31 @@ async function handleGoogle({
});
}

/**
* Handles object generation for Mistral models.
*
* @param {ProviderOptions} options - Options specific to Mistral.
* @returns {Promise<GenerateObjectResult<unknown>>} - A promise that resolves to generated objects.
*/
async function handleMistral({
model,
schema,
schemaName,
schemaDescription,
mode,
modelOptions,
}: ProviderOptions): Promise<GenerateObjectResult<unknown>> {
const mistral = createMistral();
return await aiGenerateObject({
model: mistral(model),
schema,
schemaName,
schemaDescription,
mode,
...modelOptions,
});
}

/**
* Handles object generation for Redpill models.
*
Expand Down
40 changes: 40 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,46 @@ export const models: Models = {
},
},
},
[ModelProviderName.MISTRAL]: {
model: {
[ModelClass.SMALL]: {
name:
settings.SMALL_MISTRAL_MODEL ||
settings.MISTRAL_MODEL ||
"mistral-small-latest",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
frequency_penalty: 0.4,
presence_penalty: 0.4,
temperature: 0.7,
},
[ModelClass.MEDIUM]: {
name:
settings.MEDIUM_MISTRAL_MODEL ||
settings.MISTRAL_MODEL ||
"mistral-large-latest",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
frequency_penalty: 0.4,
presence_penalty: 0.4,
temperature: 0.7,
},
[ModelClass.LARGE]: {
name:
settings.LARGE_MISTRAL_MODEL ||
settings.MISTRAL_MODEL ||
"mistral-large-latest",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
frequency_penalty: 0.4,
presence_penalty: 0.4,
temperature: 0.7,
}
},
},
[ModelProviderName.REDPILL]: {
endpoint: "https://api.red-pill.ai/v1",
// Available models: https://docs.red-pill.ai/get-started/supported-models
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ export type Models = {
[ModelProviderName.TOGETHER]: Model;
[ModelProviderName.LLAMALOCAL]: Model;
[ModelProviderName.GOOGLE]: Model;
[ModelProviderName.MISTRAL]: Model;
[ModelProviderName.CLAUDE_VERTEX]: Model;
[ModelProviderName.REDPILL]: Model;
[ModelProviderName.OPENROUTER]: Model;
Expand Down Expand Up @@ -242,6 +243,7 @@ export enum ModelProviderName {
TOGETHER = "together",
LLAMALOCAL = "llama_local",
GOOGLE = "google",
MISTRAL = "mistral",
CLAUDE_VERTEX = "claude_vertex",
REDPILL = "redpill",
OPENROUTER = "openrouter",
Expand Down
29 changes: 22 additions & 7 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit eedf278

Please sign in to comment.