-
-
Notifications
You must be signed in to change notification settings - Fork 745
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added Qdrant vectorstore support
- Loading branch information
Showing
6 changed files
with
772 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
// Package qdrant contains an implementation of the VectorStore | ||
// interface using Qdrant. | ||
package qdrant |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package qdrant | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"net/url" | ||
|
||
"github.com/tmc/langchaingo/embeddings" | ||
) | ||
|
||
const ( | ||
defaultContentKey = "content" | ||
) | ||
|
||
// ErrInvalidOptions is returned when the options given are invalid. | ||
var ErrInvalidOptions = errors.New("invalid options") | ||
|
||
// Option is a function that configures an Options. | ||
type Option func(p *Store) | ||
|
||
// WithCollectionName returns an Option for setting the collection name. Required. | ||
func WithCollectionName(name string) Option { | ||
return func(p *Store) { | ||
p.collectionName = name | ||
} | ||
} | ||
|
||
// WithURL returns an Option for setting the Qdrant instance URL. | ||
// Example: 'http://localhost:63333'. Required. | ||
func WithURL(qdrantURL url.URL) Option { | ||
return func(p *Store) { | ||
p.qdrantURL = qdrantURL | ||
} | ||
} | ||
|
||
// WithEmbedder returns an Option for setting the embedder to be used when | ||
// adding documents or doing similarity search. Required. | ||
func WithEmbedder(embedder embeddings.Embedder) Option { | ||
return func(p *Store) { | ||
p.embedder = embedder | ||
} | ||
} | ||
|
||
// WithAPIKey returns an Option for setting the API key to authenticate the connection. Optional. | ||
func WithAPIKey(apiKey string) Option { | ||
return func(p *Store) { | ||
p.apiKey = apiKey | ||
} | ||
} | ||
|
||
// WithContent returns an Option for setting field name of the document content | ||
// in the Qdrant payload. Optional. Defaults to "content". | ||
func WithContentKey(contentKey string) Option { | ||
return func(p *Store) { | ||
p.contentKey = contentKey | ||
} | ||
} | ||
|
||
func applyClientOptions(opts ...Option) (Store, error) { | ||
o := &Store{ | ||
contentKey: defaultContentKey, | ||
} | ||
|
||
for _, opt := range opts { | ||
opt(o) | ||
} | ||
|
||
if o.collectionName == "" { | ||
return Store{}, fmt.Errorf("%w: missing collection name", ErrInvalidOptions) | ||
} | ||
|
||
if o.qdrantURL == (url.URL{}) { | ||
return Store{}, fmt.Errorf("%w: missing Qdrant URL", ErrInvalidOptions) | ||
} | ||
|
||
if o.embedder == nil { | ||
return Store{}, fmt.Errorf("%w: missing embedder", ErrInvalidOptions) | ||
} | ||
|
||
return *o, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
package qdrant | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net/url" | ||
|
||
"github.com/tmc/langchaingo/embeddings" | ||
"github.com/tmc/langchaingo/schema" | ||
"github.com/tmc/langchaingo/vectorstores" | ||
) | ||
|
||
type Store struct { | ||
embedder embeddings.Embedder | ||
collectionName string | ||
qdrantURL url.URL | ||
apiKey string | ||
contentKey string | ||
} | ||
|
||
var _ vectorstores.VectorStore = Store{} | ||
|
||
func New(opts ...Option) (Store, error) { | ||
s, err := applyClientOptions(opts...) | ||
if err != nil { | ||
return Store{}, err | ||
} | ||
return s, nil | ||
} | ||
|
||
func (s Store) AddDocuments(ctx context.Context, | ||
docs []schema.Document, | ||
_ ...vectorstores.Option, | ||
) ([]string, error) { | ||
texts := make([]string, 0, len(docs)) | ||
for _, doc := range docs { | ||
texts = append(texts, doc.PageContent) | ||
} | ||
|
||
vectors, | ||
err := s.embedder.EmbedDocuments(ctx, texts) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if len(vectors) != len(docs) { | ||
return nil, errors.New("number of vectors from embedder does not match number of documents") | ||
} | ||
|
||
metadatas := make([]map[string]interface{}, 0, len(docs)) | ||
for i := 0; i < len(docs); i++ { | ||
metadata := make(map[string]interface{}, len(docs[i].Metadata)) | ||
for key, value := range docs[i].Metadata { | ||
metadata[key] = value | ||
} | ||
metadata[s.contentKey] = texts[i] | ||
|
||
metadatas = append(metadatas, metadata) | ||
} | ||
|
||
return s.upsertPoints(ctx, &s.qdrantURL, vectors, metadatas) | ||
} | ||
|
||
func (s Store) SimilaritySearch(ctx context.Context, | ||
query string, numDocuments int, | ||
options ...vectorstores.Option, | ||
) ([]schema.Document, error) { | ||
opts := s.getOptions(options...) | ||
|
||
filters := s.getFilters(opts) | ||
|
||
scoreThreshold, | ||
err := s.getScoreThreshold(opts) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
vector, | ||
err := s.embedder.EmbedQuery(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return s.searchPoints(ctx, &s.qdrantURL, vector, numDocuments, scoreThreshold, filters) | ||
} | ||
|
||
func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) { | ||
if opts.ScoreThreshold < 0 || opts.ScoreThreshold > 1 { | ||
return 0, errors.New("score threshold must be between 0 and 1") | ||
} | ||
return opts.ScoreThreshold, nil | ||
} | ||
|
||
func (s Store) getFilters(opts vectorstores.Options) any { | ||
if opts.Filters != nil { | ||
return opts.Filters | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (s Store) getOptions(options ...vectorstores.Option) vectorstores.Options { | ||
opts := vectorstores.Options{} | ||
for _, opt := range options { | ||
opt(&opts) | ||
} | ||
return opts | ||
} |
Oops, something went wrong.