-
-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtokens.go
99 lines (86 loc) · 2.79 KB
/
tokens.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package main
import (
"bytes"
"fmt"
"text/template"
"github.com/tmc/langchaingo/llms"
)
// getAvailableTokensForContent calculates how many tokens are available for content
// by rendering the template with empty content and counting tokens
func getAvailableTokensForContent(tmpl *template.Template, data map[string]interface{}) (int, error) {
if tokenLimit <= 0 {
return -1, nil // No limit when disabled
}
// Create a copy of data and set "Content" to empty
templateData := make(map[string]interface{})
for k, v := range data {
templateData[k] = v
}
templateData["Content"] = ""
// Execute template with empty content
var promptBuffer bytes.Buffer
if err := tmpl.Execute(&promptBuffer, templateData); err != nil {
return 0, fmt.Errorf("error executing template: %v", err)
}
// Count tokens in prompt template
promptTokens, err := getTokenCount(promptBuffer.String())
if err != nil {
return 0, fmt.Errorf("error counting tokens in prompt: %v", err)
}
log.Debugf("Prompt template uses %d tokens", promptTokens)
// Add safety margin for prompt tokens
promptTokens += 10
// Calculate available tokens for content
availableTokens := tokenLimit - promptTokens
if availableTokens < 0 {
return 0, fmt.Errorf("prompt template exceeds token limit")
}
return availableTokens, nil
}
func getTokenCount(content string) (int, error) {
return llms.CountTokens(llmModel, content), nil
}
// truncateContentByTokens truncates the content so that its token count does not exceed availableTokens.
// This implementation uses a binary search on runes to find the longest prefix whose token count is within the limit.
// If availableTokens is 0 or negative, the original content is returned.
func truncateContentByTokens(content string, availableTokens int) (string, error) {
if availableTokens < 0 || tokenLimit <= 0 {
return content, nil
}
totalTokens, err := getTokenCount(content)
if err != nil {
return "", fmt.Errorf("error counting tokens: %v", err)
}
if totalTokens <= availableTokens {
return content, nil
}
// Convert content to runes for safe slicing.
runes := []rune(content)
low := 0
high := len(runes)
validCut := 0
for low <= high {
mid := (low + high) / 2
substr := string(runes[:mid])
count, err := getTokenCount(substr)
if err != nil {
return "", fmt.Errorf("error counting tokens in substring: %v", err)
}
if count <= availableTokens {
validCut = mid
low = mid + 1
} else {
high = mid - 1
}
}
truncated := string(runes[:validCut])
// Final verification
finalTokens, err := getTokenCount(truncated)
if err != nil {
return "", fmt.Errorf("error counting tokens in final truncated content: %v", err)
}
if finalTokens > availableTokens {
return "", fmt.Errorf("truncated content still exceeds the available token limit")
}
return truncated, nil
}