Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms/mistral: Implementing embeddings.EmbedderClient for Mistral and an example with PGVector #1086

Merged
merged 12 commits into from
Jan 28, 2025
22 changes: 22 additions & 0 deletions examples/mistral-embedding-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module github.com/tmc/langchaingo/examples/mistral-embedding-example

go 1.22.0

toolchain go1.22.1

require github.com/tmc/langchaingo v0.1.12

replace github.com/tmc/langchaingo => ../../

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gage-technologies/mistral-go v1.1.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/pgvector/pgvector-go v0.1.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/text v0.15.0 // indirect
)
248 changes: 248 additions & 0 deletions examples/mistral-embedding-example/go.sum

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions examples/mistral-embedding-example/mistral-embedding-example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"time"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/mistral"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector"
)

func main() {
var dsn string
flag.StringVar(&dsn, "dsn", "", "PGvector connection string")
flag.Parse()
model, err := mistral.New()
if err != nil {
panic(err)
}

e, err := embeddings.NewEmbedder(model)

if err != nil {
panic(err)
}

// Create a new pgvector store.
ctx := context.Background()
store, err := pgvector.New(
ctx,
pgvector.WithConnectionURL(dsn),
pgvector.WithEmbedder(e),
)
if err != nil {
log.Fatal("pgvector.New", err)
}

// Add documents to the pgvector store.
_, err = store.AddDocuments(context.Background(), []schema.Document{
{
PageContent: "Tokyo",
Metadata: map[string]any{
"population": 38,
"area": 2190,
},
},
{
PageContent: "Paris",
Metadata: map[string]any{
"population": 11,
"area": 105,
},
},
{
PageContent: "London",
Metadata: map[string]any{
"population": 9.5,
"area": 1572,
},
},
{
PageContent: "Santiago",
Metadata: map[string]any{
"population": 6.9,
"area": 641,
},
},
{
PageContent: "Buenos Aires",
Metadata: map[string]any{
"population": 15.5,
"area": 203,
},
},
{
PageContent: "Rio de Janeiro",
Metadata: map[string]any{
"population": 13.7,
"area": 1200,
},
},
{
PageContent: "Sao Paulo",
Metadata: map[string]any{
"population": 22.6,
"area": 1523,
},
},
})
if err != nil {
log.Fatal("store.AddDocuments:\n", err)
}
time.Sleep(1 * time.Second)

// Search for similar documents.
docs, err := store.SimilaritySearch(ctx, "japan", 1)
if err != nil {
log.Fatal("store.SimilaritySearch1:\n", err)
}
fmt.Println("store.SimilaritySearch1:\n", docs)

time.Sleep(2 * time.Second) // Don't trigger cloudflare
mathiasb marked this conversation as resolved.
Show resolved Hide resolved

// Search for similar documents using score threshold.
docs, err = store.SimilaritySearch(ctx, "only cities in south america", 3, vectorstores.WithScoreThreshold(0.50))
if err != nil {
log.Fatal("store.SimilaritySearch2:\n", err)
}
fmt.Println("store.SimilaritySearch2:\n", docs)

time.Sleep(3 * time.Second) // Don't trigger cloudflare

// Search for similar documents using score threshold and metadata filter.
// Metadata filter for pgvector only supports key-value pairs for now.
filter := map[string]any{"area": "1523"} // Sao Paulo

docs, err = store.SimilaritySearch(ctx, "only cities in south america",
3,
vectorstores.WithScoreThreshold(0.50),
vectorstores.WithFilters(filter),
)
if err != nil {
log.Fatal("store.SimilaritySearch3:\n", err)
}
fmt.Println("store.SimilaritySearch3:\n", docs)
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ require (
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
Expand Down Expand Up @@ -196,7 +195,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1
github.com/cohere-ai/tokenizer v1.1.2
github.com/fatih/color v1.17.0
github.com/gage-technologies/mistral-go v1.0.0
github.com/gage-technologies/mistral-go v1.1.0
github.com/getzep/zep-go v1.0.4
github.com/go-openapi/strfmt v0.21.3
github.com/go-sql-driver/mysql v1.7.1
Expand All @@ -220,6 +219,7 @@ require (
github.com/weaviate/weaviate-go-client/v4 v4.13.1
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
go.mongodb.org/mongo-driver v1.14.0
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/tools v0.14.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
github.com/gage-technologies/mistral-go v1.1.0 h1:POv1wM9jA/9OBXGV2YdPi9Y/h09+MjCbUF+9hRYlVUI=
github.com/gage-technologies/mistral-go v1.1.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
github.com/getsentry/sentry-go v0.12.0 h1:era7g0re5iY13bHSdN/xMkyV+5zZppjRVQhZrXCaEIk=
Expand Down
36 changes: 36 additions & 0 deletions llms/mistral/mistralembed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package mistral

import (
"context"
"errors"
)

var ErrEmptyEmbeddings = errors.New("empty embeddings")

func convertFloat64ToFloat32(input []float64) []float32 {
// Create a slice with the same length as the input.
output := make([]float32, len(input))

// Iterate over the input slice and convert each element.
for i, v := range input {
output[i] = float32(v)
}

return output
}

// CreateEmbedding implements the embeddings.EmbedderClient interface and creates embeddings for the given input texts.
func (m *Model) CreateEmbedding(_ context.Context, inputTexts []string) ([][]float32, error) {
embsRes, err := m.client.Embeddings("mistral-embed", inputTexts)
if err != nil {
return nil, errors.New("failed to create embeddings: " + err.Error())
}
allEmbds := make([][]float32, len(embsRes.Data))
for i, embs := range embsRes.Data {
if len(embs.Embedding) == 0 {
return nil, ErrEmptyEmbeddings
}
allEmbds[i] = convertFloat64ToFloat32(embs.Embedding)
}
return allEmbds, nil
}
79 changes: 79 additions & 0 deletions llms/mistral/mistralembed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package mistral

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/embeddings"
)

// TestConvertFloat64ToFloat32 tests the ConvertFloat64ToFloat32 function using table-driven tests.
func TestConvertFloat64ToFloat32(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []float64
expected []float32
}{
{
name: "empty slice",
input: []float64{},
expected: []float32{},
},
{
name: "single element",
input: []float64{3.14},
expected: []float32{3.14},
},
{
name: "multiple elements",
input: []float64{1.23, 4.56, 7.89},
expected: []float32{1.23, 4.56, 7.89},
},
{
name: "zero values",
input: []float64{0.0, 0.0, 0.0},
expected: []float32{0.0, 0.0, 0.0},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
output := convertFloat64ToFloat32(tt.input)

require.Equal(t, len(tt.expected), len(output), "length mismatch")
for i := range output {
require.Equal(t, tt.expected[i], output[i], "at index %d", i)
}
})
}
}

func TestMistralEmbed(t *testing.T) {
t.Parallel()
envVar := "MISTRAL_API_KEY"

// Get the value of the environment variable
value := os.Getenv(envVar)

// Check if it is set (non-empty)
if value == "" {
mathiasb marked this conversation as resolved.
Show resolved Hide resolved
t.Skipf("Environment variable %s is not set, so skipping the test", envVar)
return
}

model, err := New()
require.NoError(t, err)

e, err := embeddings.NewEmbedder(model)
require.NoError(t, err)

_, err = e.EmbedDocuments(context.Background(), []string{"Hello world"})
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world")
require.NoError(t, err)
}
Loading