Skip to content

Commit

Permalink
embeddings: use EmbedderClient for openai embeddings
Browse files Browse the repository at this point in the history
Applies the refactored code in #357 to remove duplicated code here.

Updates #356
  • Loading branch information
eliben committed Nov 20, 2023
1 parent 26bd994 commit 35140b3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 58 deletions.
35 changes: 3 additions & 32 deletions embeddings/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/embeddings/internal/embedderclient"
"github.com/tmc/langchaingo/llms/openai"
)

Expand All @@ -30,38 +31,8 @@ func NewOpenAI(opts ...Option) (OpenAI, error) {

// EmbedDocuments creates one vector embedding for each of the texts.
func (e OpenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float32, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}

// If the size of this batch is 1, don't average/combine the vectors.
if len(texts) == 1 {
emb = append(emb, curTextEmbeddings[0])
continue
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines)
return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize)
}

// EmbedQuery embeds a single text.
Expand Down
29 changes: 3 additions & 26 deletions embeddings/openai/openaichat/openai_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/embeddings/internal/embedderclient"
"github.com/tmc/langchaingo/llms/openai"
)

Expand All @@ -29,32 +30,8 @@ func NewChatOpenAI(opts ...ChatOption) (ChatOpenAI, error) {
}

func (e ChatOpenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) {
batchedTexts := embeddings.BatchTexts(
embeddings.MaybeRemoveNewLines(texts, e.StripNewLines),
e.BatchSize,
)

emb := make([][]float32, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
return nil, err
}

textLengths := make([]int, 0, len(texts))
for _, text := range texts {
textLengths = append(textLengths, len(text))
}

combined, err := embeddings.CombineVectors(curTextEmbeddings, textLengths)
if err != nil {
return nil, err
}

emb = append(emb, combined)
}

return emb, nil
texts = embeddings.MaybeRemoveNewLines(texts, e.StripNewLines)
return embedderclient.BatchedEmbed(ctx, e.client, texts, e.BatchSize)
}

func (e ChatOpenAI) EmbedQuery(ctx context.Context, text string) ([]float32, error) {
Expand Down

0 comments on commit 35140b3

Please sign in to comment.