Skip to content

Commit

Permalink
Merge pull request #527 from corani/corani/weaviate_exists
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 20, 2024
2 parents 46d563d + 14ba46b commit c9d054a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
17 changes: 16 additions & 1 deletion vectorstores/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package vectorstores

import "github.com/tmc/langchaingo/embeddings"
import (
"context"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/schema"
)

// Option is a function that configures an Options.
type Option func(*Options)
Expand All @@ -11,6 +16,7 @@ type Options struct {
ScoreThreshold float32
Filters any
Embedder embeddings.Embedder
Deduplicater func(context.Context, schema.Document) bool
}

// WithNameSpace returns an Option for setting the name space.
Expand Down Expand Up @@ -44,3 +50,12 @@ func WithEmbedder(embedder embeddings.Embedder) Option {
o.Embedder = embedder
}
}

// WithDeduplicater returns an Option for setting the deduplicater that could be used
// when adding documents. This is useful to prevent wasting time on creating an embedding
// when one already exists.
func WithDeduplicater(fn func(ctx context.Context, doc schema.Document) bool) Option {
return func(o *Options) {
o.Deduplicater = fn
}
}
55 changes: 55 additions & 0 deletions vectorstores/weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ func (s Store) AddDocuments(ctx context.Context,
opts := s.getOptions(options...)
nameSpace := s.getNameSpace(opts)

docs = s.deduplicate(ctx, opts, docs)

if len(docs) == 0 {
// nothing to add (perhaps all documents were duplicates). This is not
// an error.
return nil, nil
}

texts := make([]string, 0, len(docs))
for _, doc := range docs {
texts = append(texts, doc.PageContent)
Expand Down Expand Up @@ -180,6 +188,35 @@ func (s Store) SimilaritySearch(
return s.parseDocumentsByGraphQLResponse(res)
}

// MetadataSearch searches weaviate based on metadata rather than based on similarity.
// Use `vectorstores.WithFilter(*filters.WhereBuilder)` to provide a where condition
// as an option.
func (s Store) MetadataSearch(
ctx context.Context,
numDocuments int,
options ...vectorstores.Option,
) ([]schema.Document, error) {
opts := s.getOptions(options...)
nameSpace := s.getNameSpace(opts)
filter := s.getFilters(opts)
whereBuilder, err := s.createWhereBuilder(nameSpace, filter)
if err != nil {
return nil, err
}
res, err := s.client.GraphQL().
Get().
WithWhere(whereBuilder).
WithClassName(s.indexName).
WithLimit(numDocuments).
WithFields(s.createFields()...).
Do(ctx)
if err != nil {
return nil, err
}

return s.parseDocumentsByGraphQLResponse(res)
}

//nolint:cyclop
func (s Store) parseDocumentsByGraphQLResponse(res *models.GraphQLResponse) ([]schema.Document, error) {
if len(res.Errors) > 0 {
Expand Down Expand Up @@ -223,6 +260,24 @@ func (s Store) parseDocumentsByGraphQLResponse(res *models.GraphQLResponse) ([]s
return docs, nil
}

func (s Store) deduplicate(ctx context.Context,
opts vectorstores.Options,
docs []schema.Document,
) []schema.Document {
if opts.Deduplicater == nil {
return docs
}

filtered := make([]schema.Document, 0, len(docs))
for _, doc := range docs {
if !opts.Deduplicater(ctx, doc) {
filtered = append(filtered, doc)
}
}

return filtered
}

func (s Store) getNameSpace(opts vectorstores.Options) string {
if opts.NameSpace != "" {
return opts.NameSpace
Expand Down
88 changes: 88 additions & 0 deletions vectorstores/weaviate/weaviate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,94 @@ func TestWeaviateStoreRestWithScoreThreshold(t *testing.T) {
require.Len(t, docs, 10)
}

func TestMetadataSearch(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)
llm, err := openai.New()
require.NoError(t, err)
e, err := embeddings.NewEmbedder(llm)
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
WithIndexName(randomizedCamelCaseClass()),
WithQueryAttrs([]string{"type"}),
)
require.NoError(t, err)

err = createTestClass(context.Background(), store)
require.NoError(t, err)

_, err = store.AddDocuments(context.Background(), []schema.Document{
{PageContent: "tokyo", Metadata: map[string]any{
"type": "city",
}},
{PageContent: "potato", Metadata: map[string]any{
"type": "vegetable",
}},
})
require.NoError(t, err)

docs, err := store.MetadataSearch(context.Background(), 2,
vectorstores.WithFilters(
filters.Where().
WithPath([]string{"type"}).
WithOperator(filters.Equal).
WithValueString("city"),
))
require.NoError(t, err)
require.Len(t, docs, 1)
require.Equal(t, "tokyo", docs[0].PageContent)
require.Equal(t, "city", docs[0].Metadata["type"])
}

func TestDeduplicater(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)
llm, err := openai.New()
require.NoError(t, err)
e, err := embeddings.NewEmbedder(llm)
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
WithIndexName(randomizedCamelCaseClass()),
WithQueryAttrs([]string{"type"}),
)
require.NoError(t, err)

err = createTestClass(context.Background(), store)
require.NoError(t, err)

_, err = store.AddDocuments(context.Background(), []schema.Document{
{PageContent: "tokyo", Metadata: map[string]any{
"type": "city",
}},
{PageContent: "potato", Metadata: map[string]any{
"type": "vegetable",
}},
}, vectorstores.WithDeduplicater(
func(ctx context.Context, doc schema.Document) bool {
return doc.PageContent == "tokyo"
},
))
require.NoError(t, err)

docs, err := store.MetadataSearch(context.Background(), 2)
require.NoError(t, err)
require.Len(t, docs, 1)
require.Equal(t, "potato", docs[0].PageContent)
require.Equal(t, "vegetable", docs[0].Metadata["type"])
}

func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit c9d054a

Please sign in to comment.