Skip to content

Commit

Permalink
Merge pull request #531 from corani/corani/cybertron
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 23, 2024
2 parents 61b3cda + 6982a61 commit 6f20ee5
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ go.work.sum

# Test outputs
coverage.out
cover.cov

# macOS Specific
.DS_Store
Expand All @@ -14,3 +15,5 @@ coverage.out
# dev
.env
vendor/*

embeddings/cybertron/models/*
44 changes: 44 additions & 0 deletions embeddings/cybertron/cybertron.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package cybertron

import (
"context"

"github.com/nlpodyssey/cybertron/pkg/models/bert"
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
"github.com/tmc/langchaingo/embeddings"
)

// Cybertron is the embedder using Cybertron to run embedding models locally.
type Cybertron struct {
encoder textencoding.Interface
Model string
ModelsDir string
PoolingStrategy bert.PoolingStrategyType
}

var _ embeddings.EmbedderClient = (*Cybertron)(nil)

// NewCybertron returns a new embedding client that uses Cybertron to run embedding
// models locally (on the CPU). The embedding model will be downloaded and cached
// automatically. Use `WithModel` and `WithModelsDir` to change which model is used
// and where it is cached.
func NewCybertron(opts ...Option) (*Cybertron, error) {
return applyOptions(opts...)
}

// CreateEmbedding implements the `embeddings.EmbedderClient` and creates an embedding
// vector for each of the supplied texts.
func (c *Cybertron) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) {
result := make([][]float32, 0, len(texts))

for _, text := range texts {
embedding, err := c.encoder.Encode(ctx, text, int(c.PoolingStrategy))
if err != nil {
return nil, err
}

result = append(result, embedding.Vector.Normalize2().Data().F32())
}

return result, nil
}
34 changes: 34 additions & 0 deletions embeddings/cybertron/cybertron_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package cybertron

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestCybertronEmbeddings(t *testing.T) {
t.Parallel()

_, err := os.Stat(_defaultModelsDir)
if os.IsNotExist(err) && os.Getenv("CYBERTRON_DO_DOWNLOAD") == "" {
// Cybertron downloads the embedding model and caches it in ModelsDir. Doing this as
// part of the tests would be costly in terms of time and bandwidth and likely make
// the test flaky.
t.Skipf("ModelsDir %q doesn't exist", _defaultModelsDir)
}

emb, err := NewCybertron(
WithModelsDir(_defaultModelsDir),
WithModel(_defaultModel),
WithPoolingStrategy(_defaultPoolingStrategy),
)
require.NoError(t, err)

res, err := emb.CreateEmbedding(context.Background(), []string{
"Hello world", "The world is ending", "good bye",
})
require.NoError(t, err)
require.Len(t, res, 3)
}
79 changes: 79 additions & 0 deletions embeddings/cybertron/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cybertron

import (
"github.com/nlpodyssey/cybertron/pkg/models/bert"
"github.com/nlpodyssey/cybertron/pkg/tasks"
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
)

const (
_defaultModel = "sentence-transformers/all-MiniLM-L6-v2"
_defaultModelsDir = "models"
_defaultPoolingStrategy = bert.MeanPooling
)

// Option is a function type that can be used to modify the client.
type Option func(c *Cybertron)

// apply the option to the instance.
func (o Option) apply(c *Cybertron) {
o(c)
}

// WithModel is an option for providing the model name to use. Default is
// "sentence-transformers/all-MiniLM-L6-v2". Note that not all embedding models
// are supported.
func WithModel(model string) Option {
return func(c *Cybertron) {
c.Model = model
}
}

// WithModelsDir is an option for setting the directory to store downloaded models.
// Default is "models".
func WithModelsDir(dir string) Option {
return func(c *Cybertron) {
c.ModelsDir = dir
}
}

// WithPoolingStrategy sets the pooling strategy. Default is mean pooling.
func WithPoolingStrategy(strategy bert.PoolingStrategyType) Option {
return func(c *Cybertron) {
c.PoolingStrategy = strategy
}
}

// WithEncoder is an option for providing the Encoder.
func WithEncoder(encoder textencoding.Interface) Option {
return func(c *Cybertron) {
c.encoder = encoder
}
}

func applyOptions(opts ...Option) (*Cybertron, error) {
c := &Cybertron{
Model: _defaultModel,
ModelsDir: _defaultModelsDir,
PoolingStrategy: _defaultPoolingStrategy,
encoder: nil,
}

for _, opt := range opts {
opt.apply(c)
}

if c.encoder == nil {
encoder, err := tasks.Load[textencoding.Interface](&tasks.Config{
ModelsDir: c.ModelsDir,
ModelName: c.Model,
})
if err != nil {
return nil, err
}

c.encoder = encoder
}

return c, nil
}
17 changes: 10 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ require (
github.com/Masterminds/semver/v3 v3.2.0 // indirect
github.com/PuerkitoBio/purell v1.1.1 // indirect
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
github.com/alecthomas/colour v0.1.0 // indirect
github.com/alecthomas/repr v0.0.0-20210801044451-80ca428c5142 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
github.com/antchfx/htmlquery v1.3.0 // indirect
github.com/antchfx/xmlquery v1.3.17 // indirect
Expand All @@ -47,37 +45,42 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/flatbuffers v23.5.26+incompatible // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 // indirect
github.com/huandu/xstrings v1.3.3 // indirect
github.com/imdario/mergo v0.3.11 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kennygrant/sanitize v1.2.4 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2 // indirect
github.com/mitchellh/copystructure v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/reflectwalk v1.0.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/nlpodyssey/gopickle v0.2.0 // indirect
github.com/nlpodyssey/gotokenizers v0.2.0 // indirect
github.com/nlpodyssey/spago v1.1.0 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/rs/zerolog v1.31.0 // indirect
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
github.com/sergi/go-diff v1.2.0 // indirect
github.com/shopspring/decimal v1.2.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spf13/cast v1.3.1 // indirect
Expand Down Expand Up @@ -111,7 +114,6 @@ require (
cloud.google.com/go/vertexai v0.6.0
github.com/Masterminds/sprig/v3 v3.2.3
github.com/PuerkitoBio/goquery v1.8.1
github.com/alecthomas/assert v1.0.0
github.com/amikos-tech/chroma-go v0.0.0-20231228181736-e8f5e927093e
github.com/cohere-ai/tokenizer v1.1.2
github.com/go-openapi/strfmt v0.21.3
Expand All @@ -126,6 +128,7 @@ require (
github.com/microcosm-cc/bluemonday v1.0.26
github.com/milvus-io/milvus-sdk-go/v2 v2.3.2
github.com/nikolalohinski/gonja v1.5.3
github.com/nlpodyssey/cybertron v0.2.1
github.com/opensearch-project/opensearch-go v1.1.0
github.com/pgvector/pgvector-go v0.1.1
github.com/pinecone-io/go-pinecone v0.3.0
Expand Down
Loading

0 comments on commit 6f20ee5

Please sign in to comment.