diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 9bd05242a..2f4336b14 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -14,7 +14,8 @@ import ( type Handler interface { HandleText(ctx context.Context, text string) HandleLLMStart(ctx context.Context, prompts []string) - HandleLLMEnd(ctx context.Context, output llms.LLMResult) + HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) + HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) HandleLLMError(ctx context.Context, err error) HandleChainStart(ctx context.Context, inputs map[string]any) HandleChainEnd(ctx context.Context, outputs map[string]any) diff --git a/callbacks/combining.go b/callbacks/combining.go index 22ad518bc..2e95e80aa 100644 --- a/callbacks/combining.go +++ b/callbacks/combining.go @@ -26,9 +26,15 @@ func (l CombiningHandler) HandleLLMStart(ctx context.Context, prompts []string) } } -func (l CombiningHandler) HandleLLMEnd(ctx context.Context, output llms.LLMResult) { +func (l CombiningHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) { for _, handle := range l.Callbacks { - handle.HandleLLMEnd(ctx, output) + handle.HandleLLMGenerateContentStart(ctx, ms) + } +} + +func (l CombiningHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) { + for _, handle := range l.Callbacks { + handle.HandleLLMGenerateContentEnd(ctx, res) } } diff --git a/callbacks/log.go b/callbacks/log.go index f7a9453e3..135b7fdcd 100644 --- a/callbacks/log.go +++ b/callbacks/log.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" - "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" ) @@ -25,10 +24,6 @@ func (l LogHandler) HandleLLMStart(_ context.Context, prompts []string) { fmt.Println("Entering LLM with prompts:", prompts) } -func (l LogHandler) HandleLLMEnd(_ context.Context, output llms.LLMResult) { - fmt.Println("Exiting LLM with results:", formatLLMResult(output)) -} - func (l LogHandler) HandleLLMError(_ context.Context, err error) { fmt.Println("Exiting LLM with error:", err) } @@ -82,17 +77,6 @@ func formatChainValues(values map[string]any) string { return output } -func formatLLMResult(output llms.LLMResult) string { - results := "[ " - for i := 0; i < len(output.Generations); i++ { - for j := 0; j < len(output.Generations[i]); j++ { - results += output.Generations[i][j].Text - } - } - - return results + " ]" -} - func formatAgentAction(action schema.AgentAction) string { return fmt.Sprintf("\"%s\" with input \"%s\"", removeNewLines(action.Tool), removeNewLines(action.ToolInput)) } diff --git a/callbacks/simple.go b/callbacks/simple.go index db24e659b..94cf54174 100644 --- a/callbacks/simple.go +++ b/callbacks/simple.go @@ -12,18 +12,19 @@ type SimpleHandler struct{} var _ Handler = SimpleHandler{} -func (SimpleHandler) HandleText(context.Context, string) {} -func (SimpleHandler) HandleLLMStart(context.Context, []string) {} -func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {} -func (SimpleHandler) HandleLLMError(context.Context, error) {} -func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {} -func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {} -func (SimpleHandler) HandleChainError(context.Context, error) {} -func (SimpleHandler) HandleToolStart(context.Context, string) {} -func (SimpleHandler) HandleToolEnd(context.Context, string) {} -func (SimpleHandler) HandleToolError(context.Context, error) {} -func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {} -func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {} -func (SimpleHandler) HandleRetrieverStart(context.Context, string) {} -func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {} -func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {} +func (SimpleHandler) HandleText(context.Context, string) {} +func (SimpleHandler) HandleLLMStart(context.Context, []string) {} +func (SimpleHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) {} +func (SimpleHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) {} +func (SimpleHandler) HandleLLMError(context.Context, error) {} +func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {} +func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {} +func (SimpleHandler) HandleChainError(context.Context, error) {} +func (SimpleHandler) HandleToolStart(context.Context, string) {} +func (SimpleHandler) HandleToolEnd(context.Context, string) {} +func (SimpleHandler) HandleToolError(context.Context, error) {} +func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {} +func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {} +func (SimpleHandler) HandleRetrieverStart(context.Context, string) {} +func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {} +func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {} diff --git a/chains/chains_test.go b/chains/chains_test.go index 103fa946d..ce5410117 100644 --- a/chains/chains_test.go +++ b/chains/chains_test.go @@ -53,18 +53,9 @@ func (l *testLanguageModel) Call(_ context.Context, prompt string, _ ...llms.Cal return llmResult, nil } -func (l *testLanguageModel) Generate( - ctx context.Context, prompts []string, options ...llms.CallOption, -) ([]*llms.Generation, error) { - result, err := l.Call(ctx, prompts[0], options...) - if err != nil { - return nil, err - } - return []*llms.Generation{ - { - Text: result, - }, - }, nil +func (l *testLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageContent, _ ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + + panic("not implemented") } var _ llms.LLM = &testLanguageModel{} diff --git a/embeddings/openai_test.go b/embeddings/openai_test.go index eea891db7..c4adf5a61 100644 --- a/embeddings/openai_test.go +++ b/embeddings/openai_test.go @@ -100,19 +100,3 @@ func TestOpenaiEmbeddingsWithAzureAPI(t *testing.T) { require.NoError(t, err) assert.Len(t, embeddings, 1) } - -func TestUseLLMAndChatAsEmbedderClient(t *testing.T) { - t.Parallel() - - if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { - t.Skip("OPENAI_API_KEY not set") - } - - // Shows that we can pass an openai chat value to NewEmbedder - chat, err := openai.NewChat() - require.NoError(t, err) - - embedderFromChat, err := NewEmbedder(chat) - require.NoError(t, err) - var _ Embedder = embedderFromChat -} diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index 620255ba6..2d40f8239 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -50,50 +50,46 @@ func newClient(opts ...Option) (*anthropicclient.Client, error) { // Call requests a completion for the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } - opts := llms.CallOptions{} + opts := &llms.CallOptions{} for _, opt := range options { - opt(&opts) + opt(opts) } - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{ - Model: opts.Model, - Prompt: prompt, - MaxTokens: opts.MaxTokens, - StopWords: opts.StopWords, - Temperature: opts.Temperature, - TopP: opts.TopP, - StreamingFunc: opts.StreamingFunc, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{ + Model: opts.Model, + Prompt: part.(llms.TextContent).Text, + MaxTokens: opts.MaxTokens, + StopWords: opts.StopWords, + Temperature: opts.Temperature, + TopP: opts.TopP, + StreamingFunc: opts.StreamingFunc, + }) + if err != nil { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) } - generations = append(generations, &llms.Generation{ - Text: result.Text, - }) + return nil, err } - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: result.Text, + }, + }, } - return generations, nil + return resp, nil } diff --git a/llms/cohere/coherellm.go b/llms/cohere/coherellm.go index a2b8d9b59..13c3a797d 100644 --- a/llms/cohere/coherellm.go +++ b/llms/cohere/coherellm.go @@ -25,49 +25,42 @@ type LLM struct { var _ llms.LLM = (*LLM)(nil) func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } - opts := llms.CallOptions{} + opts := &llms.CallOptions{} for _, opt := range options { - opt(&opts) + opt(opts) } - generations := make([]*llms.Generation, 0, len(prompts)) - - for _, prompt := range prompts { - result, err := o.client.CreateGeneration(ctx, &cohereclient.GenerationRequest{ - Prompt: prompt, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateGeneration(ctx, &cohereclient.GenerationRequest{ + Prompt: part.(llms.TextContent).Text, + }) + if err != nil { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) } - - generations = append(generations, &llms.Generation{ - Text: result.Text, - }) + return nil, err } - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: result.Text, + }, + }, } - - return generations, nil + return resp, nil } func New(opts ...Option) (*LLM, error) { diff --git a/llms/ernie/ernie_chat.go b/llms/ernie/ernie_chat.go deleted file mode 100644 index ef44b3379..000000000 --- a/llms/ernie/ernie_chat.go +++ /dev/null @@ -1,185 +0,0 @@ -package ernie - -import ( - "context" - "os" - "reflect" - - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/ernie/internal/ernieclient" - "github.com/tmc/langchaingo/schema" -) - -type ChatMessage = ernieclient.ChatMessage - -type Chat struct { - CallbacksHandler callbacks.Handler - client *ernieclient.Client -} - -var _ llms.ChatLLM = (*Chat)(nil) - -func NewChat(opts ...Option) (*Chat, error) { - options := &options{ - apiKey: os.Getenv(ernieAPIKey), - secretKey: os.Getenv(ernieSecretKey), - } - - for _, opt := range opts { - opt(options) - } - - c, err := newClient(options) - if err != nil { - return nil, err - } - c.ModelPath = modelToPath(ModelName(c.Model)) - - return &Chat{ - client: c, - }, err -} - -// Call requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -//nolint:funlen -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint:lll,cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, getPromptsFromMessageSets(messageSets)) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messageSet := range messageSets { - req := &ernieclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: messagesToClientMessages(messageSet), - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - MaxTokens: opts.MaxTokens, - N: opts.N, // TODO: note, we are not returning multiple completions - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - System: getSystem(messageSet), - - FunctionCallBehavior: ernieclient.FunctionCallBehavior(opts.FunctionCallBehavior), - } - for _, fn := range opts.Functions { - req.Functions = append(req.Functions, ernieclient.FunctionDefinition{ - Name: fn.Name, - Description: fn.Description, - Parameters: fn.Parameters, - }) - } - result, err := o.client.CreateChat(ctx, req) - if err != nil { - return nil, err - } - - if result.Result == "" && result.FunctionCall == nil { - return nil, ErrEmptyResponse - } - - generationInfo := make(map[string]any, reflect.ValueOf(result.Usage).NumField()) - generationInfo["CompletionTokens"] = result.Usage.CompletionTokens - generationInfo["PromptTokens"] = result.Usage.PromptTokens - generationInfo["TotalTokens"] = result.Usage.TotalTokens - msg := &schema.AIChatMessage{ - Content: result.Result, - } - - if result.FunctionCall != nil { - msg.FunctionCall = &schema.FunctionCall{ - Name: result.FunctionCall.Name, - Arguments: result.FunctionCall.Arguments, - } - } - generations = append(generations, &llms.Generation{ - Message: msg, - Text: msg.Content, - GenerationInfo: generationInfo, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - -func getPromptsFromMessageSets(messageSets [][]schema.ChatMessage) []string { - prompts := make([]string, 0, len(messageSets)) - for i := 0; i < len(messageSets); i++ { - curPrompt := "" - for j := 0; j < len(messageSets[i]); j++ { - curPrompt += messageSets[i][j].GetContent() - } - prompts = append(prompts, curPrompt) - } - - return prompts -} - -func messagesToClientMessages(messages []schema.ChatMessage) []*ernieclient.ChatMessage { - msgs := make([]*ernieclient.ChatMessage, 0) - for _, m := range messages { - msg := &ernieclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: // In Ernie's 'messages' parameter, there is no 'system' role. - continue - case schema.ChatMessageTypeAI: - msg.Role = "assistant" - case schema.ChatMessageTypeHuman: - msg.Role = "user" - case schema.ChatMessageTypeGeneric: - msg.Role = "user" - case schema.ChatMessageTypeFunction: - msg.Role = "function" - } - - if n, ok := m.(FunctionCalled); ok { - msg.FunctionCall = n.GetFunctionCall() - } - - if n, ok := m.(schema.Named); ok { - msg.Name = n.GetName() - } - msgs = append(msgs, msg) - } - - return msgs -} - -// getSystem Retrieve system parameter from messages. -func getSystem(messages []schema.ChatMessage) string { - for _, message := range messages { - if message.GetType() == schema.ChatMessageTypeSystem { - return message.GetContent() - } - } - return "" -} - -// FunctionCalled is an interface for objects that have a function call info. -type FunctionCalled interface { - GetFunctionCall() *schema.FunctionCall -} diff --git a/llms/ernie/ernie_chat_test.go b/llms/ernie/ernie_chat_test.go deleted file mode 100644 index 5d1669f26..000000000 --- a/llms/ernie/ernie_chat_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package ernie - -import ( - "reflect" - "testing" - - "github.com/tmc/langchaingo/llms/ernie/internal/ernieclient" - "github.com/tmc/langchaingo/schema" -) - -func TestNewChat(t *testing.T) { - t.Parallel() - type args struct { - opts []Option - } - tests := []struct { - name string - args args - want *Chat - wantErr bool - }{ - {name: "", args: args{opts: []Option{ - WithModelName(ModelNameERNIEBot), - WithAKSK("ak", "sk"), - }}, want: nil, wantErr: true}, - {name: "", args: args{opts: []Option{ - WithModelName(ModelNameERNIEBot), - WithAKSK("ak", "sk"), - WithAccessToken("xxx"), - }}, want: nil, wantErr: false}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got, err := NewChat(tt.args.opts...) - if (err != nil) != tt.wantErr { - t.Errorf("NewChat() error = %v, wantErr %v", err, tt.wantErr) - return - } - - expectedType := reflect.TypeOf(tt.want) - if reflect.TypeOf(got) != expectedType { - t.Errorf("NewChat() got = %T, want %T", got, tt.want) - } - }) - } -} - -func TestGetSystem(t *testing.T) { - t.Parallel() - type args struct { - messages []schema.ChatMessage - } - tests := []struct { - name string - args args - want string - }{ - { - name: "system message exists", - args: args{ - messages: []schema.ChatMessage{ - schema.SystemChatMessage{Content: "you are a robot."}, - schema.HumanChatMessage{Content: "who are you?"}, - }, - }, - want: "you are a robot.", - }, - { - name: "no system message", - args: args{ - messages: []schema.ChatMessage{ - schema.HumanChatMessage{Content: "who are you?"}, - }, - }, - want: "", - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - if got := getSystem(tt.args.messages); got != tt.want { - t.Errorf("getSystem() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestMessagesToClientMessages(t *testing.T) { - t.Parallel() - type args struct { - messages []schema.ChatMessage - } - tests := []struct { - name string - args args - want []*ernieclient.ChatMessage - }{ - { - name: "Test_MessagesToClientMessages_OK", - args: args{messages: []schema.ChatMessage{ - schema.AIChatMessage{Content: "assistant"}, - schema.HumanChatMessage{Content: "user"}, - schema.SystemChatMessage{Content: ""}, - schema.FunctionChatMessage{Content: "function"}, - schema.GenericChatMessage{Content: "user"}, - }}, - want: []*ernieclient.ChatMessage{ - {Content: "assistant", Role: "assistant"}, - {Content: "user", Role: "user"}, - {Content: "function", Role: "function"}, - {Content: "user", Role: "user"}, - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := messagesToClientMessages(tt.args.messages) - for i, v := range got { - if !reflect.DeepEqual(v.Content, tt.want[i].Content) { - t.Errorf("messagesToClientMessages() = %v, want %v", got, tt.want) - } - } - }) - } -} diff --git a/llms/ernie/erniellm.go b/llms/ernie/erniellm.go index a4a2a1f30..1445abe36 100644 --- a/llms/ernie/erniellm.go +++ b/llms/ernie/erniellm.go @@ -61,60 +61,59 @@ doc: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2`, ernieclient.ErrNot // Call implements llms.LLM. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - - if len(r) == 0 { - return "", ErrEmptyResponse - } - - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -// Generate implements llms.LLM. -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } - opts := llms.CallOptions{} + opts := &llms.CallOptions{} for _, opt := range options { - opt(&opts) + opt(opts) } - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, o.getModelPath(opts), &ernieclient.CompletionRequest{ - Messages: []ernieclient.Message{{Role: "user", Content: prompt}}, - Temperature: opts.Temperature, - TopP: opts.TopP, - PenaltyScore: opts.RepetitionPenalty, - StreamingFunc: opts.StreamingFunc, - Stream: opts.StreamingFunc != nil, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateCompletion(ctx, o.getModelPath(*opts), &ernieclient.CompletionRequest{ + Messages: []ernieclient.Message{{Role: "user", Content: part.(llms.TextContent).Text}}, + Temperature: opts.Temperature, + TopP: opts.TopP, + PenaltyScore: opts.RepetitionPenalty, + StreamingFunc: opts.StreamingFunc, + Stream: opts.StreamingFunc != nil, + }) + if err != nil { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) } - if result.ErrorCode > 0 { - err = fmt.Errorf("%w, error_code:%v, erro_msg:%v, id:%v", - ErrCodeResponse, result.ErrorCode, result.ErrorMsg, result.ID) - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err + return nil, err + } + if result.ErrorCode > 0 { + err = fmt.Errorf("%w, error_code:%v, erro_msg:%v, id:%v", + ErrCodeResponse, result.ErrorCode, result.ErrorMsg, result.ID) + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) } + return nil, err + } - generations = append(generations, &llms.Generation{ - Text: result.Result, - }) + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: result.Result, + }, + }, + } + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp) } - return generations, nil + return resp, nil } // CreateEmbedding use ernie Embedding-V1. diff --git a/llms/generatecontent.go b/llms/generatecontent.go index 894da7a23..8d91c2003 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -62,13 +62,15 @@ type BinaryContent struct { func (BinaryContent) isPart() {} // ContentResponse is the response returned by a GenerateContent call. -// It can potentially return multiple response choices. +// It can potentially return multiple content choices. type ContentResponse struct { Choices []*ContentChoice } -// ContentChoice is one of the response choices returned by GenerateModel calls. +// ContentChoice is one of the response choices returned by GenerateContent +// calls. type ContentChoice struct { + // Content is the textual content of a response Content string // StopReason is the reason the model stopped generating output. diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index f4cf9d7dd..4c8735ce3 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/google/generative-ai-go/genai" + "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" "google.golang.org/api/iterator" @@ -22,8 +23,9 @@ import ( // GoogleAI is a type that represents a Google AI API client. type GoogleAI struct { - client *genai.Client - opts options + CallbacksHandler callbacks.Handler + client *genai.Client + opts options } var ( @@ -62,8 +64,17 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) { return gi, nil } +// Call Implement the call interface for LLM. +func (g *GoogleAI) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { + return llms.CallLLM(ctx, g, prompt, options...) +} + // GenerateContent calls the LLM with the provided parts. func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if g.CallbacksHandler != nil { + g.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + opts := llms.CallOptions{ Model: g.opts.defaultModel, MaxTokens: int(g.opts.defaultMaxTokens), @@ -77,14 +88,27 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC model.SetMaxOutputTokens(int32(opts.MaxTokens)) model.SetTemperature(float32(opts.Temperature)) + var response *llms.ContentResponse + var err error + if len(messages) == 1 { theMessage := messages[0] if theMessage.Role != schema.ChatMessageTypeHuman { return nil, fmt.Errorf("got %v message role, want human", theMessage.Role) } - return generateFromSingleMessage(ctx, model, theMessage.Parts, &opts) + response, err = generateFromSingleMessage(ctx, model, theMessage.Parts, &opts) + } else { + response, err = generateFromMessages(ctx, model, messages, &opts) } - return generateFromMessages(ctx, model, messages, &opts) + if err != nil { + return nil, err + } + + if g.CallbacksHandler != nil { + g.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil } // downloadImageData downloads the content from the given URL and returns it as diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index c673d8dee..5a0f762d9 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -23,29 +23,29 @@ type LLM struct { var _ llms.LLM = (*LLM)(nil) +// Call implements the LLM interface. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } opts := &llms.CallOptions{Model: defaultModel} for _, opt := range options { opt(opts) } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] result, err := o.client.RunInference(ctx, &huggingfaceclient.InferenceRequest{ Model: o.client.Model, - Prompt: prompts[0], + Prompt: part.(llms.TextContent).Text, Task: huggingfaceclient.InferenceTaskTextGeneration, Temperature: opts.Temperature, TopP: opts.TopP, @@ -62,14 +62,14 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca return nil, err } - generations := []*llms.Generation{ - {Text: result.Text}, - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: result.Text, + }, + }, } - return generations, nil + return resp, nil } func New(opts ...Option) (*LLM, error) { diff --git a/llms/llms.go b/llms/llms.go index 33456af4b..8325d84b2 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -2,45 +2,49 @@ package llms import ( "context" + "errors" "github.com/tmc/langchaingo/schema" ) -// LLM is a langchaingo Large Language Model. -type LLM interface { - Call(ctx context.Context, prompt string, options ...CallOption) (string, error) - Generate(ctx context.Context, prompts []string, options ...CallOption) ([]*Generation, error) -} - -// ChatLLM is a langchaingo LLM that can be used for chatting. -type ChatLLM interface { - Call(ctx context.Context, messages []schema.ChatMessage, options ...CallOption) (*schema.AIChatMessage, error) - Generate(ctx context.Context, messages [][]schema.ChatMessage, options ...CallOption) ([]*Generation, error) -} +// LLM is an alias for model, for backwards compatibility. +// +// This alias may be removed in the future; please use Model +// instead. +type LLM = Model // Model is an interface multi-modal models implement. -// Note: this is an experimental API. type Model interface { // GenerateContent asks the model to generate content from a sequence of - // messages. It's the most general interface for LLMs that support chat-like - // interactions. - GenerateContent(ctx context.Context, parts []MessageContent, options ...CallOption) (*ContentResponse, error) -} + // messages. It's the most general interface for multi-modal LLMs that support + // chat-like interactions. + GenerateContent(ctx context.Context, messages []MessageContent, options ...CallOption) (*ContentResponse, error) -// Generation is a single generation from a langchaingo LLM. -type Generation struct { - // Text is the generated text. - Text string `json:"text"` - // Message stores the potentially generated message. - Message *schema.AIChatMessage `json:"message"` - // GenerationInfo is the generation info. This can contain vendor-specific information. - GenerationInfo map[string]any `json:"generation_info"` - // StopReason is the reason the generation stopped. - StopReason string `json:"stop_reason"` + // Call is a simplified interface for a text-only Model, generating a single + // string response from a single string prompt. + // + // It is here for backwards compatibility only and may be removed + // in the future; please use GenerateContent instead. + Call(ctx context.Context, prompt string, options ...CallOption) (string, error) } -// LLMResult is the class that contains all relevant information for an LLM Result. -type LLMResult struct { - Generations [][]*Generation - LLMOutput map[string]any +// CallLLM is a helper function for implementing Call in terms of +// GenerateContent. It's aimed to be used by Model providers. +func CallLLM(ctx context.Context, llm Model, prompt string, options ...CallOption) (string, error) { + msg := MessageContent{ + Role: schema.ChatMessageTypeHuman, + Parts: []ContentPart{TextContent{prompt}}, + } + + resp, err := llm.GenerateContent(ctx, []MessageContent{msg}, options...) + if err != nil { + return "", err + } + + choices := resp.Choices + if len(choices) < 1 { + return "", errors.New("empty response from model") //nolint:goerr113 + } + c1 := choices[0] + return c1.Content, nil } diff --git a/llms/local/localllm.go b/llms/local/localllm.go index d29fa277a..cf031f344 100644 --- a/llms/local/localllm.go +++ b/llms/local/localllm.go @@ -33,17 +33,10 @@ var ( // Call calls the local LLM binary with the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) []string { +func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) { if opts.Temperature != 0 { o.client.Args = append(o.client.Args, fmt.Sprintf("--temperature=%f", opts.Temperature)) } @@ -65,14 +58,13 @@ func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) []string { if opts.Seed != 0 { o.client.Args = append(o.client.Args, fmt.Sprintf("--seed=%d", opts.Seed)) } - - return o.client.Args } -// Generate generates completions using the local LLM binary. -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } opts := &llms.CallOptions{} @@ -86,26 +78,29 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca o.appendGlobalsToArgs(*opts) } - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, &localclient.CompletionRequest{ - Prompt: prompt, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateCompletion(ctx, &localclient.CompletionRequest{ + Prompt: part.(llms.TextContent).Text, + }) + if err != nil { + return nil, err + } - generations = append(generations, &llms.Generation{Text: result.Text}) + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: result.Text, + }, + }, } if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp) } - return generations, nil + return resp, nil } // New creates a new local LLM implementation. diff --git a/llms/ollama/ollama_test.go b/llms/ollama/ollama_test.go index be9f78e6c..57f3f7876 100644 --- a/llms/ollama/ollama_test.go +++ b/llms/ollama/ollama_test.go @@ -12,7 +12,7 @@ import ( "github.com/tmc/langchaingo/schema" ) -func newChatClient(t *testing.T) *Chat { +func newTestClient(t *testing.T) *LLM { t.Helper() var ollamaModel string if ollamaModel = os.Getenv("OLLAMA_TEST_MODEL"); ollamaModel == "" { @@ -20,27 +20,14 @@ func newChatClient(t *testing.T) *Chat { return nil } - c, err := NewChat(WithModel(ollamaModel)) + c, err := New(WithModel(ollamaModel)) require.NoError(t, err) return c } -func TestChatBasic(t *testing.T) { - t.Parallel() - - llm := newChatClient(t) - - resp, err := llm.Call(context.Background(), []schema.ChatMessage{ - schema.SystemChatMessage{Content: "You are producing poems in Spanish."}, - schema.HumanChatMessage{Content: "Write a very short poem about Donald Knuth"}, - }) - require.NoError(t, err) - assert.Regexp(t, "programa|comput|algoritm|libro", strings.ToLower(resp.Content)) //nolint:all -} - func TestGenerateContent(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "How many feet are in a nautical mile?"}, @@ -62,7 +49,7 @@ func TestGenerateContent(t *testing.T) { func TestWithStreaming(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "How many feet are in a nautical mile?"}, diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index 7f9825b6c..4c6f335b5 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -7,6 +7,7 @@ import ( "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama/internal/ollamaclient" + "github.com/tmc/langchaingo/schema" ) var ( @@ -40,20 +41,14 @@ func New(opts ...Option) (*LLM, error) { // Call Implement the call interface for LLM. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -// Generate implemente the generate interface for LLM. -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +// nolint: goerr113 +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } opts := llms.CallOptions{} @@ -61,65 +56,105 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca opt(&opts) } - // Load back CallOptions as ollamaOptions - ollamaOptions := o.options.ollamaOptions - ollamaOptions.NumPredict = opts.MaxTokens - ollamaOptions.Temperature = float32(opts.Temperature) - ollamaOptions.Stop = opts.StopWords - ollamaOptions.TopK = opts.TopK - ollamaOptions.TopP = float32(opts.TopP) - ollamaOptions.Seed = opts.Seed - ollamaOptions.RepeatPenalty = float32(opts.RepetitionPenalty) - ollamaOptions.FrequencyPenalty = float32(opts.FrequencyPenalty) - ollamaOptions.PresencePenalty = float32(opts.PresencePenalty) - // Override LLM model if set as llms.CallOption model := o.options.model if opts.Model != "" { model = opts.Model } - generations := make([]*llms.Generation, 0, len(prompts)) - - for _, prompt := range prompts { - req := &ollamaclient.GenerateRequest{ - Model: model, - System: o.options.system, - Prompt: prompt, - Template: o.options.customModelTemplate, - Options: ollamaOptions, - Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), + // Our input is a sequence of MessageContent, each of which potentially has + // a sequence of Part that could be text, images etc. + // We have to convert it to a format Ollama undestands: ChatRequest, which + // has a sequence of Message, each of which has a role and content - single + // text + potential images. + chatMsgs := make([]*ollamaclient.Message, 0, len(messages)) + for _, mc := range messages { + msg := &ollamaclient.Message{Role: typeToRole(mc.Role)} + + // Look at all the parts in mc; expect to find a single Text part and + // any number of binary parts. + var text string + foundText := false + var images []ollamaclient.ImageData + + for _, p := range mc.Parts { + switch pt := p.(type) { + case llms.TextContent: + if foundText { + return nil, errors.New("expecting a single Text content") + } + foundText = true + text = pt.Text + case llms.BinaryContent: + images = append(images, ollamaclient.ImageData(pt.Data)) + default: + return nil, errors.New("only support Text and BinaryContent parts right now") + } } - var fn ollamaclient.GenerateResponseFunc + msg.Content = text + msg.Images = images + chatMsgs = append(chatMsgs, msg) + } + + // Get our ollamaOptions from llms.CallOptions + ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) + req := &ollamaclient.ChatRequest{ + Model: model, + Messages: chatMsgs, + Options: ollamaOptions, + Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), + } - var output string - fn = func(response ollamaclient.GenerateResponse) error { - if opts.StreamingFunc != nil { - if err := opts.StreamingFunc(ctx, []byte(response.Response)); err != nil { - return err - } + var fn ollamaclient.ChatResponseFunc + streamedResponse := "" + var resp ollamaclient.ChatResponse + + fn = func(response ollamaclient.ChatResponse) error { + if opts.StreamingFunc != nil && response.Message != nil { + if err := opts.StreamingFunc(ctx, []byte(response.Message.Content)); err != nil { + return err } - output += response.Response - return nil } - - err := o.client.Generate(ctx, req, fn) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) + if response.Message != nil { + streamedResponse += response.Message.Content + } + if response.Done { + resp = response + resp.Message = &ollamaclient.Message{ + Role: "assistant", + Content: streamedResponse, } - return []*llms.Generation{}, err } + return nil + } + + err := o.client.GenerateChat(ctx, req, fn) + if err != nil { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err + } - generations = append(generations, &llms.Generation{Text: output}) + choices := []*llms.ContentChoice{ + { + Content: resp.Message.Content, + GenerationInfo: map[string]any{ + "CompletionTokens": resp.EvalCount, + "PromptTokens": resp.PromptEvalCount, + "TotalTokesn": resp.EvalCount + resp.PromptEvalCount, + }, + }, } + response := &llms.ContentResponse{Choices: choices} + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) } - return generations, nil + return response, nil } func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { @@ -147,3 +182,34 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo return embeddings, nil } + +func typeToRole(typ schema.ChatMessageType) string { + switch typ { + case schema.ChatMessageTypeSystem: + return "system" + case schema.ChatMessageTypeAI: + return "assistant" + case schema.ChatMessageTypeHuman: + fallthrough + case schema.ChatMessageTypeGeneric: + return "user" + case schema.ChatMessageTypeFunction: + return "function" + } + return "" +} + +func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.CallOptions) ollamaclient.Options { + // Load back CallOptions as ollamaOptions + ollamaOptions.NumPredict = opts.MaxTokens + ollamaOptions.Temperature = float32(opts.Temperature) + ollamaOptions.Stop = opts.StopWords + ollamaOptions.TopK = opts.TopK + ollamaOptions.TopP = float32(opts.TopP) + ollamaOptions.Seed = opts.Seed + ollamaOptions.RepeatPenalty = float32(opts.RepetitionPenalty) + ollamaOptions.FrequencyPenalty = float32(opts.FrequencyPenalty) + ollamaOptions.PresencePenalty = float32(opts.PresencePenalty) + + return ollamaOptions +} diff --git a/llms/ollama/ollamallm_chat.go b/llms/ollama/ollamallm_chat.go deleted file mode 100644 index 1fa54276f..000000000 --- a/llms/ollama/ollamallm_chat.go +++ /dev/null @@ -1,337 +0,0 @@ -package ollama - -import ( - "context" - "errors" - "fmt" - - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/ollama/internal/ollamaclient" - "github.com/tmc/langchaingo/schema" -) - -// LLM is a ollama LLM implementation. -type Chat struct { - CallbacksHandler callbacks.Handler - client *ollamaclient.Client - options options -} - -var _ llms.ChatLLM = (*Chat)(nil) - -// New creates a new ollama LLM implementation. -func NewChat(opts ...Option) (*Chat, error) { - o := options{} - for _, opt := range opts { - opt(&o) - } - - client, err := ollamaclient.NewClient(o.ollamaServerURL) - if err != nil { - return nil, err - } - - return &Chat{client: client, options: o}, nil -} - -// Call Implement the call interface for LLM. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { //nolint:lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -// Generate implemente the generate interface for LLM. -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { //nolint:lll,cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, o.getPromptsFromMessageSets(messageSets)) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - // Override LLM model if set as llms.CallOption - model := o.options.model - if opts.Model != "" { - model = opts.Model - } - - // Get our ollamaOptions from llms.CallOptions - ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) - - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messages := range messageSets { - req, err := messagesToChatRequest(messages) - if err != nil { - return nil, err - } - - req.Model = model - req.Options = ollamaOptions - req.Stream = func(b bool) *bool { return &b }(opts.StreamingFunc != nil) - - var fn ollamaclient.ChatResponseFunc - - streamedResponse := "" - var resp ollamaclient.ChatResponse - - fn = func(response ollamaclient.ChatResponse) error { - if opts.StreamingFunc != nil && response.Message != nil { - if err := opts.StreamingFunc(ctx, []byte(response.Message.Content)); err != nil { - return err - } - } - if response.Message != nil { - streamedResponse += response.Message.Content - } - if response.Done { - resp = response - resp.Message = &ollamaclient.Message{ - Role: "assistant", - Content: streamedResponse, - } - } - return nil - } - - err = o.client.GenerateChat(ctx, req, fn) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return []*llms.Generation{}, err - } - - generations = append(generations, makeGenerationFromChatResponse(resp)) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - -// GenerateContent implements the Model interface. -// nolint: goerr113 -func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - // Override LLM model if set as llms.CallOption - model := o.options.model - if opts.Model != "" { - model = opts.Model - } - - // Our input is a sequence of MessageContent, each of which potentially has - // a sequence of Part that could be text, images etc. - // We have to convert it to a format Ollama undestands: ChatRequest, which - // has a sequence of Message, each of which has a role and content - single - // text + potential images. - chatMsgs := make([]*ollamaclient.Message, 0, len(messages)) - for _, mc := range messages { - msg := &ollamaclient.Message{Role: typeToRole(mc.Role)} - - // Look at all the parts in mc; expect to find a single Text part and - // any number of binary parts. - var text string - foundText := false - var images []ollamaclient.ImageData - - for _, p := range mc.Parts { - switch pt := p.(type) { - case llms.TextContent: - if foundText { - return nil, errors.New("expecting a single Text content") - } - foundText = true - text = pt.Text - case llms.BinaryContent: - images = append(images, ollamaclient.ImageData(pt.Data)) - default: - return nil, errors.New("only support Text and BinaryContent parts right now") - } - } - - msg.Content = text - msg.Images = images - chatMsgs = append(chatMsgs, msg) - } - - // Get our ollamaOptions from llms.CallOptions - ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) - req := &ollamaclient.ChatRequest{ - Model: model, - Messages: chatMsgs, - Options: ollamaOptions, - Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), - } - - var fn ollamaclient.ChatResponseFunc - streamedResponse := "" - var resp ollamaclient.ChatResponse - - fn = func(response ollamaclient.ChatResponse) error { - if opts.StreamingFunc != nil && response.Message != nil { - if err := opts.StreamingFunc(ctx, []byte(response.Message.Content)); err != nil { - return err - } - } - if response.Message != nil { - streamedResponse += response.Message.Content - } - if response.Done { - resp = response - resp.Message = &ollamaclient.Message{ - Role: "assistant", - Content: streamedResponse, - } - } - return nil - } - - err := o.client.GenerateChat(ctx, req, fn) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - choices := []*llms.ContentChoice{ - { - Content: resp.Message.Content, - GenerationInfo: map[string]any{ - "CompletionTokens": resp.EvalCount, - "PromptTokens": resp.PromptEvalCount, - "TotalTokesn": resp.EvalCount + resp.PromptEvalCount, - }, - }, - } - - return &llms.ContentResponse{Choices: choices}, nil -} - -func makeGenerationFromChatResponse(resp ollamaclient.ChatResponse) *llms.Generation { - msg := &schema.AIChatMessage{ - Content: resp.Message.Content, - } - - gen := &llms.Generation{ - Message: msg, - Text: msg.Content, - GenerationInfo: make(map[string]any), - } - - gen.GenerationInfo["CompletionTokens"] = resp.EvalCount - gen.GenerationInfo["PromptTokens"] = resp.PromptEvalCount - gen.GenerationInfo["TotalTokens"] = resp.PromptEvalCount + resp.EvalCount - - return gen -} - -func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.CallOptions) ollamaclient.Options { - // Load back CallOptions as ollamaOptions - ollamaOptions.NumPredict = opts.MaxTokens - ollamaOptions.Temperature = float32(opts.Temperature) - ollamaOptions.Stop = opts.StopWords - ollamaOptions.TopK = opts.TopK - ollamaOptions.TopP = float32(opts.TopP) - ollamaOptions.Seed = opts.Seed - ollamaOptions.RepeatPenalty = float32(opts.RepetitionPenalty) - ollamaOptions.FrequencyPenalty = float32(opts.FrequencyPenalty) - ollamaOptions.PresencePenalty = float32(opts.PresencePenalty) - - return ollamaOptions -} - -func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { - embeddings := [][]float32{} - - for _, input := range inputTexts { - embedding, err := o.client.CreateEmbedding(ctx, &ollamaclient.EmbeddingRequest{ - Prompt: input, - Model: o.options.model, - }) - if err != nil { - return nil, err - } - - if len(embedding.Embedding) == 0 { - return nil, ErrEmptyResponse - } - - embeddings = append(embeddings, embedding.Embedding) - } - - if len(inputTexts) != len(embeddings) { - return embeddings, ErrIncompleteEmbedding - } - - return embeddings, nil -} - -func (o Chat) getPromptsFromMessageSets(messageSets [][]schema.ChatMessage) []string { - prompts := make([]string, 0, len(messageSets)) - for i := 0; i < len(messageSets); i++ { - curPrompt := "" - for j := 0; j < len(messageSets[i]); j++ { - curPrompt += messageSets[i][j].GetContent() - } - prompts = append(prompts, curPrompt) - } - return prompts -} - -func messagesToChatRequest(messages []schema.ChatMessage) (*ollamaclient.ChatRequest, error) { - req := &ollamaclient.ChatRequest{} - for _, m := range messages { - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: - fallthrough - case schema.ChatMessageTypeAI: - req.Messages = append(req.Messages, &ollamaclient.Message{ - Role: typeToRole(typ), - Content: m.GetContent(), - }) - case schema.ChatMessageTypeHuman: - fallthrough - case schema.ChatMessageTypeGeneric: - req.Messages = append(req.Messages, &ollamaclient.Message{ - Role: typeToRole(typ), - Content: m.GetContent(), - }) - case schema.ChatMessageTypeFunction: - return nil, fmt.Errorf("chat message type %s not implemented", typ) //nolint:goerr113 - } - } - return req, nil -} - -func typeToRole(typ schema.ChatMessageType) string { - switch typ { - case schema.ChatMessageTypeSystem: - return "system" - case schema.ChatMessageTypeAI: - return "assistant" - case schema.ChatMessageTypeHuman: - fallthrough - case schema.ChatMessageTypeGeneric: - return "user" - case schema.ChatMessageTypeFunction: - return "function" - } - return "" -} diff --git a/llms/openai/multicontent_test.go b/llms/openai/multicontent_test.go index 44ca0b858..bdbcb43ac 100644 --- a/llms/openai/multicontent_test.go +++ b/llms/openai/multicontent_test.go @@ -13,21 +13,21 @@ import ( "github.com/tmc/langchaingo/schema" ) -func newChatClient(t *testing.T, opts ...Option) *Chat { +func newTestClient(t *testing.T, opts ...Option) *LLM { t.Helper() if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { t.Skip("OPENAI_API_KEY not set") return nil } - llm, err := NewChat(opts...) + llm, err := New(opts...) require.NoError(t, err) return llm } func TestMultiContentText(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "I'm a pomeranian"}, @@ -50,7 +50,7 @@ func TestMultiContentText(t *testing.T) { func TestMultiContentTextChatSequence(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) content := []llms.MessageContent{ { @@ -78,7 +78,7 @@ func TestMultiContentTextChatSequence(t *testing.T) { func TestMultiContentImage(t *testing.T) { t.Parallel() - llm := newChatClient(t, WithModel("gpt-4-vision-preview")) + llm := newTestClient(t, WithModel("gpt-4-vision-preview")) parts := []llms.ContentPart{ llms.ImageURLContent{URL: "https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true"}, @@ -101,7 +101,7 @@ func TestMultiContentImage(t *testing.T) { func TestWithStreaming(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "I'm a pomeranian"}, @@ -132,7 +132,7 @@ func TestWithStreaming(t *testing.T) { //nolint:lll func TestFunctionCall(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "What is the weather like in Boston?"}, diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 3fe80d8fd..3cfe94edb 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -2,17 +2,28 @@ package openai import ( "context" + "fmt" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai/internal/openaiclient" + "github.com/tmc/langchaingo/schema" ) +type ChatMessage = openaiclient.ChatMessage + type LLM struct { CallbacksHandler callbacks.Handler client *openaiclient.Client } +const ( + RoleSystem = "system" + RoleAssistant = "assistant" + RoleUser = "user" + RoleFunction = "function" +) + var _ llms.LLM = (*LLM)(nil) // New returns a new OpenAI LLM. @@ -29,19 +40,15 @@ func New(opts ...Option) (*LLM, error) { // Call requests a completion for the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +// +//nolint:goerr113 +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } opts := llms.CallOptions{} @@ -49,36 +56,82 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca opt(&opts) } - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, &openaiclient.CompletionRequest{ - Model: opts.Model, - Prompt: prompt, - MaxTokens: opts.MaxTokens, - StopWords: opts.StopWords, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - TopP: opts.TopP, - StreamingFunc: opts.StreamingFunc, + chatMsgs := make([]*ChatMessage, 0, len(messages)) + for _, mc := range messages { + msg := &ChatMessage{MultiContent: mc.Parts} + switch mc.Role { + case schema.ChatMessageTypeSystem: + msg.Role = RoleSystem + case schema.ChatMessageTypeAI: + msg.Role = RoleAssistant + case schema.ChatMessageTypeHuman: + msg.Role = RoleUser + case schema.ChatMessageTypeGeneric: + msg.Role = RoleUser + case schema.ChatMessageTypeFunction: + fallthrough + default: + return nil, fmt.Errorf("role %v not supported", mc.Role) + } + + chatMsgs = append(chatMsgs, msg) + } + + req := &openaiclient.ChatRequest{ + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + MaxTokens: opts.MaxTokens, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), + } + + for _, fn := range opts.Functions { + req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ + Name: fn.Name, + Description: fn.Description, + Parameters: fn.Parameters, }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) + } + result, err := o.client.CreateChat(ctx, req) + if err != nil { + return nil, err + } + if len(result.Choices) == 0 { + return nil, ErrEmptyResponse + } + + choices := make([]*llms.ContentChoice, len(result.Choices)) + for i, c := range result.Choices { + choices[i] = &llms.ContentChoice{ + Content: c.Message.Content, + StopReason: c.FinishReason, + GenerationInfo: map[string]any{ + "CompletionTokens": result.Usage.CompletionTokens, + "PromptTokens": result.Usage.PromptTokens, + "TotalTokens": result.Usage.TotalTokens, + }, + } + + if c.FinishReason == "function_call" { + choices[i].FuncCall = &schema.FunctionCall{ + Name: c.Message.FunctionCall.Name, + Arguments: c.Message.FunctionCall.Arguments, } - return nil, err } - generations = append(generations, &llms.Generation{ - Text: result.Text, - }) } + response := &llms.ContentResponse{Choices: choices} + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) } - return generations, nil + return response, nil } // CreateEmbedding creates embeddings for the given input texts. diff --git a/llms/openai/openaillm_chat.go b/llms/openai/openaillm_chat.go deleted file mode 100644 index 1e9dba7c3..000000000 --- a/llms/openai/openaillm_chat.go +++ /dev/null @@ -1,264 +0,0 @@ -package openai - -import ( - "context" - "fmt" - "reflect" - - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/openai/internal/openaiclient" - "github.com/tmc/langchaingo/schema" -) - -type ChatMessage = openaiclient.ChatMessage - -type Chat struct { - CallbacksHandler callbacks.Handler - client *openaiclient.Client -} - -const ( - RoleSystem = "system" - RoleAssistant = "assistant" - RoleUser = "user" - RoleFunction = "function" -) - -var _ llms.ChatLLM = (*Chat)(nil) - -// NewChat returns a new OpenAI chat LLM. -func NewChat(opts ...Option) (*Chat, error) { - opt, c, err := newClient(opts...) - if err != nil { - return nil, err - } - return &Chat{ - client: c, - CallbacksHandler: opt.callbackHandler, - }, err -} - -// GenerateContent implements the Model interface. -// -//nolint:goerr113 -func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - chatMsgs := make([]*ChatMessage, 0, len(messages)) - for _, mc := range messages { - msg := &ChatMessage{MultiContent: mc.Parts} - switch mc.Role { - case schema.ChatMessageTypeSystem: - msg.Role = RoleSystem - case schema.ChatMessageTypeAI: - msg.Role = RoleAssistant - case schema.ChatMessageTypeHuman: - msg.Role = RoleUser - case schema.ChatMessageTypeGeneric: - msg.Role = RoleUser - case schema.ChatMessageTypeFunction: - fallthrough - default: - return nil, fmt.Errorf("role %v not supported", mc.Role) - } - - chatMsgs = append(chatMsgs, msg) - } - - req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - MaxTokens: opts.MaxTokens, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), - } - - for _, fn := range opts.Functions { - req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ - Name: fn.Name, - Description: fn.Description, - Parameters: fn.Parameters, - }) - } - result, err := o.client.CreateChat(ctx, req) - if err != nil { - return nil, err - } - if len(result.Choices) == 0 { - return nil, ErrEmptyResponse - } - - choices := make([]*llms.ContentChoice, len(result.Choices)) - for i, c := range result.Choices { - choices[i] = &llms.ContentChoice{ - Content: c.Message.Content, - StopReason: c.FinishReason, - GenerationInfo: map[string]any{ - "CompletionTokens": result.Usage.CompletionTokens, - "PromptTokens": result.Usage.PromptTokens, - "TotalTokens": result.Usage.TotalTokens, - }, - } - - if c.FinishReason == "function_call" { - choices[i].FuncCall = &schema.FunctionCall{ - Name: c.Message.FunctionCall.Name, - Arguments: c.Message.FunctionCall.Arguments, - } - } - } - - return &llms.ContentResponse{Choices: choices}, nil -} - -// Call requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -//nolint:funlen -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint:lll,cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, getPromptsFromMessageSets(messageSets)) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messageSet := range messageSets { - req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: messagesToClientMessages(messageSet), - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - MaxTokens: opts.MaxTokens, - N: opts.N, // TODO: note, we are not returning multiple completions - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - - FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), - } - for _, fn := range opts.Functions { - req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ - Name: fn.Name, - Description: fn.Description, - Parameters: fn.Parameters, - }) - } - result, err := o.client.CreateChat(ctx, req) - if err != nil { - return nil, err - } - if len(result.Choices) == 0 { - return nil, ErrEmptyResponse - } - generationInfo := make(map[string]any, reflect.ValueOf(result.Usage).NumField()) - generationInfo["CompletionTokens"] = result.Usage.CompletionTokens - generationInfo["PromptTokens"] = result.Usage.PromptTokens - generationInfo["TotalTokens"] = result.Usage.TotalTokens - msg := &schema.AIChatMessage{ - Content: result.Choices[0].Message.Content, - } - if result.Choices[0].FinishReason == "function_call" { - msg.FunctionCall = &schema.FunctionCall{ - Name: result.Choices[0].Message.FunctionCall.Name, - Arguments: result.Choices[0].Message.FunctionCall.Arguments, - } - } - generations = append(generations, &llms.Generation{ - Message: msg, - Text: msg.Content, - GenerationInfo: generationInfo, - StopReason: result.Choices[0].FinishReason, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - -// CreateEmbedding creates embeddings for the given input texts. -func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { - embeddings, err := o.client.CreateEmbedding(ctx, &openaiclient.EmbeddingRequest{ - Input: inputTexts, - }) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, ErrEmptyResponse - } - if len(inputTexts) != len(embeddings) { - return embeddings, ErrUnexpectedResponseLength - } - return embeddings, nil -} - -func getPromptsFromMessageSets(messageSets [][]schema.ChatMessage) []string { - prompts := make([]string, 0, len(messageSets)) - for i := 0; i < len(messageSets); i++ { - curPrompt := "" - for j := 0; j < len(messageSets[i]); j++ { - curPrompt += messageSets[i][j].GetContent() - } - prompts = append(prompts, curPrompt) - } - - return prompts -} - -func messagesToClientMessages(messages []schema.ChatMessage) []*openaiclient.ChatMessage { - msgs := make([]*openaiclient.ChatMessage, len(messages)) - for i, m := range messages { - msg := &openaiclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: - msg.Role = "system" - case schema.ChatMessageTypeAI: - msg.Role = "assistant" - if mm, ok := m.(schema.AIChatMessage); ok && mm.FunctionCall != nil { - msg.FunctionCall = &openaiclient.FunctionCall{ - Name: mm.FunctionCall.Name, - Arguments: mm.FunctionCall.Arguments, - } - } - case schema.ChatMessageTypeHuman: - msg.Role = "user" - case schema.ChatMessageTypeGeneric: - msg.Role = "user" - case schema.ChatMessageTypeFunction: - msg.Role = "function" - } - if n, ok := m.(schema.Named); ok { - msg.Name = n.GetName() - } - msgs[i] = msg - } - - return msgs -} diff --git a/llms/vertexai/vertexai_palm_llm.go b/llms/vertexai/vertexai_palm_llm.go index d9bde1e13..52bf59a14 100644 --- a/llms/vertexai/vertexai_palm_llm.go +++ b/llms/vertexai/vertexai_palm_llm.go @@ -25,27 +25,27 @@ var _ llms.LLM = (*LLM)(nil) // Call requests a completion for the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } opts := llms.CallOptions{} for _, opt := range options { opt(&opts) } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + results, err := o.client.CreateCompletion(ctx, &vertexaiclient.CompletionRequest{ - Prompts: prompts, + Prompts: []string{part.(llms.TextContent).Text}, MaxTokens: opts.MaxTokens, Temperature: opts.Temperature, StopSequences: opts.StopWords, @@ -57,17 +57,18 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca return nil, err } - generations := []*llms.Generation{} - for _, r := range results { - generations = append(generations, &llms.Generation{ - Text: r.Text, - }) + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + { + Content: results[0].Text, + }, + }, } - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp) } - return generations, nil + + return resp, nil } // CreateEmbedding creates embeddings for the given input texts. diff --git a/llms/vertexai/vertexai_palm_llm_chat.go b/llms/vertexai/vertexai_palm_llm_chat.go deleted file mode 100644 index 80a96a049..000000000 --- a/llms/vertexai/vertexai_palm_llm_chat.go +++ /dev/null @@ -1,137 +0,0 @@ -package vertexai - -import ( - "context" - - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/vertexai/internal/vertexaiclient" - "github.com/tmc/langchaingo/schema" -) - -const ( - userAuthor = "user" - botAuthor = "bot" -) - -type ChatMessage = vertexaiclient.ChatMessage - -type Chat struct { - client *vertexaiclient.PaLMClient -} - -var _ llms.ChatLLM = (*Chat)(nil) - -// Chat requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -// Generate requests a chat response for each of the sets of messages. -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint: lll - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - if opts.StreamingFunc != nil { - return nil, ErrNotImplemented - } - - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messages := range messageSets { - chatContext := parseContext(messages) - if len(chatContext) > 0 { - // remove system context from messages - messages = messages[1:] - } - msgs := toClientChatMessage(messages) - result, err := o.client.CreateChat(ctx, &vertexaiclient.ChatRequest{ - Temperature: opts.Temperature, - Messages: msgs, - Context: chatContext, - }) - if err != nil { - return nil, err - } - if len(result.Candidates) == 0 { - return nil, ErrEmptyResponse - } - generations = append(generations, &llms.Generation{ - Message: &schema.AIChatMessage{ - Content: result.Candidates[0].Content, - }, - Text: result.Candidates[0].Content, - }) - } - - return generations, nil -} - -func toClientChatMessage(messages []schema.ChatMessage) []*vertexaiclient.ChatMessage { - msgs := make([]*vertexaiclient.ChatMessage, len(messages)) - - for i, m := range messages { - msg := &vertexaiclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - - switch typ { - case schema.ChatMessageTypeSystem: - msg.Author = botAuthor - case schema.ChatMessageTypeAI: - msg.Author = botAuthor - case schema.ChatMessageTypeHuman: - msg.Author = userAuthor - case schema.ChatMessageTypeGeneric: - msg.Author = userAuthor - case schema.ChatMessageTypeFunction: - msg.Author = userAuthor - } - if n, ok := m.(schema.Named); ok { - msg.Author = n.GetName() - } - msgs[i] = msg - } - return msgs -} - -func parseContext(messages []schema.ChatMessage) string { - if len(messages) == 0 { - return "" - } - // check if 1st message type is system. use it as context. - if messages[0].GetType() == schema.ChatMessageTypeSystem { - return messages[0].GetContent() - } - return "" -} - -// NewChat returns a new VertexAI PaLM Chat LLM. -func NewChat(opts ...Option) (*Chat, error) { - client, err := newClient(opts...) - return &Chat{client: client}, err -} - -// CreateEmbedding creates embeddings for the given input texts. -func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { - embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{ - Input: inputTexts, - }) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, ErrEmptyResponse - } - if len(inputTexts) != len(embeddings) { - return embeddings, ErrUnexpectedResponseLength - } - return embeddings, nil -}