From 070ff5dffed157e7abb550c6ded7ffc962809412 Mon Sep 17 00:00:00 2001 From: Abirdcfly Date: Mon, 8 Jan 2024 13:58:20 +0800 Subject: [PATCH] vectorstores: pgvector add option WithConn Signed-off-by: Abirdcfly --- vectorstores/pgvector/options.go | 11 ++++++++- vectorstores/pgvector/pgvector.go | 41 +++++++++++++++++++------------ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/vectorstores/pgvector/options.go b/vectorstores/pgvector/options.go index 2222e5292..e67a8173c 100644 --- a/vectorstores/pgvector/options.go +++ b/vectorstores/pgvector/options.go @@ -65,6 +65,15 @@ func WithCollectionTableName(name string) Option { } } +// WithConn is an option for specifying the Postgres connection. +// From pgx doc: it is not safe for concurrent usage.Use a connection pool to manage access +// to multiple database connections from multiple goroutines. +func WithConn(conn *pgx.Conn) Option { + return func(p *Store) { + p.conn = conn + } +} + func applyClientOptions(opts ...Option) (Store, error) { o := &Store{ collectionName: DefaultCollectionName, @@ -81,7 +90,7 @@ func applyClientOptions(opts ...Option) (Store, error) { o.postgresConnectionURL = os.Getenv("PGVECTOR_CONNECTION_STRING") } - if o.postgresConnectionURL == "" { + if o.postgresConnectionURL == "" && o.conn == nil { return Store{}, fmt.Errorf("%w: missing postgresConnectionURL", ErrInvalidOptions) } diff --git a/vectorstores/pgvector/pgvector.go b/vectorstores/pgvector/pgvector.go index a315a2941..11444b242 100644 --- a/vectorstores/pgvector/pgvector.go +++ b/vectorstores/pgvector/pgvector.go @@ -55,33 +55,41 @@ func New(ctx context.Context, opts ...Option) (Store, error) { if err != nil { return Store{}, err } - store.conn, err = pgx.Connect(ctx, store.postgresConnectionURL) - if err != nil { - return Store{}, err + if store.conn == nil { + store.conn, err = pgx.Connect(ctx, store.postgresConnectionURL) + if err != nil { + return Store{}, err + } } if err = store.conn.Ping(ctx); err != nil { return Store{}, err } - - if err = store.createVectorExtensionIfNotExists(ctx); err != nil { + if err = store.init(ctx); err != nil { return Store{}, err } - if err = store.createCollectionTableIfNotExists(ctx); err != nil { - return Store{}, err + return store, nil +} + +func (s *Store) init(ctx context.Context) error { + if err := s.createVectorExtensionIfNotExists(ctx); err != nil { + return err } - if err = store.createEmbeddingTableIfNotExists(ctx); err != nil { - return Store{}, err + if err := s.createCollectionTableIfNotExists(ctx); err != nil { + return err } - if store.preDeleteCollection { - if err = store.RemoveCollection(ctx); err != nil { - return Store{}, err + if err := s.createEmbeddingTableIfNotExists(ctx); err != nil { + return err + } + if s.preDeleteCollection { + if err := s.RemoveCollection(ctx); err != nil { + return err } } - if err = store.createOrGetCollection(ctx); err != nil { - return Store{}, err + if err := s.createOrGetCollection(ctx); err != nil { + return err } - return store, nil + return nil } func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error { @@ -164,7 +172,8 @@ PRIMARY KEY (uuid))`, s.embeddingTableName, s.collectionTableName) // AddDocuments adds documents to the Postgres collection associated with 'Store'. // and returns the ids of the added documents. -func (s Store) AddDocuments(ctx context.Context, +func (s Store) AddDocuments( + ctx context.Context, docs []schema.Document, options ...vectorstores.Option, ) ([]string, error) {