Skip to content

Commit

Permalink
Merge branch 'main' into testcontainers-go
Browse files Browse the repository at this point in the history
* main:
  googleai: refactor to better separate generated code (tmc#547)
  googleai: add embeddings to vertex (tmc#546)
  embeddings: add cybertron local embeddings
  googleai: move the PaLM provider into googleai (tmc#541)
  • Loading branch information
mdelapenya committed Jan 24, 2024
2 parents 10b58fc + 2545ace commit 863cff7
Show file tree
Hide file tree
Showing 20 changed files with 353 additions and 98 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ go.work.sum

# Test outputs
coverage.out
cover.cov

# macOS Specific
.DS_Store
Expand All @@ -14,3 +15,5 @@ coverage.out
# dev
.env
vendor/*

embeddings/cybertron/models/*
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ linters:
- tagliatelle # As we're dealing with third parties we must accept snake case.
- wsl # We don't agree with wsl's style rules
- exhaustruct
- lll
- varnamelen
- nlreturn
- gomnd
Expand Down
44 changes: 44 additions & 0 deletions embeddings/cybertron/cybertron.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package cybertron

import (
"context"

"github.com/nlpodyssey/cybertron/pkg/models/bert"
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
"github.com/tmc/langchaingo/embeddings"
)

// Cybertron is the embedder using Cybertron to run embedding models locally.
type Cybertron struct {
encoder textencoding.Interface
Model string
ModelsDir string
PoolingStrategy bert.PoolingStrategyType
}

var _ embeddings.EmbedderClient = (*Cybertron)(nil)

// NewCybertron returns a new embedding client that uses Cybertron to run embedding
// models locally (on the CPU). The embedding model will be downloaded and cached
// automatically. Use `WithModel` and `WithModelsDir` to change which model is used
// and where it is cached.
func NewCybertron(opts ...Option) (*Cybertron, error) {
return applyOptions(opts...)
}

// CreateEmbedding implements the `embeddings.EmbedderClient` and creates an embedding
// vector for each of the supplied texts.
func (c *Cybertron) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) {
result := make([][]float32, 0, len(texts))

for _, text := range texts {
embedding, err := c.encoder.Encode(ctx, text, int(c.PoolingStrategy))
if err != nil {
return nil, err
}

result = append(result, embedding.Vector.Normalize2().Data().F32())
}

return result, nil
}
34 changes: 34 additions & 0 deletions embeddings/cybertron/cybertron_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package cybertron

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestCybertronEmbeddings(t *testing.T) {
t.Parallel()

_, err := os.Stat(_defaultModelsDir)
if os.IsNotExist(err) && os.Getenv("CYBERTRON_DO_DOWNLOAD") == "" {
// Cybertron downloads the embedding model and caches it in ModelsDir. Doing this as
// part of the tests would be costly in terms of time and bandwidth and likely make
// the test flaky.
t.Skipf("ModelsDir %q doesn't exist", _defaultModelsDir)
}

emb, err := NewCybertron(
WithModelsDir(_defaultModelsDir),
WithModel(_defaultModel),
WithPoolingStrategy(_defaultPoolingStrategy),
)
require.NoError(t, err)

res, err := emb.CreateEmbedding(context.Background(), []string{
"Hello world", "The world is ending", "good bye",
})
require.NoError(t, err)
require.Len(t, res, 3)
}
79 changes: 79 additions & 0 deletions embeddings/cybertron/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cybertron

import (
"github.com/nlpodyssey/cybertron/pkg/models/bert"
"github.com/nlpodyssey/cybertron/pkg/tasks"
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
)

const (
_defaultModel = "sentence-transformers/all-MiniLM-L6-v2"
_defaultModelsDir = "models"
_defaultPoolingStrategy = bert.MeanPooling
)

// Option is a function type that can be used to modify the client.
type Option func(c *Cybertron)

// apply the option to the instance.
func (o Option) apply(c *Cybertron) {
o(c)
}

// WithModel is an option for providing the model name to use. Default is
// "sentence-transformers/all-MiniLM-L6-v2". Note that not all embedding models
// are supported.
func WithModel(model string) Option {
return func(c *Cybertron) {
c.Model = model
}
}

// WithModelsDir is an option for setting the directory to store downloaded models.
// Default is "models".
func WithModelsDir(dir string) Option {
return func(c *Cybertron) {
c.ModelsDir = dir
}
}

// WithPoolingStrategy sets the pooling strategy. Default is mean pooling.
func WithPoolingStrategy(strategy bert.PoolingStrategyType) Option {
return func(c *Cybertron) {
c.PoolingStrategy = strategy
}
}

// WithEncoder is an option for providing the Encoder.
func WithEncoder(encoder textencoding.Interface) Option {
return func(c *Cybertron) {
c.encoder = encoder
}
}

func applyOptions(opts ...Option) (*Cybertron, error) {
c := &Cybertron{
Model: _defaultModel,
ModelsDir: _defaultModelsDir,
PoolingStrategy: _defaultPoolingStrategy,
encoder: nil,
}

for _, opt := range opts {
opt.apply(c)
}

if c.encoder == nil {
encoder, err := tasks.Load[textencoding.Interface](&tasks.Config{
ModelsDir: c.ModelsDir,
ModelName: c.Model,
})
if err != nil {
return nil, err
}

c.encoder = encoder
}

return c, nil
}
4 changes: 2 additions & 2 deletions embeddings/vertexai_palm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/llms/vertexai"
"github.com/tmc/langchaingo/llms/googleai/palm"
)

func newVertexEmbedder(t *testing.T, opts ...Option) *EmbedderImpl {
Expand All @@ -17,7 +17,7 @@ func newVertexEmbedder(t *testing.T, opts ...Option) *EmbedderImpl {
return nil
}

llm, err := vertexai.New()
llm, err := palm.New()
require.NoError(t, err)

embedder, err := NewEmbedder(llm, opts...)
Expand Down
21 changes: 15 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ require (
cloud.google.com/go/iam v1.1.5 // indirect
cloud.google.com/go/longrunning v0.5.4 // indirect
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.2.0 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
Expand All @@ -41,8 +41,9 @@ require (
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/distribution/reference v0.5.0 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect
github.com/docker/distribution v2.8.3+incompatible // indirect
github.com/docker/docker v24.0.7+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
Expand All @@ -61,26 +62,29 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/flatbuffers v23.5.26+incompatible // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 // indirect
github.com/huandu/xstrings v1.3.3 // indirect
github.com/imdario/mergo v0.3.13 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kennygrant/sanitize v1.2.4 // indirect
github.com/klauspost/compress v1.16.0 // indirect
github.com/klauspost/compress v1.17.2 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2 // indirect
github.com/mitchellh/copystructure v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
Expand All @@ -91,6 +95,9 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nlpodyssey/gopickle v0.2.0 // indirect
github.com/nlpodyssey/gotokenizers v0.2.0 // indirect
github.com/nlpodyssey/spago v1.1.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0-rc5 // indirect
Expand All @@ -100,6 +107,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/rs/zerolog v1.31.0 // indirect
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
github.com/shirou/gopsutil/v3 v3.23.11 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
Expand All @@ -121,14 +129,14 @@ require (
go.mongodb.org/mongo-driver v1.11.3 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/mod v0.13.0 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/oauth2 v0.15.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.10.0 // indirect
golang.org/x/tools v0.14.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20231120223509-83a465c0220f // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20231211222908-989df2bf70f3 // indirect
Expand All @@ -155,6 +163,7 @@ require (
github.com/microcosm-cc/bluemonday v1.0.26
github.com/milvus-io/milvus-sdk-go/v2 v2.3.2
github.com/nikolalohinski/gonja v1.5.3
github.com/nlpodyssey/cybertron v0.2.1
github.com/opensearch-project/opensearch-go v1.1.0
github.com/pgvector/pgvector-go v0.1.1
github.com/pinecone-io/go-pinecone v0.3.0
Expand Down
Loading

0 comments on commit 863cff7

Please sign in to comment.