-
-
Notifications
You must be signed in to change notification settings - Fork 695
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1086 from mathiasb/feature-mistral-embedding-pgve…
…ctor llms/mistral: Implementing embeddings.EmbedderClient for Mistral and an example with PGVector
- Loading branch information
Showing
7 changed files
with
520 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
module github.com/tmc/langchaingo/examples/mistral-embedding-example | ||
|
||
go 1.22.0 | ||
|
||
toolchain go1.22.1 | ||
|
||
require github.com/tmc/langchaingo v0.1.12 | ||
|
||
replace github.com/tmc/langchaingo => ../../ | ||
|
||
require ( | ||
github.com/dlclark/regexp2 v1.10.0 // indirect | ||
github.com/gage-technologies/mistral-go v1.1.0 // indirect | ||
github.com/google/uuid v1.6.0 // indirect | ||
github.com/jackc/pgpassfile v1.0.0 // indirect | ||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect | ||
github.com/jackc/pgx/v5 v5.5.5 // indirect | ||
github.com/pgvector/pgvector-go v0.1.1 // indirect | ||
github.com/pkoukk/tiktoken-go v0.1.6 // indirect | ||
golang.org/x/crypto v0.23.0 // indirect | ||
golang.org/x/text v0.15.0 // indirect | ||
) |
Large diffs are not rendered by default.
Oops, something went wrong.
131 changes: 131 additions & 0 deletions
131
examples/mistral-embedding-example/mistral-embedding-example.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"flag" | ||
"fmt" | ||
"log" | ||
"time" | ||
|
||
"github.com/tmc/langchaingo/embeddings" | ||
"github.com/tmc/langchaingo/llms/mistral" | ||
"github.com/tmc/langchaingo/schema" | ||
"github.com/tmc/langchaingo/vectorstores" | ||
"github.com/tmc/langchaingo/vectorstores/pgvector" | ||
) | ||
|
||
func main() { | ||
var dsn string | ||
flag.StringVar(&dsn, "dsn", "", "PGvector connection string") | ||
flag.Parse() | ||
model, err := mistral.New() | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
e, err := embeddings.NewEmbedder(model) | ||
|
||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
// Create a new pgvector store. | ||
ctx := context.Background() | ||
store, err := pgvector.New( | ||
ctx, | ||
pgvector.WithConnectionURL(dsn), | ||
pgvector.WithEmbedder(e), | ||
) | ||
if err != nil { | ||
log.Fatal("pgvector.New", err) | ||
} | ||
|
||
// Add documents to the pgvector store. | ||
_, err = store.AddDocuments(context.Background(), []schema.Document{ | ||
{ | ||
PageContent: "Tokyo", | ||
Metadata: map[string]any{ | ||
"population": 38, | ||
"area": 2190, | ||
}, | ||
}, | ||
{ | ||
PageContent: "Paris", | ||
Metadata: map[string]any{ | ||
"population": 11, | ||
"area": 105, | ||
}, | ||
}, | ||
{ | ||
PageContent: "London", | ||
Metadata: map[string]any{ | ||
"population": 9.5, | ||
"area": 1572, | ||
}, | ||
}, | ||
{ | ||
PageContent: "Santiago", | ||
Metadata: map[string]any{ | ||
"population": 6.9, | ||
"area": 641, | ||
}, | ||
}, | ||
{ | ||
PageContent: "Buenos Aires", | ||
Metadata: map[string]any{ | ||
"population": 15.5, | ||
"area": 203, | ||
}, | ||
}, | ||
{ | ||
PageContent: "Rio de Janeiro", | ||
Metadata: map[string]any{ | ||
"population": 13.7, | ||
"area": 1200, | ||
}, | ||
}, | ||
{ | ||
PageContent: "Sao Paulo", | ||
Metadata: map[string]any{ | ||
"population": 22.6, | ||
"area": 1523, | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
log.Fatal("store.AddDocuments:\n", err) | ||
} | ||
time.Sleep(1 * time.Second) | ||
|
||
// Search for similar documents. | ||
docs, err := store.SimilaritySearch(ctx, "japan", 1) | ||
if err != nil { | ||
log.Fatal("store.SimilaritySearch1:\n", err) | ||
} | ||
fmt.Println("store.SimilaritySearch1:\n", docs) | ||
|
||
time.Sleep(2 * time.Second) // Don't trigger cloudflare | ||
|
||
// Search for similar documents using score threshold. | ||
docs, err = store.SimilaritySearch(ctx, "only cities in south america", 3, vectorstores.WithScoreThreshold(0.50)) | ||
if err != nil { | ||
log.Fatal("store.SimilaritySearch2:\n", err) | ||
} | ||
fmt.Println("store.SimilaritySearch2:\n", docs) | ||
|
||
time.Sleep(3 * time.Second) // Don't trigger cloudflare | ||
|
||
// Search for similar documents using score threshold and metadata filter. | ||
// Metadata filter for pgvector only supports key-value pairs for now. | ||
filter := map[string]any{"area": "1523"} // Sao Paulo | ||
|
||
docs, err = store.SimilaritySearch(ctx, "only cities in south america", | ||
3, | ||
vectorstores.WithScoreThreshold(0.50), | ||
vectorstores.WithFilters(filter), | ||
) | ||
if err != nil { | ||
log.Fatal("store.SimilaritySearch3:\n", err) | ||
} | ||
fmt.Println("store.SimilaritySearch3:\n", docs) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package mistral | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
) | ||
|
||
var ErrEmptyEmbeddings = errors.New("empty embeddings") | ||
|
||
func convertFloat64ToFloat32(input []float64) []float32 { | ||
// Create a slice with the same length as the input. | ||
output := make([]float32, len(input)) | ||
|
||
// Iterate over the input slice and convert each element. | ||
for i, v := range input { | ||
output[i] = float32(v) | ||
} | ||
|
||
return output | ||
} | ||
|
||
// CreateEmbedding implements the embeddings.EmbedderClient interface and creates embeddings for the given input texts. | ||
func (m *Model) CreateEmbedding(_ context.Context, inputTexts []string) ([][]float32, error) { | ||
embsRes, err := m.client.Embeddings("mistral-embed", inputTexts) | ||
if err != nil { | ||
return nil, errors.New("failed to create embeddings: " + err.Error()) | ||
} | ||
allEmbds := make([][]float32, len(embsRes.Data)) | ||
for i, embs := range embsRes.Data { | ||
if len(embs.Embedding) == 0 { | ||
return nil, ErrEmptyEmbeddings | ||
} | ||
allEmbds[i] = convertFloat64ToFloat32(embs.Embedding) | ||
} | ||
return allEmbds, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package mistral | ||
|
||
import ( | ||
"context" | ||
"os" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
"github.com/tmc/langchaingo/embeddings" | ||
) | ||
|
||
// TestConvertFloat64ToFloat32 tests the ConvertFloat64ToFloat32 function using table-driven tests. | ||
func TestConvertFloat64ToFloat32(t *testing.T) { | ||
t.Parallel() | ||
tests := []struct { | ||
name string | ||
input []float64 | ||
expected []float32 | ||
}{ | ||
{ | ||
name: "empty slice", | ||
input: []float64{}, | ||
expected: []float32{}, | ||
}, | ||
{ | ||
name: "single element", | ||
input: []float64{3.14}, | ||
expected: []float32{3.14}, | ||
}, | ||
{ | ||
name: "multiple elements", | ||
input: []float64{1.23, 4.56, 7.89}, | ||
expected: []float32{1.23, 4.56, 7.89}, | ||
}, | ||
{ | ||
name: "zero values", | ||
input: []float64{0.0, 0.0, 0.0}, | ||
expected: []float32{0.0, 0.0, 0.0}, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
t.Parallel() | ||
output := convertFloat64ToFloat32(tt.input) | ||
|
||
require.Equal(t, len(tt.expected), len(output), "length mismatch") | ||
for i := range output { | ||
require.Equal(t, tt.expected[i], output[i], "at index %d", i) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestMistralEmbed(t *testing.T) { | ||
t.Parallel() | ||
envVar := "MISTRAL_API_KEY" | ||
|
||
// Get the value of the environment variable | ||
value := os.Getenv(envVar) | ||
|
||
// Check if it is set (non-empty) | ||
if value == "" { | ||
t.Skipf("Environment variable %s is not set, so skipping the test", envVar) | ||
return | ||
} | ||
|
||
model, err := New() | ||
require.NoError(t, err) | ||
|
||
e, err := embeddings.NewEmbedder(model) | ||
require.NoError(t, err) | ||
|
||
_, err = e.EmbedDocuments(context.Background(), []string{"Hello world"}) | ||
require.NoError(t, err) | ||
|
||
_, err = e.EmbedQuery(context.Background(), "Hello world") | ||
require.NoError(t, err) | ||
} |