@@ -14,6 +14,7 @@ import { Buffer } from "buffer";
14
14
import { createOllama } from "ollama-ai-provider" ;
15
15
import OpenAI from "openai" ;
16
16
import { encodingForModel , TiktokenModel } from "js-tiktoken" ;
17
+ import { AutoTokenizer } from "@huggingface/transformers" ;
17
18
import Together from "together-ai" ;
18
19
import { ZodSchema } from "zod" ;
19
20
import { elizaLogger } from "./index.ts" ;
@@ -37,13 +38,122 @@ import {
37
38
SearchResponse ,
38
39
ActionResponse ,
39
40
TelemetrySettings ,
41
+ TokenizerType ,
40
42
} from "./types.ts" ;
41
43
import { fal } from "@fal-ai/client" ;
42
44
import { tavily } from "@tavily/core" ;
43
45
44
46
type Tool = CoreTool < any , any > ;
45
47
type StepResult = AIStepResult < any > ;
46
48
49
+ /**
50
+ * Trims the provided text context to a specified token limit using a tokenizer model and type.
51
+ *
52
+ * The function dynamically determines the truncation method based on the tokenizer settings
53
+ * provided by the runtime. If no tokenizer settings are defined, it defaults to using the
54
+ * TikToken truncation method with the "gpt-4o" model.
55
+ *
56
+ * @async
57
+ * @function trimTokens
58
+ * @param {string } context - The text to be tokenized and trimmed.
59
+ * @param {number } maxTokens - The maximum number of tokens allowed after truncation.
60
+ * @param {IAgentRuntime } runtime - The runtime interface providing tokenizer settings.
61
+ *
62
+ * @returns {Promise<string> } A promise that resolves to the trimmed text.
63
+ *
64
+ * @throws {Error } Throws an error if the runtime settings are invalid or missing required fields.
65
+ *
66
+ * @example
67
+ * const trimmedText = await trimTokens("This is an example text", 50, runtime);
68
+ * console.log(trimmedText); // Output will be a truncated version of the input text.
69
+ */
70
+ export async function trimTokens (
71
+ context : string ,
72
+ maxTokens : number ,
73
+ runtime : IAgentRuntime
74
+ ) {
75
+ if ( ! context ) return "" ;
76
+ if ( maxTokens <= 0 ) throw new Error ( "maxTokens must be positive" ) ;
77
+
78
+ const tokenizerModel = runtime . getSetting ( "TOKENIZER_MODEL" ) ;
79
+ const tokenizerType = runtime . getSetting ( "TOKENIZER_TYPE" ) ;
80
+
81
+ if ( ! tokenizerModel || ! tokenizerType ) {
82
+ // Default to TikToken truncation using the "gpt-4o" model if tokenizer settings are not defined
83
+ return truncateTiktoken ( "gpt-4o" , context , maxTokens ) ;
84
+ }
85
+
86
+ // Choose the truncation method based on tokenizer type
87
+ if ( tokenizerType === TokenizerType . Auto ) {
88
+ return truncateAuto ( tokenizerModel , context , maxTokens ) ;
89
+ }
90
+
91
+ if ( tokenizerType === TokenizerType . TikToken ) {
92
+ return truncateTiktoken (
93
+ tokenizerModel as TiktokenModel ,
94
+ context ,
95
+ maxTokens
96
+ ) ;
97
+ }
98
+
99
+ elizaLogger . warn ( `Unsupported tokenizer type: ${ tokenizerType } ` ) ;
100
+ return truncateTiktoken ( "gpt-4o" , context , maxTokens ) ;
101
+ }
102
+
103
+ async function truncateAuto (
104
+ modelPath : string ,
105
+ context : string ,
106
+ maxTokens : number
107
+ ) {
108
+ try {
109
+ const tokenizer = await AutoTokenizer . from_pretrained ( modelPath ) ;
110
+ const tokens = tokenizer . encode ( context ) ;
111
+
112
+ // If already within limits, return unchanged
113
+ if ( tokens . length <= maxTokens ) {
114
+ return context ;
115
+ }
116
+
117
+ // Keep the most recent tokens by slicing from the end
118
+ const truncatedTokens = tokens . slice ( - maxTokens ) ;
119
+
120
+ // Decode back to text - js-tiktoken decode() returns a string directly
121
+ return tokenizer . decode ( truncatedTokens ) ;
122
+ } catch ( error ) {
123
+ elizaLogger . error ( "Error in trimTokens:" , error ) ;
124
+ // Return truncated string if tokenization fails
125
+ return context . slice ( - maxTokens * 4 ) ; // Rough estimate of 4 chars per token
126
+ }
127
+ }
128
+
129
+ async function truncateTiktoken (
130
+ model : TiktokenModel ,
131
+ context : string ,
132
+ maxTokens : number
133
+ ) {
134
+ try {
135
+ const encoding = encodingForModel ( model ) ;
136
+
137
+ // Encode the text into tokens
138
+ const tokens = encoding . encode ( context ) ;
139
+
140
+ // If already within limits, return unchanged
141
+ if ( tokens . length <= maxTokens ) {
142
+ return context ;
143
+ }
144
+
145
+ // Keep the most recent tokens by slicing from the end
146
+ const truncatedTokens = tokens . slice ( - maxTokens ) ;
147
+
148
+ // Decode back to text - js-tiktoken decode() returns a string directly
149
+ return encoding . decode ( truncatedTokens ) ;
150
+ } catch ( error ) {
151
+ elizaLogger . error ( "Error in trimTokens:" , error ) ;
152
+ // Return truncated string if tokenization fails
153
+ return context . slice ( - maxTokens * 4 ) ; // Rough estimate of 4 chars per token
154
+ }
155
+ }
156
+
47
157
/**
48
158
* Send a message to the model for a text generateText - receive a string back and parse how you'd like
49
159
* @param opts - The options for the generateText request.
@@ -187,7 +297,8 @@ export async function generateText({
187
297
elizaLogger . debug (
188
298
`Trimming context to max length of ${ max_context_length } tokens.`
189
299
) ;
190
- context = trimTokens ( context , max_context_length , "gpt-4o" ) ;
300
+
301
+ context = await trimTokens ( context , max_context_length , runtime ) ;
191
302
192
303
let response : string ;
193
304
@@ -653,45 +764,6 @@ export async function generateText({
653
764
}
654
765
}
655
766
656
- /**
657
- * Truncate the context to the maximum length allowed by the model.
658
- * @param context The text to truncate
659
- * @param maxTokens Maximum number of tokens to keep
660
- * @param model The tokenizer model to use
661
- * @returns The truncated text
662
- */
663
- export function trimTokens (
664
- context : string ,
665
- maxTokens : number ,
666
- model : TiktokenModel
667
- ) : string {
668
- if ( ! context ) return "" ;
669
- if ( maxTokens <= 0 ) throw new Error ( "maxTokens must be positive" ) ;
670
-
671
- // Get the tokenizer for the model
672
- const encoding = encodingForModel ( model ) ;
673
-
674
- try {
675
- // Encode the text into tokens
676
- const tokens = encoding . encode ( context ) ;
677
-
678
- // If already within limits, return unchanged
679
- if ( tokens . length <= maxTokens ) {
680
- return context ;
681
- }
682
-
683
- // Keep the most recent tokens by slicing from the end
684
- const truncatedTokens = tokens . slice ( - maxTokens ) ;
685
-
686
- // Decode back to text - js-tiktoken decode() returns a string directly
687
- return encoding . decode ( truncatedTokens ) ;
688
- } catch ( error ) {
689
- console . error ( "Error in trimTokens:" , error ) ;
690
- // Return truncated string if tokenization fails
691
- return context . slice ( - maxTokens * 4 ) ; // Rough estimate of 4 chars per token
692
- }
693
- }
694
-
695
767
/**
696
768
* Sends a message to the model to determine if it should respond to the given context.
697
769
* @param opts - The options for the generateText request
@@ -973,9 +1045,10 @@ export async function generateMessageResponse({
973
1045
context : string ;
974
1046
modelClass : string ;
975
1047
} ) : Promise < Content > {
976
- const max_context_length =
977
- models [ runtime . modelProvider ] . settings . maxInputTokens ;
978
- context = trimTokens ( context , max_context_length , "gpt-4o" ) ;
1048
+ const provider = runtime . modelProvider ;
1049
+ const max_context_length = models [ provider ] . settings . maxInputTokens ;
1050
+
1051
+ context = await trimTokens ( context , max_context_length , runtime ) ;
979
1052
let retryLength = 1000 ; // exponential backoff
980
1053
while ( true ) {
981
1054
try {
@@ -1443,20 +1516,18 @@ export const generateObject = async ({
1443
1516
}
1444
1517
1445
1518
const provider = runtime . modelProvider ;
1446
- const model = models [ provider ] . model [ modelClass ] as TiktokenModel ;
1447
- if ( ! model ) {
1448
- throw new Error ( `Unsupported model class: ${ modelClass } ` ) ;
1449
- }
1519
+ const model = models [ provider ] . model [ modelClass ] ;
1450
1520
const temperature = models [ provider ] . settings . temperature ;
1451
1521
const frequency_penalty = models [ provider ] . settings . frequency_penalty ;
1452
1522
const presence_penalty = models [ provider ] . settings . presence_penalty ;
1453
1523
const max_context_length = models [ provider ] . settings . maxInputTokens ;
1454
1524
const max_response_length = models [ provider ] . settings . maxOutputTokens ;
1455
- const experimental_telemetry = models [ provider ] . settings . experimental_telemetry ;
1525
+ const experimental_telemetry =
1526
+ models [ provider ] . settings . experimental_telemetry ;
1456
1527
const apiKey = runtime . token ;
1457
1528
1458
1529
try {
1459
- context = trimTokens ( context , max_context_length , model ) ;
1530
+ context = await trimTokens ( context , max_context_length , runtime ) ;
1460
1531
1461
1532
const modelOptions : ModelSettings = {
1462
1533
prompt : context ,
0 commit comments