Skip to content

Commit

Permalink
Add new callback for GenerateContent start and end
Browse files Browse the repository at this point in the history
And invoke it in the new methods

Re #465
  • Loading branch information
eliben committed Jan 17, 2024
1 parent 2335fa4 commit 405efea
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 22 deletions.
2 changes: 2 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Handler interface {
HandleText(ctx context.Context, text string)
HandleLLMStart(ctx context.Context, prompts []string)
HandleLLMEnd(ctx context.Context, output llms.LLMResult)
HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent)
HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse)
HandleLLMError(ctx context.Context, err error)
HandleChainStart(ctx context.Context, inputs map[string]any)
HandleChainEnd(ctx context.Context, outputs map[string]any)
Expand Down
12 changes: 12 additions & 0 deletions callbacks/combining.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ func (l CombiningHandler) HandleLLMEnd(ctx context.Context, output llms.LLMResul
}
}

func (l CombiningHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) {
for _, handle := range l.Callbacks {
handle.HandleLLMGenerateContentStart(ctx, ms)
}
}

func (l CombiningHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) {
for _, handle := range l.Callbacks {
handle.HandleLLMGenerateContentEnd(ctx, res)
}
}

func (l CombiningHandler) HandleChainStart(ctx context.Context, inputs map[string]any) {
for _, handle := range l.Callbacks {
handle.HandleChainStart(ctx, inputs)
Expand Down
32 changes: 17 additions & 15 deletions callbacks/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ type SimpleHandler struct{}

var _ Handler = SimpleHandler{}

func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {}
func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {}
func (SimpleHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) {}
func (SimpleHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {}
27 changes: 23 additions & 4 deletions llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strings"

"github.com/google/generative-ai-go/genai"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"google.golang.org/api/iterator"
Expand All @@ -22,8 +23,9 @@ import (

// GoogleAI is a type that represents a Google AI API client.
type GoogleAI struct {
client *genai.Client
opts options
CallbacksHandler callbacks.Handler
client *genai.Client
opts options
}

var (
Expand Down Expand Up @@ -64,6 +66,10 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) {

// GenerateContent calls the LLM with the provided parts.
func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
if g.CallbacksHandler != nil {
g.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{
Model: g.opts.defaultModel,
MaxTokens: int(g.opts.defaultMaxTokens),
Expand All @@ -77,14 +83,27 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
model.SetMaxOutputTokens(int32(opts.MaxTokens))
model.SetTemperature(float32(opts.Temperature))

var response *llms.ContentResponse
var err error

if len(messages) == 1 {
theMessage := messages[0]
if theMessage.Role != schema.ChatMessageTypeHuman {
return nil, fmt.Errorf("got %v message role, want human", theMessage.Role)
}
return generateFromSingleMessage(ctx, model, theMessage.Parts, &opts)
response, err = generateFromSingleMessage(ctx, model, theMessage.Parts, &opts)
} else {
response, err = generateFromMessages(ctx, model, messages, &opts)
}
return generateFromMessages(ctx, model, messages, &opts)
if err != nil {
return nil, err
}

if g.CallbacksHandler != nil {
g.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
}

return response, nil
}

// downloadImageData downloads the content from the given URL and returns it as
Expand Down
2 changes: 1 addition & 1 deletion llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Model interface {
// GenerateContent asks the model to generate content from a sequence of
// messages. It's the most general interface for LLMs that support chat-like
// interactions.
GenerateContent(ctx context.Context, parts []MessageContent, options ...CallOption) (*ContentResponse, error)
GenerateContent(ctx context.Context, messages []MessageContent, options ...CallOption) (*ContentResponse, error)
}

// Generation is a single generation from a langchaingo LLM.
Expand Down
12 changes: 11 additions & 1 deletion llms/ollama/ollamallm_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage,
// GenerateContent implements the Model interface.
// nolint: goerr113
func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
Expand Down Expand Up @@ -220,7 +224,13 @@ func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageConte
},
}

return &llms.ContentResponse{Choices: choices}, nil
response := &llms.ContentResponse{Choices: choices}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
}

return response, nil
}

func makeGenerationFromChatResponse(resp ollamaclient.ChatResponse) *llms.Generation {
Expand Down
12 changes: 11 additions & 1 deletion llms/openai/openaillm_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func NewChat(opts ...Option) (*Chat, error) {
//
//nolint:goerr113
func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
Expand Down Expand Up @@ -117,7 +121,13 @@ func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageConte
}
}

return &llms.ContentResponse{Choices: choices}, nil
response := &llms.ContentResponse{Choices: choices}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
}

return response, nil
}

// Call requests a chat response for the given messages.
Expand Down

0 comments on commit 405efea

Please sign in to comment.