From 405efead699ca28c33b63931f52bec8aefe79c10 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Tue, 16 Jan 2024 05:52:39 -0800 Subject: [PATCH 01/19] Add new callback for GenerateContent start and end And invoke it in the new methods Re #465 --- callbacks/callbacks.go | 2 ++ callbacks/combining.go | 12 ++++++++++++ callbacks/simple.go | 32 +++++++++++++++++--------------- llms/googleai/googleai_llm.go | 27 +++++++++++++++++++++++---- llms/llms.go | 2 +- llms/ollama/ollamallm_chat.go | 12 +++++++++++- llms/openai/openaillm_chat.go | 12 +++++++++++- 7 files changed, 77 insertions(+), 22 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 9bd05242a..15e83cc6c 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -15,6 +15,8 @@ type Handler interface { HandleText(ctx context.Context, text string) HandleLLMStart(ctx context.Context, prompts []string) HandleLLMEnd(ctx context.Context, output llms.LLMResult) + HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) + HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) HandleLLMError(ctx context.Context, err error) HandleChainStart(ctx context.Context, inputs map[string]any) HandleChainEnd(ctx context.Context, outputs map[string]any) diff --git a/callbacks/combining.go b/callbacks/combining.go index 22ad518bc..0f3ef07ac 100644 --- a/callbacks/combining.go +++ b/callbacks/combining.go @@ -32,6 +32,18 @@ func (l CombiningHandler) HandleLLMEnd(ctx context.Context, output llms.LLMResul } } +func (l CombiningHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) { + for _, handle := range l.Callbacks { + handle.HandleLLMGenerateContentStart(ctx, ms) + } +} + +func (l CombiningHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) { + for _, handle := range l.Callbacks { + handle.HandleLLMGenerateContentEnd(ctx, res) + } +} + func (l CombiningHandler) HandleChainStart(ctx context.Context, inputs map[string]any) { for _, handle := range l.Callbacks { handle.HandleChainStart(ctx, inputs) diff --git a/callbacks/simple.go b/callbacks/simple.go index db24e659b..3f8ded395 100644 --- a/callbacks/simple.go +++ b/callbacks/simple.go @@ -12,18 +12,20 @@ type SimpleHandler struct{} var _ Handler = SimpleHandler{} -func (SimpleHandler) HandleText(context.Context, string) {} -func (SimpleHandler) HandleLLMStart(context.Context, []string) {} -func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {} -func (SimpleHandler) HandleLLMError(context.Context, error) {} -func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {} -func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {} -func (SimpleHandler) HandleChainError(context.Context, error) {} -func (SimpleHandler) HandleToolStart(context.Context, string) {} -func (SimpleHandler) HandleToolEnd(context.Context, string) {} -func (SimpleHandler) HandleToolError(context.Context, error) {} -func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {} -func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {} -func (SimpleHandler) HandleRetrieverStart(context.Context, string) {} -func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {} -func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {} +func (SimpleHandler) HandleText(context.Context, string) {} +func (SimpleHandler) HandleLLMStart(context.Context, []string) {} +func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {} +func (SimpleHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) {} +func (SimpleHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) {} +func (SimpleHandler) HandleLLMError(context.Context, error) {} +func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {} +func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {} +func (SimpleHandler) HandleChainError(context.Context, error) {} +func (SimpleHandler) HandleToolStart(context.Context, string) {} +func (SimpleHandler) HandleToolEnd(context.Context, string) {} +func (SimpleHandler) HandleToolError(context.Context, error) {} +func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {} +func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {} +func (SimpleHandler) HandleRetrieverStart(context.Context, string) {} +func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {} +func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {} diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index f4cf9d7dd..833504497 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/google/generative-ai-go/genai" + "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" "google.golang.org/api/iterator" @@ -22,8 +23,9 @@ import ( // GoogleAI is a type that represents a Google AI API client. type GoogleAI struct { - client *genai.Client - opts options + CallbacksHandler callbacks.Handler + client *genai.Client + opts options } var ( @@ -64,6 +66,10 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) { // GenerateContent calls the LLM with the provided parts. func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if g.CallbacksHandler != nil { + g.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + opts := llms.CallOptions{ Model: g.opts.defaultModel, MaxTokens: int(g.opts.defaultMaxTokens), @@ -77,14 +83,27 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC model.SetMaxOutputTokens(int32(opts.MaxTokens)) model.SetTemperature(float32(opts.Temperature)) + var response *llms.ContentResponse + var err error + if len(messages) == 1 { theMessage := messages[0] if theMessage.Role != schema.ChatMessageTypeHuman { return nil, fmt.Errorf("got %v message role, want human", theMessage.Role) } - return generateFromSingleMessage(ctx, model, theMessage.Parts, &opts) + response, err = generateFromSingleMessage(ctx, model, theMessage.Parts, &opts) + } else { + response, err = generateFromMessages(ctx, model, messages, &opts) } - return generateFromMessages(ctx, model, messages, &opts) + if err != nil { + return nil, err + } + + if g.CallbacksHandler != nil { + g.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil } // downloadImageData downloads the content from the given URL and returns it as diff --git a/llms/llms.go b/llms/llms.go index 33456af4b..200fe189c 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -24,7 +24,7 @@ type Model interface { // GenerateContent asks the model to generate content from a sequence of // messages. It's the most general interface for LLMs that support chat-like // interactions. - GenerateContent(ctx context.Context, parts []MessageContent, options ...CallOption) (*ContentResponse, error) + GenerateContent(ctx context.Context, messages []MessageContent, options ...CallOption) (*ContentResponse, error) } // Generation is a single generation from a langchaingo LLM. diff --git a/llms/ollama/ollamallm_chat.go b/llms/ollama/ollamallm_chat.go index 1fa54276f..b21788d8c 100644 --- a/llms/ollama/ollamallm_chat.go +++ b/llms/ollama/ollamallm_chat.go @@ -123,6 +123,10 @@ func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, // GenerateContent implements the Model interface. // nolint: goerr113 func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + opts := llms.CallOptions{} for _, opt := range options { opt(&opts) @@ -220,7 +224,13 @@ func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageConte }, } - return &llms.ContentResponse{Choices: choices}, nil + response := &llms.ContentResponse{Choices: choices} + + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil } func makeGenerationFromChatResponse(resp ollamaclient.ChatResponse) *llms.Generation { diff --git a/llms/openai/openaillm_chat.go b/llms/openai/openaillm_chat.go index 1e9dba7c3..12219c501 100644 --- a/llms/openai/openaillm_chat.go +++ b/llms/openai/openaillm_chat.go @@ -43,6 +43,10 @@ func NewChat(opts ...Option) (*Chat, error) { // //nolint:goerr113 func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + opts := llms.CallOptions{} for _, opt := range options { opt(&opts) @@ -117,7 +121,13 @@ func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageConte } } - return &llms.ContentResponse{Choices: choices}, nil + response := &llms.ContentResponse{Choices: choices} + + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil } // Call requests a chat response for the given messages. From 97397f14414aefd3a75fe610b2d4b5248a142464 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 04:51:53 -0800 Subject: [PATCH 02/19] Move openai chat client functionality to its LLM, remove separate Chat The chat llm client calls use the newer OpenAI API to implement the same functionality --- embeddings/openai_test.go | 16 -- llms/openai/multicontent_test.go | 14 +- llms/openai/openaillm.go | 102 ++++++++++++ llms/openai/openaillm_chat.go | 274 ------------------------------- 4 files changed, 109 insertions(+), 297 deletions(-) delete mode 100644 llms/openai/openaillm_chat.go diff --git a/embeddings/openai_test.go b/embeddings/openai_test.go index eea891db7..c4adf5a61 100644 --- a/embeddings/openai_test.go +++ b/embeddings/openai_test.go @@ -100,19 +100,3 @@ func TestOpenaiEmbeddingsWithAzureAPI(t *testing.T) { require.NoError(t, err) assert.Len(t, embeddings, 1) } - -func TestUseLLMAndChatAsEmbedderClient(t *testing.T) { - t.Parallel() - - if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { - t.Skip("OPENAI_API_KEY not set") - } - - // Shows that we can pass an openai chat value to NewEmbedder - chat, err := openai.NewChat() - require.NoError(t, err) - - embedderFromChat, err := NewEmbedder(chat) - require.NoError(t, err) - var _ Embedder = embedderFromChat -} diff --git a/llms/openai/multicontent_test.go b/llms/openai/multicontent_test.go index 44ca0b858..bdbcb43ac 100644 --- a/llms/openai/multicontent_test.go +++ b/llms/openai/multicontent_test.go @@ -13,21 +13,21 @@ import ( "github.com/tmc/langchaingo/schema" ) -func newChatClient(t *testing.T, opts ...Option) *Chat { +func newTestClient(t *testing.T, opts ...Option) *LLM { t.Helper() if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { t.Skip("OPENAI_API_KEY not set") return nil } - llm, err := NewChat(opts...) + llm, err := New(opts...) require.NoError(t, err) return llm } func TestMultiContentText(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "I'm a pomeranian"}, @@ -50,7 +50,7 @@ func TestMultiContentText(t *testing.T) { func TestMultiContentTextChatSequence(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) content := []llms.MessageContent{ { @@ -78,7 +78,7 @@ func TestMultiContentTextChatSequence(t *testing.T) { func TestMultiContentImage(t *testing.T) { t.Parallel() - llm := newChatClient(t, WithModel("gpt-4-vision-preview")) + llm := newTestClient(t, WithModel("gpt-4-vision-preview")) parts := []llms.ContentPart{ llms.ImageURLContent{URL: "https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true"}, @@ -101,7 +101,7 @@ func TestMultiContentImage(t *testing.T) { func TestWithStreaming(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "I'm a pomeranian"}, @@ -132,7 +132,7 @@ func TestWithStreaming(t *testing.T) { //nolint:lll func TestFunctionCall(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "What is the weather like in Boston?"}, diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 3fe80d8fd..4c85d8845 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -2,17 +2,28 @@ package openai import ( "context" + "fmt" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai/internal/openaiclient" + "github.com/tmc/langchaingo/schema" ) +type ChatMessage = openaiclient.ChatMessage + type LLM struct { CallbacksHandler callbacks.Handler client *openaiclient.Client } +const ( + RoleSystem = "system" + RoleAssistant = "assistant" + RoleUser = "user" + RoleFunction = "function" +) + var _ llms.LLM = (*LLM)(nil) // New returns a new OpenAI LLM. @@ -39,6 +50,97 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio return r[0].Text, nil } +// GenerateContent implements the Model interface. +// +//nolint:goerr113 +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + chatMsgs := make([]*ChatMessage, 0, len(messages)) + for _, mc := range messages { + msg := &ChatMessage{MultiContent: mc.Parts} + switch mc.Role { + case schema.ChatMessageTypeSystem: + msg.Role = RoleSystem + case schema.ChatMessageTypeAI: + msg.Role = RoleAssistant + case schema.ChatMessageTypeHuman: + msg.Role = RoleUser + case schema.ChatMessageTypeGeneric: + msg.Role = RoleUser + case schema.ChatMessageTypeFunction: + fallthrough + default: + return nil, fmt.Errorf("role %v not supported", mc.Role) + } + + chatMsgs = append(chatMsgs, msg) + } + + req := &openaiclient.ChatRequest{ + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + MaxTokens: opts.MaxTokens, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), + } + + for _, fn := range opts.Functions { + req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ + Name: fn.Name, + Description: fn.Description, + Parameters: fn.Parameters, + }) + } + result, err := o.client.CreateChat(ctx, req) + if err != nil { + return nil, err + } + if len(result.Choices) == 0 { + return nil, ErrEmptyResponse + } + + choices := make([]*llms.ContentChoice, len(result.Choices)) + for i, c := range result.Choices { + choices[i] = &llms.ContentChoice{ + Content: c.Message.Content, + StopReason: c.FinishReason, + GenerationInfo: map[string]any{ + "CompletionTokens": result.Usage.CompletionTokens, + "PromptTokens": result.Usage.PromptTokens, + "TotalTokens": result.Usage.TotalTokens, + }, + } + + if c.FinishReason == "function_call" { + choices[i].FuncCall = &schema.FunctionCall{ + Name: c.Message.FunctionCall.Name, + Arguments: c.Message.FunctionCall.Arguments, + } + } + } + + response := &llms.ContentResponse{Choices: choices} + + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil +} + func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMStart(ctx, prompts) diff --git a/llms/openai/openaillm_chat.go b/llms/openai/openaillm_chat.go deleted file mode 100644 index 12219c501..000000000 --- a/llms/openai/openaillm_chat.go +++ /dev/null @@ -1,274 +0,0 @@ -package openai - -import ( - "context" - "fmt" - "reflect" - - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/openai/internal/openaiclient" - "github.com/tmc/langchaingo/schema" -) - -type ChatMessage = openaiclient.ChatMessage - -type Chat struct { - CallbacksHandler callbacks.Handler - client *openaiclient.Client -} - -const ( - RoleSystem = "system" - RoleAssistant = "assistant" - RoleUser = "user" - RoleFunction = "function" -) - -var _ llms.ChatLLM = (*Chat)(nil) - -// NewChat returns a new OpenAI chat LLM. -func NewChat(opts ...Option) (*Chat, error) { - opt, c, err := newClient(opts...) - if err != nil { - return nil, err - } - return &Chat{ - client: c, - CallbacksHandler: opt.callbackHandler, - }, err -} - -// GenerateContent implements the Model interface. -// -//nolint:goerr113 -func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - chatMsgs := make([]*ChatMessage, 0, len(messages)) - for _, mc := range messages { - msg := &ChatMessage{MultiContent: mc.Parts} - switch mc.Role { - case schema.ChatMessageTypeSystem: - msg.Role = RoleSystem - case schema.ChatMessageTypeAI: - msg.Role = RoleAssistant - case schema.ChatMessageTypeHuman: - msg.Role = RoleUser - case schema.ChatMessageTypeGeneric: - msg.Role = RoleUser - case schema.ChatMessageTypeFunction: - fallthrough - default: - return nil, fmt.Errorf("role %v not supported", mc.Role) - } - - chatMsgs = append(chatMsgs, msg) - } - - req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - MaxTokens: opts.MaxTokens, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), - } - - for _, fn := range opts.Functions { - req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ - Name: fn.Name, - Description: fn.Description, - Parameters: fn.Parameters, - }) - } - result, err := o.client.CreateChat(ctx, req) - if err != nil { - return nil, err - } - if len(result.Choices) == 0 { - return nil, ErrEmptyResponse - } - - choices := make([]*llms.ContentChoice, len(result.Choices)) - for i, c := range result.Choices { - choices[i] = &llms.ContentChoice{ - Content: c.Message.Content, - StopReason: c.FinishReason, - GenerationInfo: map[string]any{ - "CompletionTokens": result.Usage.CompletionTokens, - "PromptTokens": result.Usage.PromptTokens, - "TotalTokens": result.Usage.TotalTokens, - }, - } - - if c.FinishReason == "function_call" { - choices[i].FuncCall = &schema.FunctionCall{ - Name: c.Message.FunctionCall.Name, - Arguments: c.Message.FunctionCall.Arguments, - } - } - } - - response := &llms.ContentResponse{Choices: choices} - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) - } - - return response, nil -} - -// Call requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -//nolint:funlen -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint:lll,cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, getPromptsFromMessageSets(messageSets)) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messageSet := range messageSets { - req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: messagesToClientMessages(messageSet), - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - MaxTokens: opts.MaxTokens, - N: opts.N, // TODO: note, we are not returning multiple completions - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - - FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), - } - for _, fn := range opts.Functions { - req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ - Name: fn.Name, - Description: fn.Description, - Parameters: fn.Parameters, - }) - } - result, err := o.client.CreateChat(ctx, req) - if err != nil { - return nil, err - } - if len(result.Choices) == 0 { - return nil, ErrEmptyResponse - } - generationInfo := make(map[string]any, reflect.ValueOf(result.Usage).NumField()) - generationInfo["CompletionTokens"] = result.Usage.CompletionTokens - generationInfo["PromptTokens"] = result.Usage.PromptTokens - generationInfo["TotalTokens"] = result.Usage.TotalTokens - msg := &schema.AIChatMessage{ - Content: result.Choices[0].Message.Content, - } - if result.Choices[0].FinishReason == "function_call" { - msg.FunctionCall = &schema.FunctionCall{ - Name: result.Choices[0].Message.FunctionCall.Name, - Arguments: result.Choices[0].Message.FunctionCall.Arguments, - } - } - generations = append(generations, &llms.Generation{ - Message: msg, - Text: msg.Content, - GenerationInfo: generationInfo, - StopReason: result.Choices[0].FinishReason, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - -// CreateEmbedding creates embeddings for the given input texts. -func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { - embeddings, err := o.client.CreateEmbedding(ctx, &openaiclient.EmbeddingRequest{ - Input: inputTexts, - }) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, ErrEmptyResponse - } - if len(inputTexts) != len(embeddings) { - return embeddings, ErrUnexpectedResponseLength - } - return embeddings, nil -} - -func getPromptsFromMessageSets(messageSets [][]schema.ChatMessage) []string { - prompts := make([]string, 0, len(messageSets)) - for i := 0; i < len(messageSets); i++ { - curPrompt := "" - for j := 0; j < len(messageSets[i]); j++ { - curPrompt += messageSets[i][j].GetContent() - } - prompts = append(prompts, curPrompt) - } - - return prompts -} - -func messagesToClientMessages(messages []schema.ChatMessage) []*openaiclient.ChatMessage { - msgs := make([]*openaiclient.ChatMessage, len(messages)) - for i, m := range messages { - msg := &openaiclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: - msg.Role = "system" - case schema.ChatMessageTypeAI: - msg.Role = "assistant" - if mm, ok := m.(schema.AIChatMessage); ok && mm.FunctionCall != nil { - msg.FunctionCall = &openaiclient.FunctionCall{ - Name: mm.FunctionCall.Name, - Arguments: mm.FunctionCall.Arguments, - } - } - case schema.ChatMessageTypeHuman: - msg.Role = "user" - case schema.ChatMessageTypeGeneric: - msg.Role = "user" - case schema.ChatMessageTypeFunction: - msg.Role = "function" - } - if n, ok := m.(schema.Named); ok { - msg.Name = n.GetName() - } - msgs[i] = msg - } - - return msgs -} From c9a7d613c164c6349bcd8c45fccac78c3e25b992 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:04:16 -0800 Subject: [PATCH 03/19] openai: use CallLLM to implement Call Adding CallLLM as a general utility --- llms/llms.go | 20 ++++++++++++++++++++ llms/openai/openaillm.go | 9 +-------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/llms/llms.go b/llms/llms.go index 200fe189c..bdbbe1727 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -2,6 +2,7 @@ package llms import ( "context" + "errors" "github.com/tmc/langchaingo/schema" ) @@ -44,3 +45,22 @@ type LLMResult struct { Generations [][]*Generation LLMOutput map[string]any } + +func CallLLM(ctx context.Context, llm Model, prompt string, options ...CallOption) (string, error) { + msg := MessageContent{ + Role: schema.ChatMessageTypeHuman, + Parts: []ContentPart{TextContent{prompt}}, + } + + resp, err := llm.GenerateContent(ctx, []MessageContent{msg}, options...) + if err != nil { + return "", err + } + + choices := resp.Choices + if len(choices) < 1 { + return "", errors.New("empty response from model") //nolint:goerr113 + } + c1 := choices[0] + return c1.Content, nil +} diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 4c85d8845..975671888 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -40,14 +40,7 @@ func New(opts ...Option) (*LLM, error) { // Call requests a completion for the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } // GenerateContent implements the Model interface. From da53f691bd52eea0d66c37d28a934defb18f7f50 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:08:56 -0800 Subject: [PATCH 04/19] ollama: merge chat into llm, implement Call with GenerateContent --- llms/ollama/ollama_test.go | 21 +- llms/ollama/ollamallm.go | 154 ++++++++++++++- llms/ollama/ollamallm_chat.go | 347 ---------------------------------- 3 files changed, 150 insertions(+), 372 deletions(-) delete mode 100644 llms/ollama/ollamallm_chat.go diff --git a/llms/ollama/ollama_test.go b/llms/ollama/ollama_test.go index be9f78e6c..57f3f7876 100644 --- a/llms/ollama/ollama_test.go +++ b/llms/ollama/ollama_test.go @@ -12,7 +12,7 @@ import ( "github.com/tmc/langchaingo/schema" ) -func newChatClient(t *testing.T) *Chat { +func newTestClient(t *testing.T) *LLM { t.Helper() var ollamaModel string if ollamaModel = os.Getenv("OLLAMA_TEST_MODEL"); ollamaModel == "" { @@ -20,27 +20,14 @@ func newChatClient(t *testing.T) *Chat { return nil } - c, err := NewChat(WithModel(ollamaModel)) + c, err := New(WithModel(ollamaModel)) require.NoError(t, err) return c } -func TestChatBasic(t *testing.T) { - t.Parallel() - - llm := newChatClient(t) - - resp, err := llm.Call(context.Background(), []schema.ChatMessage{ - schema.SystemChatMessage{Content: "You are producing poems in Spanish."}, - schema.HumanChatMessage{Content: "Write a very short poem about Donald Knuth"}, - }) - require.NoError(t, err) - assert.Regexp(t, "programa|comput|algoritm|libro", strings.ToLower(resp.Content)) //nolint:all -} - func TestGenerateContent(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "How many feet are in a nautical mile?"}, @@ -62,7 +49,7 @@ func TestGenerateContent(t *testing.T) { func TestWithStreaming(t *testing.T) { t.Parallel() - llm := newChatClient(t) + llm := newTestClient(t) parts := []llms.ContentPart{ llms.TextContent{Text: "How many feet are in a nautical mile?"}, diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index 7f9825b6c..eaf7866f1 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -7,6 +7,7 @@ import ( "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama/internal/ollamaclient" + "github.com/tmc/langchaingo/schema" ) var ( @@ -40,14 +41,7 @@ func New(opts ...Option) (*LLM, error) { // Call Implement the call interface for LLM. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } // Generate implemente the generate interface for LLM. @@ -122,6 +116,119 @@ func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.Ca return generations, nil } +// GenerateContent implements the Model interface. +// nolint: goerr113 +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + // Override LLM model if set as llms.CallOption + model := o.options.model + if opts.Model != "" { + model = opts.Model + } + + // Our input is a sequence of MessageContent, each of which potentially has + // a sequence of Part that could be text, images etc. + // We have to convert it to a format Ollama undestands: ChatRequest, which + // has a sequence of Message, each of which has a role and content - single + // text + potential images. + chatMsgs := make([]*ollamaclient.Message, 0, len(messages)) + for _, mc := range messages { + msg := &ollamaclient.Message{Role: typeToRole(mc.Role)} + + // Look at all the parts in mc; expect to find a single Text part and + // any number of binary parts. + var text string + foundText := false + var images []ollamaclient.ImageData + + for _, p := range mc.Parts { + switch pt := p.(type) { + case llms.TextContent: + if foundText { + return nil, errors.New("expecting a single Text content") + } + foundText = true + text = pt.Text + case llms.BinaryContent: + images = append(images, ollamaclient.ImageData(pt.Data)) + default: + return nil, errors.New("only support Text and BinaryContent parts right now") + } + } + + msg.Content = text + msg.Images = images + chatMsgs = append(chatMsgs, msg) + } + + // Get our ollamaOptions from llms.CallOptions + ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) + req := &ollamaclient.ChatRequest{ + Model: model, + Messages: chatMsgs, + Options: ollamaOptions, + Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), + } + + var fn ollamaclient.ChatResponseFunc + streamedResponse := "" + var resp ollamaclient.ChatResponse + + fn = func(response ollamaclient.ChatResponse) error { + if opts.StreamingFunc != nil && response.Message != nil { + if err := opts.StreamingFunc(ctx, []byte(response.Message.Content)); err != nil { + return err + } + } + if response.Message != nil { + streamedResponse += response.Message.Content + } + if response.Done { + resp = response + resp.Message = &ollamaclient.Message{ + Role: "assistant", + Content: streamedResponse, + } + } + return nil + } + + err := o.client.GenerateChat(ctx, req, fn) + if err != nil { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err + } + + choices := []*llms.ContentChoice{ + { + Content: resp.Message.Content, + GenerationInfo: map[string]any{ + "CompletionTokens": resp.EvalCount, + "PromptTokens": resp.PromptEvalCount, + "TotalTokesn": resp.EvalCount + resp.PromptEvalCount, + }, + }, + } + + response := &llms.ContentResponse{Choices: choices} + + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) + } + + return response, nil +} + func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { embeddings := [][]float32{} @@ -147,3 +254,34 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo return embeddings, nil } + +func typeToRole(typ schema.ChatMessageType) string { + switch typ { + case schema.ChatMessageTypeSystem: + return "system" + case schema.ChatMessageTypeAI: + return "assistant" + case schema.ChatMessageTypeHuman: + fallthrough + case schema.ChatMessageTypeGeneric: + return "user" + case schema.ChatMessageTypeFunction: + return "function" + } + return "" +} + +func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.CallOptions) ollamaclient.Options { + // Load back CallOptions as ollamaOptions + ollamaOptions.NumPredict = opts.MaxTokens + ollamaOptions.Temperature = float32(opts.Temperature) + ollamaOptions.Stop = opts.StopWords + ollamaOptions.TopK = opts.TopK + ollamaOptions.TopP = float32(opts.TopP) + ollamaOptions.Seed = opts.Seed + ollamaOptions.RepeatPenalty = float32(opts.RepetitionPenalty) + ollamaOptions.FrequencyPenalty = float32(opts.FrequencyPenalty) + ollamaOptions.PresencePenalty = float32(opts.PresencePenalty) + + return ollamaOptions +} diff --git a/llms/ollama/ollamallm_chat.go b/llms/ollama/ollamallm_chat.go deleted file mode 100644 index b21788d8c..000000000 --- a/llms/ollama/ollamallm_chat.go +++ /dev/null @@ -1,347 +0,0 @@ -package ollama - -import ( - "context" - "errors" - "fmt" - - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/ollama/internal/ollamaclient" - "github.com/tmc/langchaingo/schema" -) - -// LLM is a ollama LLM implementation. -type Chat struct { - CallbacksHandler callbacks.Handler - client *ollamaclient.Client - options options -} - -var _ llms.ChatLLM = (*Chat)(nil) - -// New creates a new ollama LLM implementation. -func NewChat(opts ...Option) (*Chat, error) { - o := options{} - for _, opt := range opts { - opt(&o) - } - - client, err := ollamaclient.NewClient(o.ollamaServerURL) - if err != nil { - return nil, err - } - - return &Chat{client: client, options: o}, nil -} - -// Call Implement the call interface for LLM. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { //nolint:lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -// Generate implemente the generate interface for LLM. -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { //nolint:lll,cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, o.getPromptsFromMessageSets(messageSets)) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - // Override LLM model if set as llms.CallOption - model := o.options.model - if opts.Model != "" { - model = opts.Model - } - - // Get our ollamaOptions from llms.CallOptions - ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) - - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messages := range messageSets { - req, err := messagesToChatRequest(messages) - if err != nil { - return nil, err - } - - req.Model = model - req.Options = ollamaOptions - req.Stream = func(b bool) *bool { return &b }(opts.StreamingFunc != nil) - - var fn ollamaclient.ChatResponseFunc - - streamedResponse := "" - var resp ollamaclient.ChatResponse - - fn = func(response ollamaclient.ChatResponse) error { - if opts.StreamingFunc != nil && response.Message != nil { - if err := opts.StreamingFunc(ctx, []byte(response.Message.Content)); err != nil { - return err - } - } - if response.Message != nil { - streamedResponse += response.Message.Content - } - if response.Done { - resp = response - resp.Message = &ollamaclient.Message{ - Role: "assistant", - Content: streamedResponse, - } - } - return nil - } - - err = o.client.GenerateChat(ctx, req, fn) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return []*llms.Generation{}, err - } - - generations = append(generations, makeGenerationFromChatResponse(resp)) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - -// GenerateContent implements the Model interface. -// nolint: goerr113 -func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - // Override LLM model if set as llms.CallOption - model := o.options.model - if opts.Model != "" { - model = opts.Model - } - - // Our input is a sequence of MessageContent, each of which potentially has - // a sequence of Part that could be text, images etc. - // We have to convert it to a format Ollama undestands: ChatRequest, which - // has a sequence of Message, each of which has a role and content - single - // text + potential images. - chatMsgs := make([]*ollamaclient.Message, 0, len(messages)) - for _, mc := range messages { - msg := &ollamaclient.Message{Role: typeToRole(mc.Role)} - - // Look at all the parts in mc; expect to find a single Text part and - // any number of binary parts. - var text string - foundText := false - var images []ollamaclient.ImageData - - for _, p := range mc.Parts { - switch pt := p.(type) { - case llms.TextContent: - if foundText { - return nil, errors.New("expecting a single Text content") - } - foundText = true - text = pt.Text - case llms.BinaryContent: - images = append(images, ollamaclient.ImageData(pt.Data)) - default: - return nil, errors.New("only support Text and BinaryContent parts right now") - } - } - - msg.Content = text - msg.Images = images - chatMsgs = append(chatMsgs, msg) - } - - // Get our ollamaOptions from llms.CallOptions - ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) - req := &ollamaclient.ChatRequest{ - Model: model, - Messages: chatMsgs, - Options: ollamaOptions, - Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), - } - - var fn ollamaclient.ChatResponseFunc - streamedResponse := "" - var resp ollamaclient.ChatResponse - - fn = func(response ollamaclient.ChatResponse) error { - if opts.StreamingFunc != nil && response.Message != nil { - if err := opts.StreamingFunc(ctx, []byte(response.Message.Content)); err != nil { - return err - } - } - if response.Message != nil { - streamedResponse += response.Message.Content - } - if response.Done { - resp = response - resp.Message = &ollamaclient.Message{ - Role: "assistant", - Content: streamedResponse, - } - } - return nil - } - - err := o.client.GenerateChat(ctx, req, fn) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - choices := []*llms.ContentChoice{ - { - Content: resp.Message.Content, - GenerationInfo: map[string]any{ - "CompletionTokens": resp.EvalCount, - "PromptTokens": resp.PromptEvalCount, - "TotalTokesn": resp.EvalCount + resp.PromptEvalCount, - }, - }, - } - - response := &llms.ContentResponse{Choices: choices} - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response) - } - - return response, nil -} - -func makeGenerationFromChatResponse(resp ollamaclient.ChatResponse) *llms.Generation { - msg := &schema.AIChatMessage{ - Content: resp.Message.Content, - } - - gen := &llms.Generation{ - Message: msg, - Text: msg.Content, - GenerationInfo: make(map[string]any), - } - - gen.GenerationInfo["CompletionTokens"] = resp.EvalCount - gen.GenerationInfo["PromptTokens"] = resp.PromptEvalCount - gen.GenerationInfo["TotalTokens"] = resp.PromptEvalCount + resp.EvalCount - - return gen -} - -func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.CallOptions) ollamaclient.Options { - // Load back CallOptions as ollamaOptions - ollamaOptions.NumPredict = opts.MaxTokens - ollamaOptions.Temperature = float32(opts.Temperature) - ollamaOptions.Stop = opts.StopWords - ollamaOptions.TopK = opts.TopK - ollamaOptions.TopP = float32(opts.TopP) - ollamaOptions.Seed = opts.Seed - ollamaOptions.RepeatPenalty = float32(opts.RepetitionPenalty) - ollamaOptions.FrequencyPenalty = float32(opts.FrequencyPenalty) - ollamaOptions.PresencePenalty = float32(opts.PresencePenalty) - - return ollamaOptions -} - -func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { - embeddings := [][]float32{} - - for _, input := range inputTexts { - embedding, err := o.client.CreateEmbedding(ctx, &ollamaclient.EmbeddingRequest{ - Prompt: input, - Model: o.options.model, - }) - if err != nil { - return nil, err - } - - if len(embedding.Embedding) == 0 { - return nil, ErrEmptyResponse - } - - embeddings = append(embeddings, embedding.Embedding) - } - - if len(inputTexts) != len(embeddings) { - return embeddings, ErrIncompleteEmbedding - } - - return embeddings, nil -} - -func (o Chat) getPromptsFromMessageSets(messageSets [][]schema.ChatMessage) []string { - prompts := make([]string, 0, len(messageSets)) - for i := 0; i < len(messageSets); i++ { - curPrompt := "" - for j := 0; j < len(messageSets[i]); j++ { - curPrompt += messageSets[i][j].GetContent() - } - prompts = append(prompts, curPrompt) - } - return prompts -} - -func messagesToChatRequest(messages []schema.ChatMessage) (*ollamaclient.ChatRequest, error) { - req := &ollamaclient.ChatRequest{} - for _, m := range messages { - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: - fallthrough - case schema.ChatMessageTypeAI: - req.Messages = append(req.Messages, &ollamaclient.Message{ - Role: typeToRole(typ), - Content: m.GetContent(), - }) - case schema.ChatMessageTypeHuman: - fallthrough - case schema.ChatMessageTypeGeneric: - req.Messages = append(req.Messages, &ollamaclient.Message{ - Role: typeToRole(typ), - Content: m.GetContent(), - }) - case schema.ChatMessageTypeFunction: - return nil, fmt.Errorf("chat message type %s not implemented", typ) //nolint:goerr113 - } - } - return req, nil -} - -func typeToRole(typ schema.ChatMessageType) string { - switch typ { - case schema.ChatMessageTypeSystem: - return "system" - case schema.ChatMessageTypeAI: - return "assistant" - case schema.ChatMessageTypeHuman: - fallthrough - case schema.ChatMessageTypeGeneric: - return "user" - case schema.ChatMessageTypeFunction: - return "function" - } - return "" -} From 537d9ea3cb680d84fa8c4fb0e6d1ddd6c8672939 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:09:46 -0800 Subject: [PATCH 05/19] googleai: add Call in terms of GenerateContent --- llms/googleai/googleai_llm.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index 833504497..4c8735ce3 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -64,6 +64,11 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) { return gi, nil } +// Call Implement the call interface for LLM. +func (g *GoogleAI) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { + return llms.CallLLM(ctx, g, prompt, options...) +} + // GenerateContent calls the LLM with the provided parts. func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { if g.CallbacksHandler != nil { From ab1fc1e0791c44efa5361895885d28af012dc150 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:18:53 -0800 Subject: [PATCH 06/19] huggingface: add GenerateContent and implement Model --- llms/huggingface/huggingfacellm.go | 47 ++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index c673d8dee..865bf7016 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -23,15 +23,52 @@ type LLM struct { var _ llms.LLM = (*LLM)(nil) +// Call implements the LLM interface. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) + return llms.CallLLM(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := &llms.CallOptions{Model: defaultModel} + for _, opt := range options { + opt(opts) + } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.RunInference(ctx, &huggingfaceclient.InferenceRequest{ + Model: o.client.Model, + Prompt: part.(llms.TextContent).Text, + Task: huggingfaceclient.InferenceTaskTextGeneration, + Temperature: opts.Temperature, + TopP: opts.TopP, + TopK: opts.TopK, + MinLength: opts.MinLength, + MaxLength: opts.MaxLength, + RepetitionPenalty: opts.RepetitionPenalty, + Seed: opts.Seed, + }) if err != nil { - return "", err + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err } - if len(r) == 0 { - return "", ErrEmptyResponse + + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + &llms.ContentChoice{ + Content: result.Text, + }, + }, } - return r[0].Text, nil + return resp, nil } func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { From b31c6f46900573d2c55dea5a0b9c27a1fd2f8b2e Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:22:18 -0800 Subject: [PATCH 07/19] localllm: implement Model with GenerateContent --- llms/local/localllm.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/llms/local/localllm.go b/llms/local/localllm.go index d29fa277a..d7191744a 100644 --- a/llms/local/localllm.go +++ b/llms/local/localllm.go @@ -69,6 +69,48 @@ func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) []string { return o.client.Args } +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := &llms.CallOptions{} + for _, opt := range options { + opt(opts) + } + + // If o.client.GlobalAsArgs is true + if o.client.GlobalAsArgs { + // Then add the option to the args in --key=value format + o.appendGlobalsToArgs(*opts) + } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateCompletion(ctx, &localclient.CompletionRequest{ + Prompt: part.(llms.TextContent).Text, + }) + if err != nil { + return nil, err + } + + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + &llms.ContentChoice{ + Content: result.Text, + }, + }, + } + + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp) + } + + return resp, nil +} + // Generate generates completions using the local LLM binary. func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { if o.CallbacksHandler != nil { From 9fee6b6a1915c3bdb57c94952c695d6e73ba295b Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:33:12 -0800 Subject: [PATCH 08/19] vertexai: implement GenerateContent --- llms/vertexai/vertexai_palm_llm.go | 46 ++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/llms/vertexai/vertexai_palm_llm.go b/llms/vertexai/vertexai_palm_llm.go index d9bde1e13..38348b7ca 100644 --- a/llms/vertexai/vertexai_palm_llm.go +++ b/llms/vertexai/vertexai_palm_llm.go @@ -25,14 +25,50 @@ var _ llms.LLM = (*LLM)(nil) // Call requests a completion for the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) + return llms.CallLLM(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + + results, err := o.client.CreateCompletion(ctx, &vertexaiclient.CompletionRequest{ + Prompts: []string{part.(llms.TextContent).Text}, + MaxTokens: opts.MaxTokens, + Temperature: opts.Temperature, + StopSequences: opts.StopWords, + }) + if err != nil { - return "", err + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err + } + + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + &llms.ContentChoice{ + Content: results[0].Text, + }, + }, } - if len(r) == 0 { - return "", ErrEmptyResponse + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp) } - return r[0].Text, nil + + return resp, nil } func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { From e1a90e03f07f218ea8bf7fa68dc5ba3430a9e5ec Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:33:50 -0800 Subject: [PATCH 09/19] vertexai: remove chat implementation since we now have Model --- llms/vertexai/vertexai_palm_llm_chat.go | 137 ------------------------ 1 file changed, 137 deletions(-) delete mode 100644 llms/vertexai/vertexai_palm_llm_chat.go diff --git a/llms/vertexai/vertexai_palm_llm_chat.go b/llms/vertexai/vertexai_palm_llm_chat.go deleted file mode 100644 index 80a96a049..000000000 --- a/llms/vertexai/vertexai_palm_llm_chat.go +++ /dev/null @@ -1,137 +0,0 @@ -package vertexai - -import ( - "context" - - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/vertexai/internal/vertexaiclient" - "github.com/tmc/langchaingo/schema" -) - -const ( - userAuthor = "user" - botAuthor = "bot" -) - -type ChatMessage = vertexaiclient.ChatMessage - -type Chat struct { - client *vertexaiclient.PaLMClient -} - -var _ llms.ChatLLM = (*Chat)(nil) - -// Chat requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -// Generate requests a chat response for each of the sets of messages. -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint: lll - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - if opts.StreamingFunc != nil { - return nil, ErrNotImplemented - } - - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messages := range messageSets { - chatContext := parseContext(messages) - if len(chatContext) > 0 { - // remove system context from messages - messages = messages[1:] - } - msgs := toClientChatMessage(messages) - result, err := o.client.CreateChat(ctx, &vertexaiclient.ChatRequest{ - Temperature: opts.Temperature, - Messages: msgs, - Context: chatContext, - }) - if err != nil { - return nil, err - } - if len(result.Candidates) == 0 { - return nil, ErrEmptyResponse - } - generations = append(generations, &llms.Generation{ - Message: &schema.AIChatMessage{ - Content: result.Candidates[0].Content, - }, - Text: result.Candidates[0].Content, - }) - } - - return generations, nil -} - -func toClientChatMessage(messages []schema.ChatMessage) []*vertexaiclient.ChatMessage { - msgs := make([]*vertexaiclient.ChatMessage, len(messages)) - - for i, m := range messages { - msg := &vertexaiclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - - switch typ { - case schema.ChatMessageTypeSystem: - msg.Author = botAuthor - case schema.ChatMessageTypeAI: - msg.Author = botAuthor - case schema.ChatMessageTypeHuman: - msg.Author = userAuthor - case schema.ChatMessageTypeGeneric: - msg.Author = userAuthor - case schema.ChatMessageTypeFunction: - msg.Author = userAuthor - } - if n, ok := m.(schema.Named); ok { - msg.Author = n.GetName() - } - msgs[i] = msg - } - return msgs -} - -func parseContext(messages []schema.ChatMessage) string { - if len(messages) == 0 { - return "" - } - // check if 1st message type is system. use it as context. - if messages[0].GetType() == schema.ChatMessageTypeSystem { - return messages[0].GetContent() - } - return "" -} - -// NewChat returns a new VertexAI PaLM Chat LLM. -func NewChat(opts ...Option) (*Chat, error) { - client, err := newClient(opts...) - return &Chat{client: client}, err -} - -// CreateEmbedding creates embeddings for the given input texts. -func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { - embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{ - Input: inputTexts, - }) - if err != nil { - return nil, err - } - if len(embeddings) == 0 { - return nil, ErrEmptyResponse - } - if len(inputTexts) != len(embeddings) { - return embeddings, ErrUnexpectedResponseLength - } - return embeddings, nil -} From dd74cc7d6da9ac48652bfa85a5d42989b0f3030c Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:36:15 -0800 Subject: [PATCH 10/19] anthropic: implement GenerateContent and Model --- llms/anthropic/anthropicllm.go | 43 ++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index 620255ba6..a686441f4 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -50,14 +50,47 @@ func newClient(opts ...Option) (*anthropicclient.Client, error) { // Call requests a completion for the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) + return llms.CallLLM(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := &llms.CallOptions{} + for _, opt := range options { + opt(opts) + } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{ + Model: opts.Model, + Prompt: part.(llms.TextContent).Text, + MaxTokens: opts.MaxTokens, + StopWords: opts.StopWords, + Temperature: opts.Temperature, + TopP: opts.TopP, + StreamingFunc: opts.StreamingFunc, + }) if err != nil { - return "", err + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err } - if len(r) == 0 { - return "", ErrEmptyResponse + + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + &llms.ContentChoice{ + Content: result.Text, + }, + }, } - return r[0].Text, nil + return resp, nil } func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { From 22406075733b9a0420e7244ccb43d7bef35aef57 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:39:19 -0800 Subject: [PATCH 11/19] cohere: implement GenerateContent --- llms/cohere/coherellm.go | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/llms/cohere/coherellm.go b/llms/cohere/coherellm.go index a2b8d9b59..e5d452c4d 100644 --- a/llms/cohere/coherellm.go +++ b/llms/cohere/coherellm.go @@ -25,14 +25,41 @@ type LLM struct { var _ llms.LLM = (*LLM)(nil) func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) + return llms.CallLLM(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := &llms.CallOptions{} + for _, opt := range options { + opt(opts) + } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateGeneration(ctx, &cohereclient.GenerationRequest{ + Prompt: part.(llms.TextContent).Text, + }) if err != nil { - return "", err + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err } - if len(r) == 0 { - return "", ErrEmptyResponse + + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + &llms.ContentChoice{ + Content: result.Text, + }, + }, } - return r[0].Text, nil + return resp, nil } func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { From 51af7083b1558dfbaab4e96316c5d4d5b0928100 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:42:57 -0800 Subject: [PATCH 12/19] ernie: implement GenerateContent --- llms/ernie/erniellm.go | 52 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/llms/ernie/erniellm.go b/llms/ernie/erniellm.go index a4a2a1f30..1c499b0f4 100644 --- a/llms/ernie/erniellm.go +++ b/llms/ernie/erniellm.go @@ -61,16 +61,58 @@ doc: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2`, ernieclient.ErrNot // Call implements llms.LLM. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) + return llms.CallLLM(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := &llms.CallOptions{} + for _, opt := range options { + opt(opts) + } + + // Assume we get a single text message + msg0 := messages[0] + part := msg0.Parts[0] + result, err := o.client.CreateCompletion(ctx, o.getModelPath(*opts), &ernieclient.CompletionRequest{ + Messages: []ernieclient.Message{{Role: "user", Content: part.(llms.TextContent).Text}}, + Temperature: opts.Temperature, + TopP: opts.TopP, + PenaltyScore: opts.RepetitionPenalty, + StreamingFunc: opts.StreamingFunc, + Stream: opts.StreamingFunc != nil, + }) if err != nil { - return "", err + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err + } + if result.ErrorCode > 0 { + err = fmt.Errorf("%w, error_code:%v, erro_msg:%v, id:%v", + ErrCodeResponse, result.ErrorCode, result.ErrorMsg, result.ID) + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMError(ctx, err) + } + return nil, err } - if len(r) == 0 { - return "", ErrEmptyResponse + resp := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{ + &llms.ContentChoice{ + Content: result.Result, + }, + }, + } + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp) } - return r[0].Text, nil + return resp, nil } // Generate implements llms.LLM. From 06ed2fba829247122e681d1d0f6833cbda3035ea Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:48:50 -0800 Subject: [PATCH 13/19] all: lint --- llms/anthropic/anthropicllm.go | 5 +++-- llms/cohere/coherellm.go | 5 +++-- llms/ernie/erniellm.go | 5 +++-- llms/huggingface/huggingfacellm.go | 5 +++-- llms/local/localllm.go | 9 ++++----- llms/openai/openaillm.go | 2 +- llms/vertexai/vertexai_palm_llm.go | 6 +++--- 7 files changed, 20 insertions(+), 17 deletions(-) diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index a686441f4..dcfaf64d6 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -54,7 +54,8 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio } // GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -85,7 +86,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten resp := &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - &llms.ContentChoice{ + { Content: result.Text, }, }, diff --git a/llms/cohere/coherellm.go b/llms/cohere/coherellm.go index e5d452c4d..79bc4845d 100644 --- a/llms/cohere/coherellm.go +++ b/llms/cohere/coherellm.go @@ -29,7 +29,8 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio } // GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -54,7 +55,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten resp := &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - &llms.ContentChoice{ + { Content: result.Text, }, }, diff --git a/llms/ernie/erniellm.go b/llms/ernie/erniellm.go index 1c499b0f4..0ad9cffdb 100644 --- a/llms/ernie/erniellm.go +++ b/llms/ernie/erniellm.go @@ -65,7 +65,8 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio } // GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -103,7 +104,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten resp := &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - &llms.ContentChoice{ + { Content: result.Result, }, }, diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index 865bf7016..2eb5ea2a7 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -29,7 +29,8 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio } // GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -63,7 +64,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten resp := &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - &llms.ContentChoice{ + { Content: result.Text, }, }, diff --git a/llms/local/localllm.go b/llms/local/localllm.go index d7191744a..581630e52 100644 --- a/llms/local/localllm.go +++ b/llms/local/localllm.go @@ -43,7 +43,7 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio return r[0].Text, nil } -func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) []string { +func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) { if opts.Temperature != 0 { o.client.Args = append(o.client.Args, fmt.Sprintf("--temperature=%f", opts.Temperature)) } @@ -65,12 +65,11 @@ func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) []string { if opts.Seed != 0 { o.client.Args = append(o.client.Args, fmt.Sprintf("--seed=%d", opts.Seed)) } - - return o.client.Args } // GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -98,7 +97,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten resp := &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - &llms.ContentChoice{ + { Content: result.Text, }, }, diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 975671888..694d3ed24 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -46,7 +46,7 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio // GenerateContent implements the Model interface. // //nolint:goerr113 -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } diff --git a/llms/vertexai/vertexai_palm_llm.go b/llms/vertexai/vertexai_palm_llm.go index 38348b7ca..f668ea239 100644 --- a/llms/vertexai/vertexai_palm_llm.go +++ b/llms/vertexai/vertexai_palm_llm.go @@ -29,7 +29,8 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio } // GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -49,7 +50,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten Temperature: opts.Temperature, StopSequences: opts.StopWords, }) - if err != nil { if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMError(ctx, err) @@ -59,7 +59,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten resp := &llms.ContentResponse{ Choices: []*llms.ContentChoice{ - &llms.ContentChoice{ + { Content: results[0].Text, }, }, From 7d102723f9b5e84a1f2625aed08af35ed2c72c10 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:49:58 -0800 Subject: [PATCH 14/19] all: remove ChatLLM interface - it's now unused --- llms/ernie/ernie_chat.go | 185 ---------------------------------- llms/ernie/ernie_chat_test.go | 131 ------------------------ llms/llms.go | 6 -- 3 files changed, 322 deletions(-) delete mode 100644 llms/ernie/ernie_chat.go delete mode 100644 llms/ernie/ernie_chat_test.go diff --git a/llms/ernie/ernie_chat.go b/llms/ernie/ernie_chat.go deleted file mode 100644 index ef44b3379..000000000 --- a/llms/ernie/ernie_chat.go +++ /dev/null @@ -1,185 +0,0 @@ -package ernie - -import ( - "context" - "os" - "reflect" - - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/llms" - "github.com/tmc/langchaingo/llms/ernie/internal/ernieclient" - "github.com/tmc/langchaingo/schema" -) - -type ChatMessage = ernieclient.ChatMessage - -type Chat struct { - CallbacksHandler callbacks.Handler - client *ernieclient.Client -} - -var _ llms.ChatLLM = (*Chat)(nil) - -func NewChat(opts ...Option) (*Chat, error) { - options := &options{ - apiKey: os.Getenv(ernieAPIKey), - secretKey: os.Getenv(ernieSecretKey), - } - - for _, opt := range opts { - opt(options) - } - - c, err := newClient(options) - if err != nil { - return nil, err - } - c.ModelPath = modelToPath(ModelName(c.Model)) - - return &Chat{ - client: c, - }, err -} - -// Call requests a chat response for the given messages. -func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll - r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, ErrEmptyResponse - } - return r[0].Message, nil -} - -//nolint:funlen -func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint:lll,cyclop - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, getPromptsFromMessageSets(messageSets)) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - generations := make([]*llms.Generation, 0, len(messageSets)) - for _, messageSet := range messageSets { - req := &ernieclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: messagesToClientMessages(messageSet), - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - MaxTokens: opts.MaxTokens, - N: opts.N, // TODO: note, we are not returning multiple completions - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - System: getSystem(messageSet), - - FunctionCallBehavior: ernieclient.FunctionCallBehavior(opts.FunctionCallBehavior), - } - for _, fn := range opts.Functions { - req.Functions = append(req.Functions, ernieclient.FunctionDefinition{ - Name: fn.Name, - Description: fn.Description, - Parameters: fn.Parameters, - }) - } - result, err := o.client.CreateChat(ctx, req) - if err != nil { - return nil, err - } - - if result.Result == "" && result.FunctionCall == nil { - return nil, ErrEmptyResponse - } - - generationInfo := make(map[string]any, reflect.ValueOf(result.Usage).NumField()) - generationInfo["CompletionTokens"] = result.Usage.CompletionTokens - generationInfo["PromptTokens"] = result.Usage.PromptTokens - generationInfo["TotalTokens"] = result.Usage.TotalTokens - msg := &schema.AIChatMessage{ - Content: result.Result, - } - - if result.FunctionCall != nil { - msg.FunctionCall = &schema.FunctionCall{ - Name: result.FunctionCall.Name, - Arguments: result.FunctionCall.Arguments, - } - } - generations = append(generations, &llms.Generation{ - Message: msg, - Text: msg.Content, - GenerationInfo: generationInfo, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - -func getPromptsFromMessageSets(messageSets [][]schema.ChatMessage) []string { - prompts := make([]string, 0, len(messageSets)) - for i := 0; i < len(messageSets); i++ { - curPrompt := "" - for j := 0; j < len(messageSets[i]); j++ { - curPrompt += messageSets[i][j].GetContent() - } - prompts = append(prompts, curPrompt) - } - - return prompts -} - -func messagesToClientMessages(messages []schema.ChatMessage) []*ernieclient.ChatMessage { - msgs := make([]*ernieclient.ChatMessage, 0) - for _, m := range messages { - msg := &ernieclient.ChatMessage{ - Content: m.GetContent(), - } - typ := m.GetType() - switch typ { - case schema.ChatMessageTypeSystem: // In Ernie's 'messages' parameter, there is no 'system' role. - continue - case schema.ChatMessageTypeAI: - msg.Role = "assistant" - case schema.ChatMessageTypeHuman: - msg.Role = "user" - case schema.ChatMessageTypeGeneric: - msg.Role = "user" - case schema.ChatMessageTypeFunction: - msg.Role = "function" - } - - if n, ok := m.(FunctionCalled); ok { - msg.FunctionCall = n.GetFunctionCall() - } - - if n, ok := m.(schema.Named); ok { - msg.Name = n.GetName() - } - msgs = append(msgs, msg) - } - - return msgs -} - -// getSystem Retrieve system parameter from messages. -func getSystem(messages []schema.ChatMessage) string { - for _, message := range messages { - if message.GetType() == schema.ChatMessageTypeSystem { - return message.GetContent() - } - } - return "" -} - -// FunctionCalled is an interface for objects that have a function call info. -type FunctionCalled interface { - GetFunctionCall() *schema.FunctionCall -} diff --git a/llms/ernie/ernie_chat_test.go b/llms/ernie/ernie_chat_test.go deleted file mode 100644 index 5d1669f26..000000000 --- a/llms/ernie/ernie_chat_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package ernie - -import ( - "reflect" - "testing" - - "github.com/tmc/langchaingo/llms/ernie/internal/ernieclient" - "github.com/tmc/langchaingo/schema" -) - -func TestNewChat(t *testing.T) { - t.Parallel() - type args struct { - opts []Option - } - tests := []struct { - name string - args args - want *Chat - wantErr bool - }{ - {name: "", args: args{opts: []Option{ - WithModelName(ModelNameERNIEBot), - WithAKSK("ak", "sk"), - }}, want: nil, wantErr: true}, - {name: "", args: args{opts: []Option{ - WithModelName(ModelNameERNIEBot), - WithAKSK("ak", "sk"), - WithAccessToken("xxx"), - }}, want: nil, wantErr: false}, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got, err := NewChat(tt.args.opts...) - if (err != nil) != tt.wantErr { - t.Errorf("NewChat() error = %v, wantErr %v", err, tt.wantErr) - return - } - - expectedType := reflect.TypeOf(tt.want) - if reflect.TypeOf(got) != expectedType { - t.Errorf("NewChat() got = %T, want %T", got, tt.want) - } - }) - } -} - -func TestGetSystem(t *testing.T) { - t.Parallel() - type args struct { - messages []schema.ChatMessage - } - tests := []struct { - name string - args args - want string - }{ - { - name: "system message exists", - args: args{ - messages: []schema.ChatMessage{ - schema.SystemChatMessage{Content: "you are a robot."}, - schema.HumanChatMessage{Content: "who are you?"}, - }, - }, - want: "you are a robot.", - }, - { - name: "no system message", - args: args{ - messages: []schema.ChatMessage{ - schema.HumanChatMessage{Content: "who are you?"}, - }, - }, - want: "", - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - if got := getSystem(tt.args.messages); got != tt.want { - t.Errorf("getSystem() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestMessagesToClientMessages(t *testing.T) { - t.Parallel() - type args struct { - messages []schema.ChatMessage - } - tests := []struct { - name string - args args - want []*ernieclient.ChatMessage - }{ - { - name: "Test_MessagesToClientMessages_OK", - args: args{messages: []schema.ChatMessage{ - schema.AIChatMessage{Content: "assistant"}, - schema.HumanChatMessage{Content: "user"}, - schema.SystemChatMessage{Content: ""}, - schema.FunctionChatMessage{Content: "function"}, - schema.GenericChatMessage{Content: "user"}, - }}, - want: []*ernieclient.ChatMessage{ - {Content: "assistant", Role: "assistant"}, - {Content: "user", Role: "user"}, - {Content: "function", Role: "function"}, - {Content: "user", Role: "user"}, - }, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := messagesToClientMessages(tt.args.messages) - for i, v := range got { - if !reflect.DeepEqual(v.Content, tt.want[i].Content) { - t.Errorf("messagesToClientMessages() = %v, want %v", got, tt.want) - } - } - }) - } -} diff --git a/llms/llms.go b/llms/llms.go index bdbbe1727..ac6f03a19 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -13,12 +13,6 @@ type LLM interface { Generate(ctx context.Context, prompts []string, options ...CallOption) ([]*Generation, error) } -// ChatLLM is a langchaingo LLM that can be used for chatting. -type ChatLLM interface { - Call(ctx context.Context, messages []schema.ChatMessage, options ...CallOption) (*schema.AIChatMessage, error) - Generate(ctx context.Context, messages [][]schema.ChatMessage, options ...CallOption) ([]*Generation, error) -} - // Model is an interface multi-modal models implement. // Note: this is an experimental API. type Model interface { From d7633d6fdfad62b8e88e1809b1c59922577768c1 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:52:46 -0800 Subject: [PATCH 15/19] all: remove Generate method, as it's no longer used --- llms/anthropic/anthropicllm.go | 38 ---------------- llms/cohere/coherellm.go | 35 --------------- llms/ernie/erniellm.go | 44 ------------------ llms/huggingface/huggingfacellm.go | 38 ---------------- llms/llms.go | 1 - llms/local/localllm.go | 48 +------------------- llms/ollama/ollamallm.go | 72 ------------------------------ llms/openai/openaillm.go | 42 ----------------- llms/vertexai/vertexai_palm_llm.go | 35 --------------- 9 files changed, 1 insertion(+), 352 deletions(-) diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index dcfaf64d6..2d40f8239 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -93,41 +93,3 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten } return resp, nil } - -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{ - Model: opts.Model, - Prompt: prompt, - MaxTokens: opts.MaxTokens, - StopWords: opts.StopWords, - Temperature: opts.Temperature, - TopP: opts.TopP, - StreamingFunc: opts.StreamingFunc, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - generations = append(generations, &llms.Generation{ - Text: result.Text, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - return generations, nil -} diff --git a/llms/cohere/coherellm.go b/llms/cohere/coherellm.go index 79bc4845d..13c3a797d 100644 --- a/llms/cohere/coherellm.go +++ b/llms/cohere/coherellm.go @@ -63,41 +63,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return resp, nil } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - generations := make([]*llms.Generation, 0, len(prompts)) - - for _, prompt := range prompts { - result, err := o.client.CreateGeneration(ctx, &cohereclient.GenerationRequest{ - Prompt: prompt, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - generations = append(generations, &llms.Generation{ - Text: result.Text, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - func New(opts ...Option) (*LLM, error) { c, err := newClient(opts...) return &LLM{ diff --git a/llms/ernie/erniellm.go b/llms/ernie/erniellm.go index 0ad9cffdb..1445abe36 100644 --- a/llms/ernie/erniellm.go +++ b/llms/ernie/erniellm.go @@ -116,50 +116,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return resp, nil } -// Generate implements llms.LLM. -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, o.getModelPath(opts), &ernieclient.CompletionRequest{ - Messages: []ernieclient.Message{{Role: "user", Content: prompt}}, - Temperature: opts.Temperature, - TopP: opts.TopP, - PenaltyScore: opts.RepetitionPenalty, - StreamingFunc: opts.StreamingFunc, - Stream: opts.StreamingFunc != nil, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - if result.ErrorCode > 0 { - err = fmt.Errorf("%w, error_code:%v, erro_msg:%v, id:%v", - ErrCodeResponse, result.ErrorCode, result.ErrorMsg, result.ID) - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - generations = append(generations, &llms.Generation{ - Text: result.Result, - }) - } - - return generations, nil -} - // CreateEmbedding use ernie Embedding-V1. // 1. texts counts less than 16 // 2. text runes counts less than 384 diff --git a/llms/huggingface/huggingfacellm.go b/llms/huggingface/huggingfacellm.go index 2eb5ea2a7..5a0f762d9 100644 --- a/llms/huggingface/huggingfacellm.go +++ b/llms/huggingface/huggingfacellm.go @@ -72,44 +72,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return resp, nil } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := &llms.CallOptions{Model: defaultModel} - for _, opt := range options { - opt(opts) - } - result, err := o.client.RunInference(ctx, &huggingfaceclient.InferenceRequest{ - Model: o.client.Model, - Prompt: prompts[0], - Task: huggingfaceclient.InferenceTaskTextGeneration, - Temperature: opts.Temperature, - TopP: opts.TopP, - TopK: opts.TopK, - MinLength: opts.MinLength, - MaxLength: opts.MaxLength, - RepetitionPenalty: opts.RepetitionPenalty, - Seed: opts.Seed, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - generations := []*llms.Generation{ - {Text: result.Text}, - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - return generations, nil -} - func New(opts ...Option) (*LLM, error) { options := &options{ token: os.Getenv(tokenEnvVarName), diff --git a/llms/llms.go b/llms/llms.go index ac6f03a19..c43d7c058 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -10,7 +10,6 @@ import ( // LLM is a langchaingo Large Language Model. type LLM interface { Call(ctx context.Context, prompt string, options ...CallOption) (string, error) - Generate(ctx context.Context, prompts []string, options ...CallOption) ([]*Generation, error) } // Model is an interface multi-modal models implement. diff --git a/llms/local/localllm.go b/llms/local/localllm.go index 581630e52..cf031f344 100644 --- a/llms/local/localllm.go +++ b/llms/local/localllm.go @@ -33,14 +33,7 @@ var ( // Call calls the local LLM binary with the given prompt. func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { - r, err := o.Generate(ctx, []string{prompt}, options...) - if err != nil { - return "", err - } - if len(r) == 0 { - return "", ErrEmptyResponse - } - return r[0].Text, nil + return llms.CallLLM(ctx, o, prompt, options...) } func (o *LLM) appendGlobalsToArgs(opts llms.CallOptions) { @@ -110,45 +103,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return resp, nil } -// Generate generates completions using the local LLM binary. -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := &llms.CallOptions{} - for _, opt := range options { - opt(opts) - } - - // If o.client.GlobalAsArgs is true - if o.client.GlobalAsArgs { - // Then add the option to the args in --key=value format - o.appendGlobalsToArgs(*opts) - } - - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, &localclient.CompletionRequest{ - Prompt: prompt, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - generations = append(generations, &llms.Generation{Text: result.Text}) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - // New creates a new local LLM implementation. func New(opts ...Option) (*LLM, error) { options := &options{ diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index eaf7866f1..4c6f335b5 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -44,78 +44,6 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio return llms.CallLLM(ctx, o, prompt, options...) } -// Generate implemente the generate interface for LLM. -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - // Load back CallOptions as ollamaOptions - ollamaOptions := o.options.ollamaOptions - ollamaOptions.NumPredict = opts.MaxTokens - ollamaOptions.Temperature = float32(opts.Temperature) - ollamaOptions.Stop = opts.StopWords - ollamaOptions.TopK = opts.TopK - ollamaOptions.TopP = float32(opts.TopP) - ollamaOptions.Seed = opts.Seed - ollamaOptions.RepeatPenalty = float32(opts.RepetitionPenalty) - ollamaOptions.FrequencyPenalty = float32(opts.FrequencyPenalty) - ollamaOptions.PresencePenalty = float32(opts.PresencePenalty) - - // Override LLM model if set as llms.CallOption - model := o.options.model - if opts.Model != "" { - model = opts.Model - } - - generations := make([]*llms.Generation, 0, len(prompts)) - - for _, prompt := range prompts { - req := &ollamaclient.GenerateRequest{ - Model: model, - System: o.options.system, - Prompt: prompt, - Template: o.options.customModelTemplate, - Options: ollamaOptions, - Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), - } - - var fn ollamaclient.GenerateResponseFunc - - var output string - fn = func(response ollamaclient.GenerateResponse) error { - if opts.StreamingFunc != nil { - if err := opts.StreamingFunc(ctx, []byte(response.Response)); err != nil { - return err - } - } - output += response.Response - return nil - } - - err := o.client.Generate(ctx, req, fn) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return []*llms.Generation{}, err - } - - generations = append(generations, &llms.Generation{Text: output}) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - // GenerateContent implements the Model interface. // nolint: goerr113 func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 694d3ed24..3cfe94edb 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -134,48 +134,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return response, nil } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - - generations := make([]*llms.Generation, 0, len(prompts)) - for _, prompt := range prompts { - result, err := o.client.CreateCompletion(ctx, &openaiclient.CompletionRequest{ - Model: opts.Model, - Prompt: prompt, - MaxTokens: opts.MaxTokens, - StopWords: opts.StopWords, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - TopP: opts.TopP, - StreamingFunc: opts.StreamingFunc, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - generations = append(generations, &llms.Generation{ - Text: result.Text, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - - return generations, nil -} - // CreateEmbedding creates embeddings for the given input texts. func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { embeddings, err := o.client.CreateEmbedding(ctx, &openaiclient.EmbeddingRequest{ diff --git a/llms/vertexai/vertexai_palm_llm.go b/llms/vertexai/vertexai_palm_llm.go index f668ea239..52bf59a14 100644 --- a/llms/vertexai/vertexai_palm_llm.go +++ b/llms/vertexai/vertexai_palm_llm.go @@ -71,41 +71,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten return resp, nil } -func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMStart(ctx, prompts) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - results, err := o.client.CreateCompletion(ctx, &vertexaiclient.CompletionRequest{ - Prompts: prompts, - MaxTokens: opts.MaxTokens, - Temperature: opts.Temperature, - StopSequences: opts.StopWords, - }) - if err != nil { - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMError(ctx, err) - } - return nil, err - } - - generations := []*llms.Generation{} - for _, r := range results { - generations = append(generations, &llms.Generation{ - Text: r.Text, - }) - } - - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}}) - } - return generations, nil -} - // CreateEmbedding creates embeddings for the given input texts. func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{ From 9325c2c8f5e32dd4608bcfb91da5ffcd7940b483 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 05:58:01 -0800 Subject: [PATCH 16/19] Add alias and comments --- llms/llms.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/llms/llms.go b/llms/llms.go index c43d7c058..4feca7631 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -7,14 +7,21 @@ import ( "github.com/tmc/langchaingo/schema" ) -// LLM is a langchaingo Large Language Model. -type LLM interface { - Call(ctx context.Context, prompt string, options ...CallOption) (string, error) -} +// LLM is an alias for model, for backwards compatibility. +// +// This alias may be removed in the future; please use Model instead. +type LLM = Model // Model is an interface multi-modal models implement. // Note: this is an experimental API. type Model interface { + // Call is a simplified interace for Model, generating a single string + // response from a single string prompt. + // + // It is here for backwards compatibility only and may be removed in the + // future; please use GenerateContent instead. + Call(ctx context.Context, prompt string, options ...CallOption) (string, error) + // GenerateContent asks the model to generate content from a sequence of // messages. It's the most general interface for LLMs that support chat-like // interactions. @@ -22,6 +29,7 @@ type Model interface { } // Generation is a single generation from a langchaingo LLM. +// This type may be removed in the future; please don't use in new code. type Generation struct { // Text is the generated text. Text string `json:"text"` @@ -34,11 +42,14 @@ type Generation struct { } // LLMResult is the class that contains all relevant information for an LLM Result. +// This type may be removed in the future; please don't use in new code. type LLMResult struct { Generations [][]*Generation LLMOutput map[string]any } +// CallLLM is a helper function for implementing Call in terms of +// GenerateContent. It's aimed to be used by Model providers. func CallLLM(ctx context.Context, llm Model, prompt string, options ...CallOption) (string, error) { msg := MessageContent{ Role: schema.ChatMessageTypeHuman, From 0e4d69e6f19c81d1813c2af0b4862d6e6fb605eb Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 06:01:45 -0800 Subject: [PATCH 17/19] chains: dummy implementation of GenerateContent because of new interfaces --- chains/chains_test.go | 5 +++++ llms/llms.go | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/chains/chains_test.go b/chains/chains_test.go index 103fa946d..f30ea66ae 100644 --- a/chains/chains_test.go +++ b/chains/chains_test.go @@ -53,6 +53,11 @@ func (l *testLanguageModel) Call(_ context.Context, prompt string, _ ...llms.Cal return llmResult, nil } +func (l *testLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageContent, _ ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace + + panic("not implemented") +} + func (l *testLanguageModel) Generate( ctx context.Context, prompts []string, options ...llms.CallOption, ) ([]*llms.Generation, error) { diff --git a/llms/llms.go b/llms/llms.go index 4feca7631..2e78e5b76 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -15,7 +15,7 @@ type LLM = Model // Model is an interface multi-modal models implement. // Note: this is an experimental API. type Model interface { - // Call is a simplified interace for Model, generating a single string + // Call is a simplified interface for Model, generating a single string // response from a single string prompt. // // It is here for backwards compatibility only and may be removed in the From 8b30d4b229ceb4aca4a4b4a55db2ebe3f1b97e4f Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 07:07:35 -0800 Subject: [PATCH 18/19] all: clean up now unused types Generation and LLMResult --- callbacks/callbacks.go | 1 - callbacks/combining.go | 6 ------ callbacks/log.go | 16 ---------------- callbacks/simple.go | 1 - chains/chains_test.go | 14 -------------- llms/llms.go | 20 -------------------- 6 files changed, 58 deletions(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 15e83cc6c..2f4336b14 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -14,7 +14,6 @@ import ( type Handler interface { HandleText(ctx context.Context, text string) HandleLLMStart(ctx context.Context, prompts []string) - HandleLLMEnd(ctx context.Context, output llms.LLMResult) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) HandleLLMError(ctx context.Context, err error) diff --git a/callbacks/combining.go b/callbacks/combining.go index 0f3ef07ac..2e95e80aa 100644 --- a/callbacks/combining.go +++ b/callbacks/combining.go @@ -26,12 +26,6 @@ func (l CombiningHandler) HandleLLMStart(ctx context.Context, prompts []string) } } -func (l CombiningHandler) HandleLLMEnd(ctx context.Context, output llms.LLMResult) { - for _, handle := range l.Callbacks { - handle.HandleLLMEnd(ctx, output) - } -} - func (l CombiningHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) { for _, handle := range l.Callbacks { handle.HandleLLMGenerateContentStart(ctx, ms) diff --git a/callbacks/log.go b/callbacks/log.go index f7a9453e3..135b7fdcd 100644 --- a/callbacks/log.go +++ b/callbacks/log.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" - "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" ) @@ -25,10 +24,6 @@ func (l LogHandler) HandleLLMStart(_ context.Context, prompts []string) { fmt.Println("Entering LLM with prompts:", prompts) } -func (l LogHandler) HandleLLMEnd(_ context.Context, output llms.LLMResult) { - fmt.Println("Exiting LLM with results:", formatLLMResult(output)) -} - func (l LogHandler) HandleLLMError(_ context.Context, err error) { fmt.Println("Exiting LLM with error:", err) } @@ -82,17 +77,6 @@ func formatChainValues(values map[string]any) string { return output } -func formatLLMResult(output llms.LLMResult) string { - results := "[ " - for i := 0; i < len(output.Generations); i++ { - for j := 0; j < len(output.Generations[i]); j++ { - results += output.Generations[i][j].Text - } - } - - return results + " ]" -} - func formatAgentAction(action schema.AgentAction) string { return fmt.Sprintf("\"%s\" with input \"%s\"", removeNewLines(action.Tool), removeNewLines(action.ToolInput)) } diff --git a/callbacks/simple.go b/callbacks/simple.go index 3f8ded395..94cf54174 100644 --- a/callbacks/simple.go +++ b/callbacks/simple.go @@ -14,7 +14,6 @@ var _ Handler = SimpleHandler{} func (SimpleHandler) HandleText(context.Context, string) {} func (SimpleHandler) HandleLLMStart(context.Context, []string) {} -func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {} func (SimpleHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) {} func (SimpleHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) {} func (SimpleHandler) HandleLLMError(context.Context, error) {} diff --git a/chains/chains_test.go b/chains/chains_test.go index f30ea66ae..ce5410117 100644 --- a/chains/chains_test.go +++ b/chains/chains_test.go @@ -58,20 +58,6 @@ func (l *testLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageC panic("not implemented") } -func (l *testLanguageModel) Generate( - ctx context.Context, prompts []string, options ...llms.CallOption, -) ([]*llms.Generation, error) { - result, err := l.Call(ctx, prompts[0], options...) - if err != nil { - return nil, err - } - return []*llms.Generation{ - { - Text: result, - }, - }, nil -} - var _ llms.LLM = &testLanguageModel{} func TestApply(t *testing.T) { diff --git a/llms/llms.go b/llms/llms.go index 2e78e5b76..33f1b0220 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -28,26 +28,6 @@ type Model interface { GenerateContent(ctx context.Context, messages []MessageContent, options ...CallOption) (*ContentResponse, error) } -// Generation is a single generation from a langchaingo LLM. -// This type may be removed in the future; please don't use in new code. -type Generation struct { - // Text is the generated text. - Text string `json:"text"` - // Message stores the potentially generated message. - Message *schema.AIChatMessage `json:"message"` - // GenerationInfo is the generation info. This can contain vendor-specific information. - GenerationInfo map[string]any `json:"generation_info"` - // StopReason is the reason the generation stopped. - StopReason string `json:"stop_reason"` -} - -// LLMResult is the class that contains all relevant information for an LLM Result. -// This type may be removed in the future; please don't use in new code. -type LLMResult struct { - Generations [][]*Generation - LLMOutput map[string]any -} - // CallLLM is a helper function for implementing Call in terms of // GenerateContent. It's aimed to be used by Model providers. func CallLLM(ctx context.Context, llm Model, prompt string, options ...CallOption) (string, error) { From 75ee21080d8a95ff6b38737b333bab40eaf64996 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 17 Jan 2024 07:24:56 -0800 Subject: [PATCH 19/19] llms: clean up comments --- llms/generatecontent.go | 6 ++++-- llms/llms.go | 22 +++++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/llms/generatecontent.go b/llms/generatecontent.go index 894da7a23..8d91c2003 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -62,13 +62,15 @@ type BinaryContent struct { func (BinaryContent) isPart() {} // ContentResponse is the response returned by a GenerateContent call. -// It can potentially return multiple response choices. +// It can potentially return multiple content choices. type ContentResponse struct { Choices []*ContentChoice } -// ContentChoice is one of the response choices returned by GenerateModel calls. +// ContentChoice is one of the response choices returned by GenerateContent +// calls. type ContentChoice struct { + // Content is the textual content of a response Content string // StopReason is the reason the model stopped generating output. diff --git a/llms/llms.go b/llms/llms.go index 33f1b0220..8325d84b2 100644 --- a/llms/llms.go +++ b/llms/llms.go @@ -9,23 +9,23 @@ import ( // LLM is an alias for model, for backwards compatibility. // -// This alias may be removed in the future; please use Model instead. +// This alias may be removed in the future; please use Model +// instead. type LLM = Model // Model is an interface multi-modal models implement. -// Note: this is an experimental API. type Model interface { - // Call is a simplified interface for Model, generating a single string - // response from a single string prompt. - // - // It is here for backwards compatibility only and may be removed in the - // future; please use GenerateContent instead. - Call(ctx context.Context, prompt string, options ...CallOption) (string, error) - // GenerateContent asks the model to generate content from a sequence of - // messages. It's the most general interface for LLMs that support chat-like - // interactions. + // messages. It's the most general interface for multi-modal LLMs that support + // chat-like interactions. GenerateContent(ctx context.Context, messages []MessageContent, options ...CallOption) (*ContentResponse, error) + + // Call is a simplified interface for a text-only Model, generating a single + // string response from a single string prompt. + // + // It is here for backwards compatibility only and may be removed + // in the future; please use GenerateContent instead. + Call(ctx context.Context, prompt string, options ...CallOption) (string, error) } // CallLLM is a helper function for implementing Call in terms of