Skip to content

Commit

Permalink
adding hugging face text embedings
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonjp committed Nov 10, 2023
1 parent 2b9f753 commit 9c44a16
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
8 changes: 8 additions & 0 deletions embeddings/tei/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
Huggingface Text Embeddings Inference
https://github.com/huggingface/text-embeddings-inference
package is a wrapper for the Huggingface text embeddings inference project
that can be run locally for creating vector embeddings.
*/
package tei
97 changes: 97 additions & 0 deletions embeddings/tei/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package tei

import (
"errors"
"runtime"
"time"

client "github.com/gage-technologies/tei-go"
)

const (
_defaultBatchSize = 512
_defaultStripNewLines = true
_defaultTimeNanoSeconds = 60 * 1000000000
)

var ErrMissingAPIBaseURL = errors.New("missing the API Base URL") //nolint:lll

type Option func(emb *TextEmbeddingsInference)

// WithStripNewLines is an option for specifying the should it strip new lines.
func WithStripNewLines(stripNewLines bool) Option {
return func(p *TextEmbeddingsInference) {
p.StripNewLines = stripNewLines
}
}

// WithPoolSize is an option for specifying the number of goroutines.
func WithPoolSize(poolSize int) Option {
return func(p *TextEmbeddingsInference) {
p.poolSize = poolSize
}
}

// WithBatchSize is an option for specifying the batch size.
func WithBatchSize(batchSize int) Option {
return func(p *TextEmbeddingsInference) {
p.BatchSize = batchSize
}
}

// WithAPIBaseURL adds base url for api.
func WithAPIBaseURL(url string) Option {
return func(emb *TextEmbeddingsInference) {
emb.baseURL = url
}
}

// WithHeaders add request headers.
func WithHeaders(headers map[string]string) Option {
return func(emb *TextEmbeddingsInference) {
if emb.headers == nil {
emb.headers = make(map[string]string, len(headers))
}
for k, v := range headers {
emb.headers[k] = v
}
}
}

// WithCookies add request cookies.
func WithCookies(cookies map[string]string) Option {
return func(emb *TextEmbeddingsInference) {
if emb.cookies == nil {
emb.cookies = make(map[string]string, len(cookies))
}
for k, v := range cookies {
emb.cookies[k] = v
}
}
}

// WithTimeout set the request timeout.
func WithTimeout(dur time.Duration) Option {
return func(emb *TextEmbeddingsInference) {
emb.timeout = dur
}
}

func applyClientOptions(opts ...Option) (TextEmbeddingsInference, error) {
emb := TextEmbeddingsInference{
StripNewLines: _defaultStripNewLines,
BatchSize: _defaultBatchSize,
timeout: time.Duration(_defaultTimeNanoSeconds),
poolSize: runtime.GOMAXPROCS(0),
}
for _, opt := range opts {
opt(&emb)
}
if emb.baseURL == "" {
return emb, ErrMissingAPIBaseURL
}
if emb.client == nil {
emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout)
}
return emb, nil
}
84 changes: 84 additions & 0 deletions embeddings/tei/text_embeddings_inference..go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package tei

import (
"context"
"strings"
"time"

client "github.com/gage-technologies/tei-go"
"github.com/sourcegraph/conc/pool"
"github.com/tmc/langchaingo/embeddings"
)

type TextEmbeddingsInference struct {
client *client.Client
StripNewLines bool
BatchSize int
baseURL string
headers map[string]string
cookies map[string]string
timeout time.Duration
poolSize int
}

var _ embeddings.Embedder = TextEmbeddingsInference{}

func New(opts ...Option) (TextEmbeddingsInference, error) {
emb, err := applyClientOptions(opts...)
if err != nil {
return emb, err
}
emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout)

return emb, nil
}

// EmbedDocuments creates one vector embedding for each of the texts.
func (e TextEmbeddingsInference) EmbedDocuments(_ context.Context, texts []string) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float32, 0, len(texts))

p := pool.New().WithMaxGoroutines(e.poolSize).WithErrors()

for _, txt := range batchedTexts {
p.Go(func() error {
curTextEmbeddings, err := e.client.Embed(strings.Join(txt, " "), false)
if err != nil {
return err
}

textLengths := make([]int, 0, len(txt))
for _, text := range txt {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return err
}

emb = append(emb, combined)

return nil
})
}
return emb, p.Wait()
}

// EmbedQuery embeds a single text.
func (e TextEmbeddingsInference) EmbedQuery(_ context.Context, text string) ([]float32, error) {
if e.StripNewLines {
text = strings.ReplaceAll(text, "\n", " ")
}

emb, err := e.client.Embed(text, false)
if err != nil {
return nil, err
}

return emb[0], nil
}

0 comments on commit 9c44a16

Please sign in to comment.