Skip to content

Commit

Permalink
Merge pull request #506 from Abirdcfly/fixhistory
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 20, 2024
2 parents 15a180f + 008cf7d commit bb1dc7b
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 4 deletions.
12 changes: 8 additions & 4 deletions memory/window_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ import (
"github.com/tmc/langchaingo/schema"
)

// defaultConversationWindowSize is the default number of previous conversation.
const defaultConversationWindowSize = 5
const (
// defaultConversationWindowSize is the default number of previous conversation.
defaultConversationWindowSize = 5
// defaultMessageSize indicates the length of a complete message, currently consisting of 2 parts: ai and human.
defaultMessageSize = 2
)

// ConversationWindowBuffer for storing conversation memory.
type ConversationWindowBuffer struct {
Expand Down Expand Up @@ -85,8 +89,8 @@ func (wb *ConversationWindowBuffer) SaveContext(
}

func (wb *ConversationWindowBuffer) cutMessages(message []schema.ChatMessage) ([]schema.ChatMessage, bool) {
if len(message) > wb.ConversationWindowSize {
return message[len(message)-wb.ConversationWindowSize*2:], true
if len(message) > wb.ConversationWindowSize*defaultMessageSize {
return message[len(message)-wb.ConversationWindowSize*defaultMessageSize:], true
}
return message, false
}
Expand Down
101 changes: 101 additions & 0 deletions memory/window_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,104 @@ func TestWindowBufferMemoryWithPreLoadedHistory(t *testing.T) {
expected := map[string]any{"history": "Human: bar2\nAI: foo2\nHuman: bar3\nAI: foo3"}
assert.Equal(t, expected, result)
}

func TestConversationWindowBuffer_cutMessages(t *testing.T) { // nolint:funlen
t.Parallel()
type fields struct {
ConversationBuffer ConversationBuffer
ConversationWindowSize int
}
type args struct {
message []schema.ChatMessage
}
tests := []struct {
name string
fields fields
args args
wantMessage []schema.ChatMessage
isCut bool
}{
{
name: "empty messages, do not need cut",
fields: fields{
ConversationBuffer: *NewConversationBuffer(),
ConversationWindowSize: 1,
},
args: args{
message: []schema.ChatMessage{},
},
wantMessage: []schema.ChatMessage{},
isCut: false,
},
{
name: "message less than buffer size, do not need cut",
fields: fields{
ConversationBuffer: *NewConversationBuffer(),
ConversationWindowSize: 1,
},
args: args{
message: []schema.ChatMessage{
schema.HumanChatMessage{Content: "foo"},
schema.AIChatMessage{Content: "bar"},
},
},
wantMessage: []schema.ChatMessage{
schema.HumanChatMessage{Content: "foo"},
schema.AIChatMessage{Content: "bar"},
},
isCut: false,
},
{
name: "add human message, will cut",
fields: fields{
ConversationBuffer: *NewConversationBuffer(),
ConversationWindowSize: 1,
},
args: args{
message: []schema.ChatMessage{
schema.HumanChatMessage{Content: "foo"},
schema.AIChatMessage{Content: "bar"},
schema.HumanChatMessage{Content: "foo1"},
},
},
wantMessage: []schema.ChatMessage{
schema.AIChatMessage{Content: "bar"},
schema.HumanChatMessage{Content: "foo1"},
},
isCut: true,
},
{
name: "message more than buffer size, will cut",
fields: fields{
ConversationBuffer: *NewConversationBuffer(),
ConversationWindowSize: 1,
},
args: args{
message: []schema.ChatMessage{
schema.HumanChatMessage{Content: "foo"},
schema.AIChatMessage{Content: "bar"},
schema.HumanChatMessage{Content: "foo1"},
schema.AIChatMessage{Content: "bar1"},
},
},
wantMessage: []schema.ChatMessage{
schema.HumanChatMessage{Content: "foo1"},
schema.AIChatMessage{Content: "bar1"},
},
isCut: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
wb := &ConversationWindowBuffer{
ConversationBuffer: tt.fields.ConversationBuffer,
ConversationWindowSize: tt.fields.ConversationWindowSize,
}
cut, isCut := wb.cutMessages(tt.args.message)
assert.Equalf(t, tt.wantMessage, cut, "cutMessages(%s), want:%v, get:%v", tt.name, tt.wantMessage, cut)
assert.Equalf(t, tt.isCut, isCut, "cutMessages(%s), want:%t, get:%t", tt.name, tt.isCut, isCut)
})
}
}

0 comments on commit bb1dc7b

Please sign in to comment.