-
-
Notifications
You must be signed in to change notification settings - Fork 743
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #531 from corani/corani/cybertron
- Loading branch information
Showing
6 changed files
with
199 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package cybertron | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/nlpodyssey/cybertron/pkg/models/bert" | ||
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding" | ||
"github.com/tmc/langchaingo/embeddings" | ||
) | ||
|
||
// Cybertron is the embedder using Cybertron to run embedding models locally. | ||
type Cybertron struct { | ||
encoder textencoding.Interface | ||
Model string | ||
ModelsDir string | ||
PoolingStrategy bert.PoolingStrategyType | ||
} | ||
|
||
var _ embeddings.EmbedderClient = (*Cybertron)(nil) | ||
|
||
// NewCybertron returns a new embedding client that uses Cybertron to run embedding | ||
// models locally (on the CPU). The embedding model will be downloaded and cached | ||
// automatically. Use `WithModel` and `WithModelsDir` to change which model is used | ||
// and where it is cached. | ||
func NewCybertron(opts ...Option) (*Cybertron, error) { | ||
return applyOptions(opts...) | ||
} | ||
|
||
// CreateEmbedding implements the `embeddings.EmbedderClient` and creates an embedding | ||
// vector for each of the supplied texts. | ||
func (c *Cybertron) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) { | ||
result := make([][]float32, 0, len(texts)) | ||
|
||
for _, text := range texts { | ||
embedding, err := c.encoder.Encode(ctx, text, int(c.PoolingStrategy)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
result = append(result, embedding.Vector.Normalize2().Data().F32()) | ||
} | ||
|
||
return result, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package cybertron | ||
|
||
import ( | ||
"context" | ||
"os" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestCybertronEmbeddings(t *testing.T) { | ||
t.Parallel() | ||
|
||
_, err := os.Stat(_defaultModelsDir) | ||
if os.IsNotExist(err) && os.Getenv("CYBERTRON_DO_DOWNLOAD") == "" { | ||
// Cybertron downloads the embedding model and caches it in ModelsDir. Doing this as | ||
// part of the tests would be costly in terms of time and bandwidth and likely make | ||
// the test flaky. | ||
t.Skipf("ModelsDir %q doesn't exist", _defaultModelsDir) | ||
} | ||
|
||
emb, err := NewCybertron( | ||
WithModelsDir(_defaultModelsDir), | ||
WithModel(_defaultModel), | ||
WithPoolingStrategy(_defaultPoolingStrategy), | ||
) | ||
require.NoError(t, err) | ||
|
||
res, err := emb.CreateEmbedding(context.Background(), []string{ | ||
"Hello world", "The world is ending", "good bye", | ||
}) | ||
require.NoError(t, err) | ||
require.Len(t, res, 3) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package cybertron | ||
|
||
import ( | ||
"github.com/nlpodyssey/cybertron/pkg/models/bert" | ||
"github.com/nlpodyssey/cybertron/pkg/tasks" | ||
"github.com/nlpodyssey/cybertron/pkg/tasks/textencoding" | ||
) | ||
|
||
const ( | ||
_defaultModel = "sentence-transformers/all-MiniLM-L6-v2" | ||
_defaultModelsDir = "models" | ||
_defaultPoolingStrategy = bert.MeanPooling | ||
) | ||
|
||
// Option is a function type that can be used to modify the client. | ||
type Option func(c *Cybertron) | ||
|
||
// apply the option to the instance. | ||
func (o Option) apply(c *Cybertron) { | ||
o(c) | ||
} | ||
|
||
// WithModel is an option for providing the model name to use. Default is | ||
// "sentence-transformers/all-MiniLM-L6-v2". Note that not all embedding models | ||
// are supported. | ||
func WithModel(model string) Option { | ||
return func(c *Cybertron) { | ||
c.Model = model | ||
} | ||
} | ||
|
||
// WithModelsDir is an option for setting the directory to store downloaded models. | ||
// Default is "models". | ||
func WithModelsDir(dir string) Option { | ||
return func(c *Cybertron) { | ||
c.ModelsDir = dir | ||
} | ||
} | ||
|
||
// WithPoolingStrategy sets the pooling strategy. Default is mean pooling. | ||
func WithPoolingStrategy(strategy bert.PoolingStrategyType) Option { | ||
return func(c *Cybertron) { | ||
c.PoolingStrategy = strategy | ||
} | ||
} | ||
|
||
// WithEncoder is an option for providing the Encoder. | ||
func WithEncoder(encoder textencoding.Interface) Option { | ||
return func(c *Cybertron) { | ||
c.encoder = encoder | ||
} | ||
} | ||
|
||
func applyOptions(opts ...Option) (*Cybertron, error) { | ||
c := &Cybertron{ | ||
Model: _defaultModel, | ||
ModelsDir: _defaultModelsDir, | ||
PoolingStrategy: _defaultPoolingStrategy, | ||
encoder: nil, | ||
} | ||
|
||
for _, opt := range opts { | ||
opt.apply(c) | ||
} | ||
|
||
if c.encoder == nil { | ||
encoder, err := tasks.Load[textencoding.Interface](&tasks.Config{ | ||
ModelsDir: c.ModelsDir, | ||
ModelName: c.Model, | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
c.encoder = encoder | ||
} | ||
|
||
return c, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.