Skip to content

Commit

Permalink
Merge pull request #1086 from mathiasb/feature-mistral-embedding-pgve…
Browse files Browse the repository at this point in the history
…ctor

llms/mistral: Implementing embeddings.EmbedderClient for Mistral and an example with PGVector
  • Loading branch information
FluffyKebab authored Jan 28, 2025
2 parents 71ded3c + c62063e commit a4e1e5a
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 4 deletions.
22 changes: 22 additions & 0 deletions examples/mistral-embedding-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module github.com/tmc/langchaingo/examples/mistral-embedding-example

go 1.22.0

toolchain go1.22.1

require github.com/tmc/langchaingo v0.1.12

replace github.com/tmc/langchaingo => ../../

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gage-technologies/mistral-go v1.1.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/pgvector/pgvector-go v0.1.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/text v0.15.0 // indirect
)
248 changes: 248 additions & 0 deletions examples/mistral-embedding-example/go.sum

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions examples/mistral-embedding-example/mistral-embedding-example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"time"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/mistral"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector"
)

func main() {
var dsn string
flag.StringVar(&dsn, "dsn", "", "PGvector connection string")
flag.Parse()
model, err := mistral.New()
if err != nil {
panic(err)
}

e, err := embeddings.NewEmbedder(model)

if err != nil {
panic(err)
}

// Create a new pgvector store.
ctx := context.Background()
store, err := pgvector.New(
ctx,
pgvector.WithConnectionURL(dsn),
pgvector.WithEmbedder(e),
)
if err != nil {
log.Fatal("pgvector.New", err)
}

// Add documents to the pgvector store.
_, err = store.AddDocuments(context.Background(), []schema.Document{
{
PageContent: "Tokyo",
Metadata: map[string]any{
"population": 38,
"area": 2190,
},
},
{
PageContent: "Paris",
Metadata: map[string]any{
"population": 11,
"area": 105,
},
},
{
PageContent: "London",
Metadata: map[string]any{
"population": 9.5,
"area": 1572,
},
},
{
PageContent: "Santiago",
Metadata: map[string]any{
"population": 6.9,
"area": 641,
},
},
{
PageContent: "Buenos Aires",
Metadata: map[string]any{
"population": 15.5,
"area": 203,
},
},
{
PageContent: "Rio de Janeiro",
Metadata: map[string]any{
"population": 13.7,
"area": 1200,
},
},
{
PageContent: "Sao Paulo",
Metadata: map[string]any{
"population": 22.6,
"area": 1523,
},
},
})
if err != nil {
log.Fatal("store.AddDocuments:\n", err)
}
time.Sleep(1 * time.Second)

// Search for similar documents.
docs, err := store.SimilaritySearch(ctx, "japan", 1)
if err != nil {
log.Fatal("store.SimilaritySearch1:\n", err)
}
fmt.Println("store.SimilaritySearch1:\n", docs)

time.Sleep(2 * time.Second) // Don't trigger cloudflare

// Search for similar documents using score threshold.
docs, err = store.SimilaritySearch(ctx, "only cities in south america", 3, vectorstores.WithScoreThreshold(0.50))
if err != nil {
log.Fatal("store.SimilaritySearch2:\n", err)
}
fmt.Println("store.SimilaritySearch2:\n", docs)

time.Sleep(3 * time.Second) // Don't trigger cloudflare

// Search for similar documents using score threshold and metadata filter.
// Metadata filter for pgvector only supports key-value pairs for now.
filter := map[string]any{"area": "1523"} // Sao Paulo

docs, err = store.SimilaritySearch(ctx, "only cities in south america",
3,
vectorstores.WithScoreThreshold(0.50),
vectorstores.WithFilters(filter),
)
if err != nil {
log.Fatal("store.SimilaritySearch3:\n", err)
}
fmt.Println("store.SimilaritySearch3:\n", docs)
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ require (
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
Expand Down Expand Up @@ -196,7 +195,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1
github.com/cohere-ai/tokenizer v1.1.2
github.com/fatih/color v1.17.0
github.com/gage-technologies/mistral-go v1.0.0
github.com/gage-technologies/mistral-go v1.1.0
github.com/getzep/zep-go v1.0.4
github.com/go-openapi/strfmt v0.21.3
github.com/go-sql-driver/mysql v1.7.1
Expand All @@ -220,6 +219,7 @@ require (
github.com/weaviate/weaviate-go-client/v4 v4.13.1
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
go.mongodb.org/mongo-driver v1.14.0
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/tools v0.14.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
github.com/gage-technologies/mistral-go v1.1.0 h1:POv1wM9jA/9OBXGV2YdPi9Y/h09+MjCbUF+9hRYlVUI=
github.com/gage-technologies/mistral-go v1.1.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
github.com/getsentry/sentry-go v0.12.0 h1:era7g0re5iY13bHSdN/xMkyV+5zZppjRVQhZrXCaEIk=
Expand Down
36 changes: 36 additions & 0 deletions llms/mistral/mistralembed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package mistral

import (
"context"
"errors"
)

var ErrEmptyEmbeddings = errors.New("empty embeddings")

func convertFloat64ToFloat32(input []float64) []float32 {
// Create a slice with the same length as the input.
output := make([]float32, len(input))

// Iterate over the input slice and convert each element.
for i, v := range input {
output[i] = float32(v)
}

return output
}

// CreateEmbedding implements the embeddings.EmbedderClient interface and creates embeddings for the given input texts.
func (m *Model) CreateEmbedding(_ context.Context, inputTexts []string) ([][]float32, error) {
embsRes, err := m.client.Embeddings("mistral-embed", inputTexts)
if err != nil {
return nil, errors.New("failed to create embeddings: " + err.Error())
}
allEmbds := make([][]float32, len(embsRes.Data))
for i, embs := range embsRes.Data {
if len(embs.Embedding) == 0 {
return nil, ErrEmptyEmbeddings
}
allEmbds[i] = convertFloat64ToFloat32(embs.Embedding)
}
return allEmbds, nil
}
79 changes: 79 additions & 0 deletions llms/mistral/mistralembed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package mistral

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/embeddings"
)

// TestConvertFloat64ToFloat32 tests the ConvertFloat64ToFloat32 function using table-driven tests.
func TestConvertFloat64ToFloat32(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []float64
expected []float32
}{
{
name: "empty slice",
input: []float64{},
expected: []float32{},
},
{
name: "single element",
input: []float64{3.14},
expected: []float32{3.14},
},
{
name: "multiple elements",
input: []float64{1.23, 4.56, 7.89},
expected: []float32{1.23, 4.56, 7.89},
},
{
name: "zero values",
input: []float64{0.0, 0.0, 0.0},
expected: []float32{0.0, 0.0, 0.0},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
output := convertFloat64ToFloat32(tt.input)

require.Equal(t, len(tt.expected), len(output), "length mismatch")
for i := range output {
require.Equal(t, tt.expected[i], output[i], "at index %d", i)
}
})
}
}

func TestMistralEmbed(t *testing.T) {
t.Parallel()
envVar := "MISTRAL_API_KEY"

// Get the value of the environment variable
value := os.Getenv(envVar)

// Check if it is set (non-empty)
if value == "" {
t.Skipf("Environment variable %s is not set, so skipping the test", envVar)
return
}

model, err := New()
require.NoError(t, err)

e, err := embeddings.NewEmbedder(model)
require.NoError(t, err)

_, err = e.EmbedDocuments(context.Background(), []string{"Hello world"})
require.NoError(t, err)

_, err = e.EmbedQuery(context.Background(), "Hello world")
require.NoError(t, err)
}

0 comments on commit a4e1e5a

Please sign in to comment.