From 9c44a166d4a344c2b82b44f99f48661640d14e0e Mon Sep 17 00:00:00 2001 From: pattonjp Date: Fri, 10 Nov 2023 13:46:15 -0600 Subject: [PATCH] adding hugging face text embedings --- embeddings/tei/doc.go | 8 ++ embeddings/tei/options.go | 97 ++++++++++++++++++++ embeddings/tei/text_embeddings_inference..go | 84 +++++++++++++++++ 3 files changed, 189 insertions(+) create mode 100644 embeddings/tei/doc.go create mode 100644 embeddings/tei/options.go create mode 100644 embeddings/tei/text_embeddings_inference..go diff --git a/embeddings/tei/doc.go b/embeddings/tei/doc.go new file mode 100644 index 000000000..0a05401fc --- /dev/null +++ b/embeddings/tei/doc.go @@ -0,0 +1,8 @@ +/* +Huggingface Text Embeddings Inference +https://github.com/huggingface/text-embeddings-inference + +package is a wrapper for the Huggingface text embeddings inference project +that can be run locally for creating vector embeddings. +*/ +package tei diff --git a/embeddings/tei/options.go b/embeddings/tei/options.go new file mode 100644 index 000000000..c234c9de8 --- /dev/null +++ b/embeddings/tei/options.go @@ -0,0 +1,97 @@ +package tei + +import ( + "errors" + "runtime" + "time" + + client "github.com/gage-technologies/tei-go" +) + +const ( + _defaultBatchSize = 512 + _defaultStripNewLines = true + _defaultTimeNanoSeconds = 60 * 1000000000 +) + +var ErrMissingAPIBaseURL = errors.New("missing the API Base URL") //nolint:lll + +type Option func(emb *TextEmbeddingsInference) + +// WithStripNewLines is an option for specifying the should it strip new lines. +func WithStripNewLines(stripNewLines bool) Option { + return func(p *TextEmbeddingsInference) { + p.StripNewLines = stripNewLines + } +} + +// WithPoolSize is an option for specifying the number of goroutines. +func WithPoolSize(poolSize int) Option { + return func(p *TextEmbeddingsInference) { + p.poolSize = poolSize + } +} + +// WithBatchSize is an option for specifying the batch size. +func WithBatchSize(batchSize int) Option { + return func(p *TextEmbeddingsInference) { + p.BatchSize = batchSize + } +} + +// WithAPIBaseURL adds base url for api. +func WithAPIBaseURL(url string) Option { + return func(emb *TextEmbeddingsInference) { + emb.baseURL = url + } +} + +// WithHeaders add request headers. +func WithHeaders(headers map[string]string) Option { + return func(emb *TextEmbeddingsInference) { + if emb.headers == nil { + emb.headers = make(map[string]string, len(headers)) + } + for k, v := range headers { + emb.headers[k] = v + } + } +} + +// WithCookies add request cookies. +func WithCookies(cookies map[string]string) Option { + return func(emb *TextEmbeddingsInference) { + if emb.cookies == nil { + emb.cookies = make(map[string]string, len(cookies)) + } + for k, v := range cookies { + emb.cookies[k] = v + } + } +} + +// WithTimeout set the request timeout. +func WithTimeout(dur time.Duration) Option { + return func(emb *TextEmbeddingsInference) { + emb.timeout = dur + } +} + +func applyClientOptions(opts ...Option) (TextEmbeddingsInference, error) { + emb := TextEmbeddingsInference{ + StripNewLines: _defaultStripNewLines, + BatchSize: _defaultBatchSize, + timeout: time.Duration(_defaultTimeNanoSeconds), + poolSize: runtime.GOMAXPROCS(0), + } + for _, opt := range opts { + opt(&emb) + } + if emb.baseURL == "" { + return emb, ErrMissingAPIBaseURL + } + if emb.client == nil { + emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout) + } + return emb, nil +} diff --git a/embeddings/tei/text_embeddings_inference..go b/embeddings/tei/text_embeddings_inference..go new file mode 100644 index 000000000..3a5c253f6 --- /dev/null +++ b/embeddings/tei/text_embeddings_inference..go @@ -0,0 +1,84 @@ +package tei + +import ( + "context" + "strings" + "time" + + client "github.com/gage-technologies/tei-go" + "github.com/sourcegraph/conc/pool" + "github.com/tmc/langchaingo/embeddings" +) + +type TextEmbeddingsInference struct { + client *client.Client + StripNewLines bool + BatchSize int + baseURL string + headers map[string]string + cookies map[string]string + timeout time.Duration + poolSize int +} + +var _ embeddings.Embedder = TextEmbeddingsInference{} + +func New(opts ...Option) (TextEmbeddingsInference, error) { + emb, err := applyClientOptions(opts...) + if err != nil { + return emb, err + } + emb.client = client.NewClient(emb.baseURL, emb.headers, emb.cookies, emb.timeout) + + return emb, nil +} + +// EmbedDocuments creates one vector embedding for each of the texts. +func (e TextEmbeddingsInference) EmbedDocuments(_ context.Context, texts []string) ([][]float32, error) { + batchedTexts := embeddings.BatchTexts( + embeddings.MaybeRemoveNewLines(texts, e.StripNewLines), + e.BatchSize, + ) + + emb := make([][]float32, 0, len(texts)) + + p := pool.New().WithMaxGoroutines(e.poolSize).WithErrors() + + for _, txt := range batchedTexts { + p.Go(func() error { + curTextEmbeddings, err := e.client.Embed(strings.Join(txt, " "), false) + if err != nil { + return err + } + + textLengths := make([]int, 0, len(txt)) + for _, text := range txt { + textLengths = append(textLengths, len(text)) + } + + combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths) + if err != nil { + return err + } + + emb = append(emb, combined) + + return nil + }) + } + return emb, p.Wait() +} + +// EmbedQuery embeds a single text. +func (e TextEmbeddingsInference) EmbedQuery(_ context.Context, text string) ([]float32, error) { + if e.StripNewLines { + text = strings.ReplaceAll(text, "\n", " ") + } + + emb, err := e.client.Embed(text, false) + if err != nil { + return nil, err + } + + return emb[0], nil +}