Skip to content

Commit

Permalink
Merge pull request #522 from jvgrootveld/feature/store-weaviate-addit…
Browse files Browse the repository at this point in the history
…ional-fields-option
  • Loading branch information
tmc authored Jan 18, 2024
2 parents d41e440 + 0936014 commit b63d845
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ All types of contributions are encouraged and valued. See the [Table of Contents
## Code of Conduct

This project and everyone participating in it is governed by the
[langchaingo Code of Conduct](/CODE_OF_CONDUCT.md).
[langchaingo Code of Conduct](CODE_OF_CONDUCT.md).
By participating, you are expected to uphold this code. Please report unacceptable behavior
to <travis.cline@gmail.com>.

Expand Down
33 changes: 33 additions & 0 deletions vectorstores/weaviate/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ func WithQueryAttrs(queryAttrs []string) Option {
}
}

// WithAdditionalFields is an option for setting additional fields query attributes of the weaviate server.
func WithAdditionalFields(additionalFields []string) Option {
return func(p *Store) {
p.additionalFields = additionalFields
}
}

func applyClientOptions(opts ...Option) (Store, error) {
o := &Store{
textKey: _defaultTextKey,
Expand Down Expand Up @@ -143,5 +150,31 @@ func applyClientOptions(opts ...Option) (Store, error) {
o.queryAttrs = append(o.queryAttrs, o.nameSpaceKey)
}

// add additional fields
defaultAdditionalFields := []string{"certainty"}

if o.additionalFields == nil {
o.additionalFields = defaultAdditionalFields
} else {
o.additionalFields = mergeValuesAsUnique(defaultAdditionalFields, o.additionalFields)
}

return *o, nil
}

func mergeValuesAsUnique(collections ...[]string) []string {
valueMap := make(map[string]bool)

for _, collection := range collections {
for _, value := range collection {
valueMap[value] = true
}
}

uniqueValues := make([]string, 0, len(valueMap))
for k := range valueMap {
uniqueValues = append(uniqueValues, k)
}

return uniqueValues
}
18 changes: 13 additions & 5 deletions vectorstores/weaviate/weaviate.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ type Store struct {
connectionClient *http.Client

// optional
queryAttrs []string
queryAttrs []string
additionalFields []string
}

var _ vectorstores.VectorStore = Store{}
Expand Down Expand Up @@ -273,11 +274,18 @@ func (s Store) createFields() []graphql.Field {
Name: attr,
})
}

additionalFields := make([]graphql.Field, 0, len(s.additionalFields))
for _, attr := range s.additionalFields {
additionalFields = append(additionalFields, graphql.Field{
Name: attr,
})
}

fields = append(fields, graphql.Field{
Name: "_additional",
Fields: []graphql.Field{
{Name: "certainty"},
},
Name: "_additional",
Fields: additionalFields,
})

return fields
}
85 changes: 85 additions & 0 deletions vectorstores/weaviate/weaviate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,88 @@ func TestWeaviateAsRetrieverWithMetadataFilters(t *testing.T) {
require.NotContains(t, result, "orange", "expected not orange in result")
require.NotContains(t, result, "yellow", "expected not yellow in result")
}

func TestWeaviateStoreAdditionalFieldsDefaults(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)

llm, err := openai.New()
require.NoError(t, err)
e, err := embeddings.NewEmbedder(llm)
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
WithIndexName(randomizedCamelCaseClass()),
)
require.NoError(t, err)

err = createTestClass(context.Background(), store)
require.NoError(t, err)

_, err = store.AddDocuments(context.Background(), []schema.Document{
{PageContent: "Foo"},
})
require.NoError(t, err)

// Check if the default additional fields are present in the result
docs, err := store.SimilaritySearch(context.Background(),
"Foo", 1)
require.NoError(t, err)
require.Len(t, docs, 1)

additional, ok := docs[0].Metadata["_additional"].(map[string]any)
require.True(t, ok, "expected '_additional' to be present in the metadata and parsable as 'map[string]any'")
require.Len(t, additional, 1)

certainty, _ := additional["certainty"].(float64)
require.InDelta(t, docs[0].Score, float32(certainty), 0, "expect score to be equal to the certainty")
}

func TestWeaviateStoreAdditionalFieldsAdded(t *testing.T) {
t.Parallel()

scheme, host := getValues(t)

llm, err := openai.New()
require.NoError(t, err)
e, err := embeddings.NewEmbedder(llm)
require.NoError(t, err)

store, err := New(
WithScheme(scheme),
WithHost(host),
WithEmbedder(e),
WithNameSpace(uuid.New().String()),
WithIndexName(randomizedCamelCaseClass()),
WithAdditionalFields([]string{"id", "vector", "certainty", "distance"}),
)
require.NoError(t, err)

err = createTestClass(context.Background(), store)
require.NoError(t, err)

_, err = store.AddDocuments(context.Background(), []schema.Document{
{PageContent: "Foo"},
})
require.NoError(t, err)

// Check if all the additional fields are present in the result
docs, err := store.SimilaritySearch(context.Background(),
"Foo", 1)
require.NoError(t, err)
require.Len(t, docs, 1)

additional, ok := docs[0].Metadata["_additional"].(map[string]any)
require.True(t, ok, "expected '_additional' to be present in the metadata and parsable as 'map[string]any'")
require.Len(t, additional, 4)

require.NotEmpty(t, additional["id"], "expected the id to be present")
require.NotEmpty(t, additional["vector"], "expected the vector to be present")
require.NotEmpty(t, additional["certainty"], "expected the certainty to be present")
require.NotEmpty(t, additional["distance"], "expected the distance to be present")
}

0 comments on commit b63d845

Please sign in to comment.