Skip to content

Commit

Permalink
update pgvector Follow reviewer's advice
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <fp544037857@gmail.com>
  • Loading branch information
Abirdcfly committed Dec 3, 2023
1 parent 46b24e7 commit 279e046
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
26 changes: 17 additions & 9 deletions vectorstores/pgvector/pgvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ PRIMARY KEY (uuid))`, s.embeddingTableName, s.collectionTableName)

func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options ...vectorstores.Option) error {
opts := s.getOptions(options...)
if opts.Embedder != nil || opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" {
if opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" {
return ErrUnsupportedOptions
}

Expand All @@ -160,7 +160,11 @@ func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options
texts = append(texts, doc.PageContent)
}

vectors, err := s.embedder.EmbedDocuments(ctx, texts)
embedder := s.embedder
if opts.Embedder != nil {
embedder = opts.Embedder
}
vectors, err := embedder.EmbedDocuments(ctx, texts)
if err != nil {
return err
}
Expand All @@ -186,11 +190,8 @@ func (s Store) SimilaritySearch(
numDocuments int,
options ...vectorstores.Option,
) ([]schema.Document, error) {
collectionName := s.collectionName
opts := s.getOptions(options...)
if nameSpace := s.getNameSpace(opts); nameSpace != "" {
collectionName = nameSpace
}
collectionName := s.getNameSpace(opts)
scoreThreshold, err := s.getScoreThreshold(opts)
if err != nil {
return nil, err
Expand All @@ -199,7 +200,11 @@ func (s Store) SimilaritySearch(
if err != nil {
return nil, err
}
embedder, err := s.embedder.EmbedQuery(ctx, query)
embedder := s.embedder
if opts.Embedder != nil {
embedder = opts.Embedder
}
embedderData, err := embedder.EmbedQuery(ctx, query)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -236,7 +241,7 @@ LIMIT %d`, s.embeddingTableName,
s.embeddingTableName,
s.collectionTableName, s.embeddingTableName, s.collectionTableName, s.collectionTableName, collectionName,
whereQuery, numDocuments)
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedder))
rows, err := tx.Query(ctx, sql, pgvector.NewVector(embedderData))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -295,6 +300,8 @@ func (s Store) createOrGetCollection(ctx context.Context) (string, error) {
return collectionUUID, nil
}

// getOptions applies given options to default Options and returns it
// This uses options pattern so clients can easily pass options without changing function signature.
func (s Store) getOptions(options ...vectorstores.Option) vectorstores.Options {
opts := vectorstores.Options{}
for _, opt := range options {
Expand All @@ -307,7 +314,7 @@ func (s Store) getNameSpace(opts vectorstores.Options) string {
if opts.NameSpace != "" {
return opts.NameSpace
}
return ""
return s.collectionName
}

func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) {
Expand All @@ -317,6 +324,7 @@ func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) {
return opts.ScoreThreshold, nil
}

// getFilters return metadata filters, now only support map[key]value pattern
// TODO: should support more types like {"key1": {"key2":"values2"}} or {"key": ["value1", "values2"]}.
func (s Store) getFilters(opts vectorstores.Options) (map[string]any, error) {
if opts.Filters != nil {
Expand Down
16 changes: 8 additions & 8 deletions vectorstores/pgvector/pgvector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func preCheckEnvSetting(t *testing.T) {
}
}

func getTestCollectionName() string {
func makeNewCollectionName() string {
return fmt.Sprintf("test-collection-%s", uuid.New().String())
}

Expand All @@ -54,7 +54,7 @@ func TestPgvectorStoreRest(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down Expand Up @@ -89,7 +89,7 @@ func TestPgvectorStoreRestWithScoreThreshold(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down Expand Up @@ -143,7 +143,7 @@ func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down Expand Up @@ -194,7 +194,7 @@ func TestPgvectorAsRetriever(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down Expand Up @@ -236,7 +236,7 @@ func TestPgvectorAsRetrieverWithScoreThreshold(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down Expand Up @@ -283,7 +283,7 @@ func TestPgvectorAsRetrieverWithMetadataFilterNotSelected(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down Expand Up @@ -357,7 +357,7 @@ func TestPgvectorAsRetrieverWithMetadataFilters(t *testing.T) {
ctx,
pgvector.WithEmbedder(e),
pgvector.WithPreDeleteCollection(true),
pgvector.WithCollectionName(getTestCollectionName()),
pgvector.WithCollectionName(makeNewCollectionName()),
)
require.NoError(t, err)

Expand Down

0 comments on commit 279e046

Please sign in to comment.