Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

all: clean up now unused types Generation and LLMResult #523

Merged
merged 18 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import (
type Handler interface {
HandleText(ctx context.Context, text string)
HandleLLMStart(ctx context.Context, prompts []string)
HandleLLMEnd(ctx context.Context, output llms.LLMResult)
HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent)
HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse)
HandleLLMError(ctx context.Context, err error)
HandleChainStart(ctx context.Context, inputs map[string]any)
HandleChainEnd(ctx context.Context, outputs map[string]any)
Expand Down
10 changes: 8 additions & 2 deletions callbacks/combining.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ func (l CombiningHandler) HandleLLMStart(ctx context.Context, prompts []string)
}
}

func (l CombiningHandler) HandleLLMEnd(ctx context.Context, output llms.LLMResult) {
func (l CombiningHandler) HandleLLMGenerateContentStart(ctx context.Context, ms []llms.MessageContent) {
for _, handle := range l.Callbacks {
handle.HandleLLMEnd(ctx, output)
handle.HandleLLMGenerateContentStart(ctx, ms)
}
}

func (l CombiningHandler) HandleLLMGenerateContentEnd(ctx context.Context, res *llms.ContentResponse) {
for _, handle := range l.Callbacks {
handle.HandleLLMGenerateContentEnd(ctx, res)
}
}

Expand Down
16 changes: 0 additions & 16 deletions callbacks/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"strings"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

Expand All @@ -25,10 +24,6 @@ func (l LogHandler) HandleLLMStart(_ context.Context, prompts []string) {
fmt.Println("Entering LLM with prompts:", prompts)
}

func (l LogHandler) HandleLLMEnd(_ context.Context, output llms.LLMResult) {
fmt.Println("Exiting LLM with results:", formatLLMResult(output))
}

func (l LogHandler) HandleLLMError(_ context.Context, err error) {
fmt.Println("Exiting LLM with error:", err)
}
Expand Down Expand Up @@ -82,17 +77,6 @@ func formatChainValues(values map[string]any) string {
return output
}

func formatLLMResult(output llms.LLMResult) string {
results := "[ "
for i := 0; i < len(output.Generations); i++ {
for j := 0; j < len(output.Generations[i]); j++ {
results += output.Generations[i][j].Text
}
}

return results + " ]"
}

func formatAgentAction(action schema.AgentAction) string {
return fmt.Sprintf("\"%s\" with input \"%s\"", removeNewLines(action.Tool), removeNewLines(action.ToolInput))
}
Expand Down
31 changes: 16 additions & 15 deletions callbacks/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ type SimpleHandler struct{}

var _ Handler = SimpleHandler{}

func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMEnd(context.Context, llms.LLMResult) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {}
func (SimpleHandler) HandleText(context.Context, string) {}
func (SimpleHandler) HandleLLMStart(context.Context, []string) {}
func (SimpleHandler) HandleLLMGenerateContentStart(context.Context, []llms.MessageContent) {}
func (SimpleHandler) HandleLLMGenerateContentEnd(context.Context, *llms.ContentResponse) {}
func (SimpleHandler) HandleLLMError(context.Context, error) {}
func (SimpleHandler) HandleChainStart(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainEnd(context.Context, map[string]any) {}
func (SimpleHandler) HandleChainError(context.Context, error) {}
func (SimpleHandler) HandleToolStart(context.Context, string) {}
func (SimpleHandler) HandleToolEnd(context.Context, string) {}
func (SimpleHandler) HandleToolError(context.Context, error) {}
func (SimpleHandler) HandleAgentAction(context.Context, schema.AgentAction) {}
func (SimpleHandler) HandleAgentFinish(context.Context, schema.AgentFinish) {}
func (SimpleHandler) HandleRetrieverStart(context.Context, string) {}
func (SimpleHandler) HandleRetrieverEnd(context.Context, string, []schema.Document) {}
func (SimpleHandler) HandleStreamingFunc(context.Context, []byte) {}
15 changes: 3 additions & 12 deletions chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,9 @@ func (l *testLanguageModel) Call(_ context.Context, prompt string, _ ...llms.Cal
return llmResult, nil
}

func (l *testLanguageModel) Generate(
ctx context.Context, prompts []string, options ...llms.CallOption,
) ([]*llms.Generation, error) {
result, err := l.Call(ctx, prompts[0], options...)
if err != nil {
return nil, err
}
return []*llms.Generation{
{
Text: result,
},
}, nil
func (l *testLanguageModel) GenerateContent(_ context.Context, _ []llms.MessageContent, _ ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace

panic("not implemented")
}

var _ llms.LLM = &testLanguageModel{}
Expand Down
16 changes: 0 additions & 16 deletions embeddings/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,3 @@ func TestOpenaiEmbeddingsWithAzureAPI(t *testing.T) {
require.NoError(t, err)
assert.Len(t, embeddings, 1)
}

func TestUseLLMAndChatAsEmbedderClient(t *testing.T) {
t.Parallel()

if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" {
t.Skip("OPENAI_API_KEY not set")
}

// Shows that we can pass an openai chat value to NewEmbedder
chat, err := openai.NewChat()
require.NoError(t, err)

embedderFromChat, err := NewEmbedder(chat)
require.NoError(t, err)
var _ Embedder = embedderFromChat
}
64 changes: 30 additions & 34 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,50 +50,46 @@ func newClient(opts ...Option) (*anthropicclient.Client, error) {

// Call requests a completion for the given prompt.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := o.Generate(ctx, []string{prompt}, options...)
if err != nil {
return "", err
}
if len(r) == 0 {
return "", ErrEmptyResponse
}
return r[0].Text, nil
return llms.CallLLM(ctx, o, prompt, options...)
}

func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
// GenerateContent implements the Model interface.
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMStart(ctx, prompts)
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{}
opts := &llms.CallOptions{}
for _, opt := range options {
opt(&opts)
opt(opts)
}

generations := make([]*llms.Generation, 0, len(prompts))
for _, prompt := range prompts {
result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{
Model: opts.Model,
Prompt: prompt,
MaxTokens: opts.MaxTokens,
StopWords: opts.StopWords,
Temperature: opts.Temperature,
TopP: opts.TopP,
StreamingFunc: opts.StreamingFunc,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
// Assume we get a single text message
msg0 := messages[0]
part := msg0.Parts[0]
result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{
Model: opts.Model,
Prompt: part.(llms.TextContent).Text,
MaxTokens: opts.MaxTokens,
StopWords: opts.StopWords,
Temperature: opts.Temperature,
TopP: opts.TopP,
StreamingFunc: opts.StreamingFunc,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
generations = append(generations, &llms.Generation{
Text: result.Text,
})
return nil, err
}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}})
resp := &llms.ContentResponse{
Choices: []*llms.ContentChoice{
{
Content: result.Text,
},
},
}
return generations, nil
return resp, nil
}
55 changes: 24 additions & 31 deletions llms/cohere/coherellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,49 +25,42 @@ type LLM struct {
var _ llms.LLM = (*LLM)(nil)

func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := o.Generate(ctx, []string{prompt}, options...)
if err != nil {
return "", err
}
if len(r) == 0 {
return "", ErrEmptyResponse
}
return r[0].Text, nil
return llms.CallLLM(ctx, o, prompt, options...)
}

func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
// GenerateContent implements the Model interface.
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMStart(ctx, prompts)
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}

opts := llms.CallOptions{}
opts := &llms.CallOptions{}
for _, opt := range options {
opt(&opts)
opt(opts)
}

generations := make([]*llms.Generation, 0, len(prompts))

for _, prompt := range prompts {
result, err := o.client.CreateGeneration(ctx, &cohereclient.GenerationRequest{
Prompt: prompt,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
// Assume we get a single text message
msg0 := messages[0]
part := msg0.Parts[0]
result, err := o.client.CreateGeneration(ctx, &cohereclient.GenerationRequest{
Prompt: part.(llms.TextContent).Text,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}

generations = append(generations, &llms.Generation{
Text: result.Text,
})
return nil, err
}

if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}})
resp := &llms.ContentResponse{
Choices: []*llms.ContentChoice{
{
Content: result.Text,
},
},
}

return generations, nil
return resp, nil
}

func New(opts ...Option) (*LLM, error) {
Expand Down
Loading
Loading