Skip to content

Commit

Permalink
Merge pull request #499 from Abirdcfly/pgvector
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 20, 2024
2 parents cc899e5 + 1443e5f commit 15a180f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
11 changes: 10 additions & 1 deletion vectorstores/pgvector/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ func WithCollectionTableName(name string) Option {
}
}

// WithConn is an option for specifying the Postgres connection.
// From pgx doc: it is not safe for concurrent usage.Use a connection pool to manage access
// to multiple database connections from multiple goroutines.
func WithConn(conn *pgx.Conn) Option {
return func(p *Store) {
p.conn = conn
}
}

func applyClientOptions(opts ...Option) (Store, error) {
o := &Store{
collectionName: DefaultCollectionName,
Expand All @@ -81,7 +90,7 @@ func applyClientOptions(opts ...Option) (Store, error) {
o.postgresConnectionURL = os.Getenv("PGVECTOR_CONNECTION_STRING")
}

if o.postgresConnectionURL == "" {
if o.postgresConnectionURL == "" && o.conn == nil {
return Store{}, fmt.Errorf("%w: missing postgresConnectionURL", ErrInvalidOptions)
}

Expand Down
41 changes: 25 additions & 16 deletions vectorstores/pgvector/pgvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,41 @@ func New(ctx context.Context, opts ...Option) (Store, error) {
if err != nil {
return Store{}, err
}
store.conn, err = pgx.Connect(ctx, store.postgresConnectionURL)
if err != nil {
return Store{}, err
if store.conn == nil {
store.conn, err = pgx.Connect(ctx, store.postgresConnectionURL)
if err != nil {
return Store{}, err
}
}

if err = store.conn.Ping(ctx); err != nil {
return Store{}, err
}

if err = store.createVectorExtensionIfNotExists(ctx); err != nil {
if err = store.init(ctx); err != nil {
return Store{}, err
}
if err = store.createCollectionTableIfNotExists(ctx); err != nil {
return Store{}, err
return store, nil
}

func (s *Store) init(ctx context.Context) error {
if err := s.createVectorExtensionIfNotExists(ctx); err != nil {
return err
}
if err = store.createEmbeddingTableIfNotExists(ctx); err != nil {
return Store{}, err
if err := s.createCollectionTableIfNotExists(ctx); err != nil {
return err
}
if store.preDeleteCollection {
if err = store.RemoveCollection(ctx); err != nil {
return Store{}, err
if err := s.createEmbeddingTableIfNotExists(ctx); err != nil {
return err
}
if s.preDeleteCollection {
if err := s.RemoveCollection(ctx); err != nil {
return err
}
}
if err = store.createOrGetCollection(ctx); err != nil {
return Store{}, err
if err := s.createOrGetCollection(ctx); err != nil {
return err
}
return store, nil
return nil
}

func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error {
Expand Down Expand Up @@ -164,7 +172,8 @@ PRIMARY KEY (uuid))`, s.embeddingTableName, s.collectionTableName)

// AddDocuments adds documents to the Postgres collection associated with 'Store'.
// and returns the ids of the added documents.
func (s Store) AddDocuments(ctx context.Context,
func (s Store) AddDocuments(
ctx context.Context,
docs []schema.Document,
options ...vectorstores.Option,
) ([]string, error) {
Expand Down

0 comments on commit 15a180f

Please sign in to comment.