Skip to content

Commit

Permalink
Merge pull request #514 from tmc/cb
Browse files Browse the repository at this point in the history
Add callback handler method for GenerateContent start and end
  • Loading branch information
tmc authored Jan 17, 2024
2 parents 2335fa4 + 405efea commit bf1b8da
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 22 deletions.
2 changes: 2 additions & 0 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Handler interface {
HandleText(ctx context.Context, text string)
HandleLLMStart(ctx context.Context, prompts []string)
HandleLLMEnd(ctx context.Context, output llms.LLMResult)
HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent)
HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse)
HandleLLMError(ctx context.Context, err error)
HandleChainStart(ctx context.Context, inputs map[string]any)
HandleChainEnd(ctx context.Context, outputs map[string]any)
Expand Down
12 changes: 12 additions & 0 deletions callbacks/combining.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ func (l CombiningHandler) HandleLLMEnd(ctx context.Context, output llms.LLMResul
}
}

func (l CombiningHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) {
for _, handle := range l.Callbacks {
handle.HandleLLMGenerateContentStart(ctx, ms)
}
}

func (l CombiningHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) {
for _, handle := range l.Callbacks {
handle.HandleLLMGenerateContentEnd(ctx, res)
}
}

func (l CombiningHandler) HandleChainStart(ctx context.Context, inputs map[string]any) {
for _, handle := range l.Callbacks {
handle.HandleChainStart(ctx, inputs)
Expand Down
32 changes: 17 additions & 15 deletions callbacks/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ type SimpleHandler struct{}

var _ Handler = SimpleHandler{}

func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {}
func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {}
func (SimpleHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) {}
func (SimpleHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {}
27 changes: 23 additions & 4 deletions llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strings"

"github.com/google/generative-ai-go/genai"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"google.golang.org/api/iterator"
Expand All @@ -22,8 +23,9 @@ import (

// GoogleAI is a type that represents a Google AI API client.
type GoogleAI struct {
client *genai.Client
opts options
CallbacksHandler callbacks.Handler
client *genai.Client
opts options
}

var (
Expand Down Expand Up @@ -64,6 +66,10 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) {

// GenerateContent calls the LLM with the provided parts.
func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
if g.CallbacksHandler != nil {
g.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{
Model: g.opts.defaultModel,
MaxTokens: int(g.opts.defaultMaxTokens),
Expand All @@ -77,14 +83,27 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
model.SetMaxOutputTokens(int32(opts.MaxTokens))
model.SetTemperature(float32(opts.Temperature))

var response *llms.ContentResponse
var err error

if len(messages) == 1 {
theMessage := messages[0]
if theMessage.Role != schema.ChatMessageTypeHuman {
return nil, fmt.Errorf("got %v message role, want human", theMessage.Role)
}
return generateFromSingleMessage(ctx, model, theMessage.Parts, &opts)
response, err = generateFromSingleMessage(ctx, model, theMessage.Parts, &opts)
} else {
response, err = generateFromMessages(ctx, model, messages, &opts)
}
return generateFromMessages(ctx, model, messages, &opts)
if err != nil {
return nil, err
}

if g.CallbacksHandler != nil {
g.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
}

return response, nil
}

// downloadImageData downloads the content from the given URL and returns it as
Expand Down
2 changes: 1 addition & 1 deletion llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Model interface {
// GenerateContent asks the model to generate content from a sequence of
// messages. It's the most general interface for LLMs that support chat-like
// interactions.
GenerateContent(ctx context.Context, parts []MessageContent, options ...CallOption) (*ContentResponse, error)
GenerateContent(ctx context.Context, messages []MessageContent, options ...CallOption) (*ContentResponse, error)
}

// Generation is a single generation from a langchaingo LLM.
Expand Down
12 changes: 11 additions & 1 deletion llms/ollama/ollamallm_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage,
// GenerateContent implements the Model interface.
// nolint: goerr113
func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
Expand Down Expand Up @@ -220,7 +224,13 @@ func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageConte
},
}

return &llms.ContentResponse{Choices: choices}, nil
response := &llms.ContentResponse{Choices: choices}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
}

return response, nil
}

func makeGenerationFromChatResponse(resp ollamaclient.ChatResponse) *llms.Generation {
Expand Down
12 changes: 11 additions & 1 deletion llms/openai/openaillm_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func NewChat(opts ...Option) (*Chat, error) {
//
//nolint:goerr113
func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
Expand Down Expand Up @@ -117,7 +121,13 @@ func (o *Chat) GenerateContent(ctx context.Context, messages []llms.MessageConte
}
}

return &llms.ContentResponse{Choices: choices}, nil
response := &llms.ContentResponse{Choices: choices}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, response)
}

return response, nil
}

// Call requests a chat response for the given messages.
Expand Down

0 comments on commit bf1b8da

Please sign in to comment.