Skip to content

Commit

Permalink
Merge pull request #526 from corani/corani/weaviate
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 20, 2024
2 parents c9d054a + a7cd70c commit cc899e5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
9 changes: 9 additions & 0 deletions embeddings/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ type EmbedderClient interface {
CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error)
}

// EmbedderClientFunc is an adapter to allow the use of ordinary functions as Embedder Clients. If
// `f` is a function with the appropriate signature, `EmbedderClientFunc(f)` is an `EmbedderClient`
// that calls `f`.
type EmbedderClientFunc func(ctx context.Context, texts []string) ([][]float32, error)

func (e EmbedderClientFunc) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) {
return e(ctx, texts)
}

type EmbedderImpl struct {
client EmbedderClient

Expand Down
10 changes: 7 additions & 3 deletions vectorstores/weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (s Store) AddDocuments(ctx context.Context,
texts = append(texts, doc.PageContent)
}

vectors, err := s.embedder.EmbedDocuments(ctx, texts)
vectors, err := opts.Embedder.EmbedDocuments(ctx, texts)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -166,7 +166,7 @@ func (s Store) SimilaritySearch(
return nil, err
}

vector, err := s.embedder.EmbedQuery(ctx, query)
vector, err := opts.Embedder.EmbedQuery(ctx, query)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -300,7 +300,11 @@ func (s Store) getFilters(opts vectorstores.Options) any {
}

func (s Store) getOptions(options ...vectorstores.Option) vectorstores.Options {
opts := vectorstores.Options{}
// use the embedder from the store by default, this can be overwritten by passing
// an `vectorstores.WithEmbedder` option.
opts := vectorstores.Options{
Embedder: s.embedder,
}
for _, opt := range options {
opt(&opts)
}
Expand Down
55 changes: 55 additions & 0 deletions vectorstores/weaviate/weaviate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,58 @@ func TestWeaviateStoreAdditionalFieldsAdded(t *testing.T) {
require.NotEmpty(t, additional["certainty"], "expected the certainty to be present")
require.NotEmpty(t, additional["distance"], "expected the distance to be present")
}

// TestWeaviateWithOptionEmbedder ensures that the embedder provided as an option to either
// `AddDocuments` or `SimilaritySearch` takes precedence over the one provided when creating
// the `Store`.
func TestWeaviateWithOptionEmbedder(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)

llm, err := openai.New()
require.NoError(t, err)

notme, err := embeddings.NewEmbedder(
embeddings.EmbedderClientFunc(func(context.Context, []string) ([][]float32, error) {
require.FailNow(t, "wrong embedder was called")
return nil, nil
}),
)
require.NoError(t, err)

butme, err := embeddings.NewEmbedder(
embeddings.EmbedderClientFunc(func(ctx context.Context, texts []string) ([][]float32, error) {
return llm.CreateEmbedding(ctx, texts)
}),
)
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithEmbedder(notme),
WithNameSpace(uuid.New().String()),
WithIndexName(randomizedCamelCaseClass()),
WithQueryAttrs([]string{"location"}),
)
require.NoError(t, err)

err = createTestClass(context.Background(), store)
require.NoError(t, err)

_, err = store.AddDocuments(context.Background(), []schema.Document{
{PageContent: "tokyo", Metadata: map[string]any{
"country": "japan",
}},
{PageContent: "potato"},
}, vectorstores.WithEmbedder(butme))
require.NoError(t, err)

docs, err := store.SimilaritySearch(context.Background(), "japan", 1,
vectorstores.WithEmbedder(butme))
require.NoError(t, err)
require.Len(t, docs, 1)
require.Equal(t, "tokyo", docs[0].PageContent)
require.Equal(t, "japan", docs[0].Metadata["country"])
}

0 comments on commit cc899e5

Please sign in to comment.