diff --git a/embeddings/embedding.go b/embeddings/embedding.go index 96f8c171f..c3d7198d7 100644 --- a/embeddings/embedding.go +++ b/embeddings/embedding.go @@ -35,6 +35,15 @@ type EmbedderClient interface { CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) } +// EmbedderClientFunc is an adapter to allow the use of ordinary functions as Embedder Clients. If +// `f` is a function with the appropriate signature, `EmbedderClientFunc(f)` is an `EmbedderClient` +// that calls `f`. +type EmbedderClientFunc func(ctx context.Context, texts []string) ([][]float32, error) + +func (e EmbedderClientFunc) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) { + return e(ctx, texts) +} + type EmbedderImpl struct { client EmbedderClient diff --git a/vectorstores/weaviate/weaviate.go b/vectorstores/weaviate/weaviate.go index 6d6849975..55790deca 100644 --- a/vectorstores/weaviate/weaviate.go +++ b/vectorstores/weaviate/weaviate.go @@ -109,7 +109,7 @@ func (s Store) AddDocuments(ctx context.Context, texts = append(texts, doc.PageContent) } - vectors, err := s.embedder.EmbedDocuments(ctx, texts) + vectors, err := opts.Embedder.EmbedDocuments(ctx, texts) if err != nil { return nil, err } @@ -166,7 +166,7 @@ func (s Store) SimilaritySearch( return nil, err } - vector, err := s.embedder.EmbedQuery(ctx, query) + vector, err := opts.Embedder.EmbedQuery(ctx, query) if err != nil { return nil, err } @@ -300,7 +300,11 @@ func (s Store) getFilters(opts vectorstores.Options) any { } func (s Store) getOptions(options ...vectorstores.Option) vectorstores.Options { - opts := vectorstores.Options{} + // use the embedder from the store by default, this can be overwritten by passing + // an `vectorstores.WithEmbedder` option. + opts := vectorstores.Options{ + Embedder: s.embedder, + } for _, opt := range options { opt(&opts) } diff --git a/vectorstores/weaviate/weaviate_test.go b/vectorstores/weaviate/weaviate_test.go index ef158dbf7..8505a4ecf 100644 --- a/vectorstores/weaviate/weaviate_test.go +++ b/vectorstores/weaviate/weaviate_test.go @@ -718,3 +718,58 @@ func TestWeaviateStoreAdditionalFieldsAdded(t *testing.T) { require.NotEmpty(t, additional["certainty"], "expected the certainty to be present") require.NotEmpty(t, additional["distance"], "expected the distance to be present") } + +// TestWeaviateWithOptionEmbedder ensures that the embedder provided as an option to either +// `AddDocuments` or `SimilaritySearch` takes precedence over the one provided when creating +// the `Store`. +func TestWeaviateWithOptionEmbedder(t *testing.T) { + t.Parallel() + + scheme, host := getValues(t) + + llm, err := openai.New() + require.NoError(t, err) + + notme, err := embeddings.NewEmbedder( + embeddings.EmbedderClientFunc(func(context.Context, []string) ([][]float32, error) { + require.FailNow(t, "wrong embedder was called") + return nil, nil + }), + ) + require.NoError(t, err) + + butme, err := embeddings.NewEmbedder( + embeddings.EmbedderClientFunc(func(ctx context.Context, texts []string) ([][]float32, error) { + return llm.CreateEmbedding(ctx, texts) + }), + ) + require.NoError(t, err) + + store, err := New( + WithScheme(scheme), + WithHost(host), + WithEmbedder(notme), + WithNameSpace(uuid.New().String()), + WithIndexName(randomizedCamelCaseClass()), + WithQueryAttrs([]string{"location"}), + ) + require.NoError(t, err) + + err = createTestClass(context.Background(), store) + require.NoError(t, err) + + _, err = store.AddDocuments(context.Background(), []schema.Document{ + {PageContent: "tokyo", Metadata: map[string]any{ + "country": "japan", + }}, + {PageContent: "potato"}, + }, vectorstores.WithEmbedder(butme)) + require.NoError(t, err) + + docs, err := store.SimilaritySearch(context.Background(), "japan", 1, + vectorstores.WithEmbedder(butme)) + require.NoError(t, err) + require.Len(t, docs, 1) + require.Equal(t, "tokyo", docs[0].PageContent) + require.Equal(t, "japan", docs[0].Metadata["country"]) +}