diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1b1c3c0d6..c1a68716f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,7 +33,7 @@ All types of contributions are encouraged and valued. See the [Table of Contents ## Code of Conduct This project and everyone participating in it is governed by the -[langchaingo Code of Conduct](/CODE_OF_CONDUCT.md). +[langchaingo Code of Conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to . diff --git a/vectorstores/weaviate/options.go b/vectorstores/weaviate/options.go index d09389645..f099b4513 100644 --- a/vectorstores/weaviate/options.go +++ b/vectorstores/weaviate/options.go @@ -105,6 +105,13 @@ func WithQueryAttrs(queryAttrs []string) Option { } } +// WithAdditionalFields is an option for setting additional fields query attributes of the weaviate server. +func WithAdditionalFields(additionalFields []string) Option { + return func(p *Store) { + p.additionalFields = additionalFields + } +} + func applyClientOptions(opts ...Option) (Store, error) { o := &Store{ textKey: _defaultTextKey, @@ -143,5 +150,31 @@ func applyClientOptions(opts ...Option) (Store, error) { o.queryAttrs = append(o.queryAttrs, o.nameSpaceKey) } + // add additional fields + defaultAdditionalFields := []string{"certainty"} + + if o.additionalFields == nil { + o.additionalFields = defaultAdditionalFields + } else { + o.additionalFields = mergeValuesAsUnique(defaultAdditionalFields, o.additionalFields) + } + return *o, nil } + +func mergeValuesAsUnique(collections ...[]string) []string { + valueMap := make(map[string]bool) + + for _, collection := range collections { + for _, value := range collection { + valueMap[value] = true + } + } + + uniqueValues := make([]string, 0, len(valueMap)) + for k := range valueMap { + uniqueValues = append(uniqueValues, k) + } + + return uniqueValues +} diff --git a/vectorstores/weaviate/weaviate.go b/vectorstores/weaviate/weaviate.go index b0020abc4..5a22beea9 100644 --- a/vectorstores/weaviate/weaviate.go +++ b/vectorstores/weaviate/weaviate.go @@ -57,7 +57,8 @@ type Store struct { connectionClient *http.Client // optional - queryAttrs []string + queryAttrs []string + additionalFields []string } var _ vectorstores.VectorStore = Store{} @@ -273,11 +274,18 @@ func (s Store) createFields() []graphql.Field { Name: attr, }) } + + additionalFields := make([]graphql.Field, 0, len(s.additionalFields)) + for _, attr := range s.additionalFields { + additionalFields = append(additionalFields, graphql.Field{ + Name: attr, + }) + } + fields = append(fields, graphql.Field{ - Name: "_additional", - Fields: []graphql.Field{ - {Name: "certainty"}, - }, + Name: "_additional", + Fields: additionalFields, }) + return fields } diff --git a/vectorstores/weaviate/weaviate_test.go b/vectorstores/weaviate/weaviate_test.go index f00b608ab..174c92357 100644 --- a/vectorstores/weaviate/weaviate_test.go +++ b/vectorstores/weaviate/weaviate_test.go @@ -545,3 +545,88 @@ func TestWeaviateAsRetrieverWithMetadataFilters(t *testing.T) { require.NotContains(t, result, "orange", "expected not orange in result") require.NotContains(t, result, "yellow", "expected not yellow in result") } + +func TestWeaviateStoreAdditionalFieldsDefaults(t *testing.T) { + t.Parallel() + + scheme, host := getValues(t) + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := New( + WithScheme(scheme), + WithHost(host), + WithEmbedder(e), + WithNameSpace(uuid.New().String()), + WithIndexName(randomizedCamelCaseClass()), + ) + require.NoError(t, err) + + err = createTestClass(context.Background(), store) + require.NoError(t, err) + + _, err = store.AddDocuments(context.Background(), []schema.Document{ + {PageContent: "Foo"}, + }) + require.NoError(t, err) + + // Check if the default additional fields are present in the result + docs, err := store.SimilaritySearch(context.Background(), + "Foo", 1) + require.NoError(t, err) + require.Len(t, docs, 1) + + additional, ok := docs[0].Metadata["_additional"].(map[string]any) + require.True(t, ok, "expected '_additional' to be present in the metadata and parsable as 'map[string]any'") + require.Len(t, additional, 1) + + certainty, _ := additional["certainty"].(float64) + require.InDelta(t, docs[0].Score, float32(certainty), 0, "expect score to be equal to the certainty") +} + +func TestWeaviateStoreAdditionalFieldsAdded(t *testing.T) { + t.Parallel() + + scheme, host := getValues(t) + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := New( + WithScheme(scheme), + WithHost(host), + WithEmbedder(e), + WithNameSpace(uuid.New().String()), + WithIndexName(randomizedCamelCaseClass()), + WithAdditionalFields([]string{"id", "vector", "certainty", "distance"}), + ) + require.NoError(t, err) + + err = createTestClass(context.Background(), store) + require.NoError(t, err) + + _, err = store.AddDocuments(context.Background(), []schema.Document{ + {PageContent: "Foo"}, + }) + require.NoError(t, err) + + // Check if all the additional fields are present in the result + docs, err := store.SimilaritySearch(context.Background(), + "Foo", 1) + require.NoError(t, err) + require.Len(t, docs, 1) + + additional, ok := docs[0].Metadata["_additional"].(map[string]any) + require.True(t, ok, "expected '_additional' to be present in the metadata and parsable as 'map[string]any'") + require.Len(t, additional, 4) + + require.NotEmpty(t, additional["id"], "expected the id to be present") + require.NotEmpty(t, additional["vector"], "expected the vector to be present") + require.NotEmpty(t, additional["certainty"], "expected the certainty to be present") + require.NotEmpty(t, additional["distance"], "expected the distance to be present") +}