Skip to content

Commit

Permalink
LLM: Add support for Ernie chat_completions
Browse files Browse the repository at this point in the history
The current implementation lacks support for Ernie chat_completions, so I've added code for this. The committed changes include basic stream calls, function calls, and other capabilities.

The invocation of Ernie is based on the official documentation at https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11.

The current implementation lacks support for Ernie chat_completions, so I've added code for this. The committed changes include basic stream calls, function calls, and other capabilities.

The invocation of Ernie is based on the official documentation at https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11.

For documentation on this content in the Python version of Langchain, please refer to https://python.langchain.com/docs/integrations/chat/ernie.
  • Loading branch information
sxk10812139 committed Nov 15, 2023
1 parent 09a09b3 commit 51245f7
Show file tree
Hide file tree
Showing 13 changed files with 814 additions and 3 deletions.
38 changes: 38 additions & 0 deletions examples/ernie-chat-example/ernie_chat_example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package main

import (
"context"
"fmt"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ernie"
"log"

"github.com/tmc/langchaingo/schema"
)

func main() {
llm, err := ernie.NewChat(
ernie.WithModelName(ernie.ModelNameERNIEBot),
// Fill in your AK and SK here.
ernie.WithAKSK("ak", "sk"),
// Use an external cache for the access token.
ernie.WithAccessToken("accesstoken"),
)
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
completion, err := llm.Call(ctx, []schema.ChatMessage{
schema.SystemChatMessage{Content: "Hello, I am a friendly chatbot. I love to talk about movies, books and music. Answer in long form yaml."},
schema.HumanChatMessage{Content: "What would be a good company name a company that makes colorful socks?"},
}, llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
log.Println(string(chunk))
return nil
}),
)
if err != nil {
log.Fatal(err)
}

fmt.Println(completion)
}
15 changes: 15 additions & 0 deletions examples/ernie-chat-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module github.com/tmc/langchaingo/examples/openai-chat-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
)

replace (
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 => "../../"
)
12 changes: 12 additions & 0 deletions examples/ernie-chat-example/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY=
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 h1:vIYDdskaXk+iwfrMIaN9dM66o3wUedqvWWww1o7a1m4=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823/go.mod h1:wwzKIaam0XFmiWfTlvSvdKwq7CkxE9Tz5rIkz1KKDws=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package main

import (
"context"
"encoding/json"
"fmt"
"github.com/tmc/langchaingo/llms/ernie"
"log"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

func main() {
llm, err := ernie.NewChat(
ernie.WithModelName(ernie.ModelNameERNIEBot),
// Fill in your AK and SK here.
ernie.WithAKSK("ak", "sk"),
// Use an external cache for the access token.
ernie.WithAccessToken("accesstoken"),
)
if err != nil {
log.Fatal(err)
}
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
completion, err := llm.Call(ctx, []schema.ChatMessage{
schema.HumanChatMessage{Content: "What is the weather like in Boston?"},
}, llms.WithFunctions(functions))
if err != nil {
log.Fatal(err)
}

if completion != nil {
fmt.Printf("Function call: %v\n", completion.FunctionCall)
}
}

func getCurrentWeather(location string, unit string) (string, error) {
weatherInfo := map[string]interface{}{
"location": location,
"temperature": "72",
"unit": unit,
"forecast": []string{"sunny", "windy"},
}
b, err := json.Marshal(weatherInfo)
if err != nil {
return "", err
}
return string(b), nil
}

var functions = []llms.FunctionDefinition{
{
Name: "getCurrentWeather",
Description: "Get the current weather in a given location",
Parameters: json.RawMessage(`{"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}`),
},
}
15 changes: 15 additions & 0 deletions examples/ernie-function-call-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module github.com/tmc/langchaingo/examples/openai-function-call-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
)

replace (
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 => "../../"
)
12 changes: 12 additions & 0 deletions examples/ernie-function-call-example/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY=
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 h1:vIYDdskaXk+iwfrMIaN9dM66o3wUedqvWWww1o7a1m4=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823/go.mod h1:wwzKIaam0XFmiWfTlvSvdKwq7CkxE9Tz5rIkz1KKDws=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package main

import (
"context"
"encoding/json"
"fmt"
"github.com/tmc/langchaingo/llms/ernie"
"log"

"github.com/tmc/langchaingo/jsonschema"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
)

func main() {
llm, err := ernie.NewChat(
ernie.WithModelName(ernie.ModelNameERNIEBot),
// Fill in your AK and SK here.
ernie.WithAKSK("ak", "sk"),
// Use an external cache for the access token.
ernie.WithAccessToken("accesstoken"),
)
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
completion, err := llm.Call(ctx, []schema.ChatMessage{
schema.HumanChatMessage{Content: "What is the weather going to be like in Boston?"},
}, llms.WithFunctions(functions), llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
fmt.Printf("Received chunk: %s\n", chunk)
return nil
}))
if err != nil {
log.Fatal(err)
}

if completion.FunctionCall != nil {
fmt.Printf("Function call: %+v\n", completion.FunctionCall)
}
fmt.Println(completion.Content)
}

func getCurrentWeather(location string, unit string) (string, error) {
weatherInfo := map[string]interface{}{
"location": location,
"temperature": "72",
"unit": unit,
"forecast": []string{"sunny", "windy"},
}
b, err := json.Marshal(weatherInfo)
if err != nil {
return "", err
}
return string(b), nil
}

// json.RawMessage(`{"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}`),

var functions = []llms.FunctionDefinition{
{
Name: "getCurrentWeather",
Description: "Get the current weather in a given location",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"rationale": {
Type: jsonschema.String,
Description: "The rationale for choosing this function call with these parameters",
},
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celsius", "fahrenheit"},
},
},
Required: []string{"rationale", "location"},
},
},
{
Name: "getTomorrowWeather",
Description: "Get the predicted weather in a given location",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"rationale": {
Type: jsonschema.String,
Description: "The rationale for choosing this function call with these parameters",
},
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celsius", "fahrenheit"},
},
},
Required: []string{"rationale", "location"},
},
},
{
Name: "getSuggestedPrompts",
Description: "Given the user's input prompt suggest some related prompts",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"rationale": {
Type: jsonschema.String,
Description: "The rationale for choosing this function call with these parameters",
},
"suggestions": {
Type: jsonschema.Array,
Items: &jsonschema.Definition{
Type: jsonschema.String,
Description: "A suggested prompt",
},
},
},
Required: []string{"rationale", "suggestions"},
},
},
}
15 changes: 15 additions & 0 deletions examples/ernie-function-call-streaming-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module github.com/tmc/langchaingo/examples/openai-function-call-streaming-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
)

replace (
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 => "../../"
)
12 changes: 12 additions & 0 deletions examples/ernie-function-call-streaming-example/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY=
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 h1:vIYDdskaXk+iwfrMIaN9dM66o3wUedqvWWww1o7a1m4=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823/go.mod h1:wwzKIaam0XFmiWfTlvSvdKwq7CkxE9Tz5rIkz1KKDws=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Loading

0 comments on commit 51245f7

Please sign in to comment.