Skip to content

Commit

Permalink
googleai: move the PaLM provider into googleai (#541)
Browse files Browse the repository at this point in the history

Now it's consistent with the other Google AI providers

Re #410
  • Loading branch information
eliben authored Jan 22, 2024
1 parent 029ff8e commit 61b3cda
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 17 deletions.
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ linters:
- tagliatelle # As we're dealing with third parties we must accept snake case.
- wsl # We don't agree with wsl's style rules
- exhaustruct
- lll
- varnamelen
- nlreturn
- gomnd
Expand Down
4 changes: 2 additions & 2 deletions embeddings/vertexai_palm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/llms/vertexai"
"github.com/tmc/langchaingo/llms/googleai/palm"
)

func newVertexEmbedder(t *testing.T, opts ...Option) *EmbedderImpl {
Expand All @@ -17,7 +17,7 @@ func newVertexEmbedder(t *testing.T, opts ...Option) *EmbedderImpl {
return nil
}

llm, err := vertexai.New()
llm, err := palm.New()
require.NoError(t, err)

embedder, err := NewEmbedder(llm, opts...)
Expand Down
2 changes: 1 addition & 1 deletion llms/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// 1. Hugging Face: llms/huggingface/
// 2. Local LLM: llms/local/
// 3. OpenAI: llms/openai/
// 4. Vertex AI: llms/vertexai/
// 4. Google AI: llms/googleai/
// 5. Cohere: llms/cohere/
//
// Each subpackage includes provider-specific LLM implementations and helper files for communication
Expand Down
4 changes: 1 addition & 3 deletions llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// package googleai implements a langchaingo provider for Google AI LLMs.
// See https://ai.google.dev/ for more details and documetnation.
//
// nolint: lll
// See https://ai.google.dev/ for more details.
package googleai

import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package vertexaiclient
package palmclient

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package vertexai
// package palm implements a langchaingo provider for Google Vertex AI legacy
// PaLM models. Use the newer Gemini models via llms/googleai/vertex if
// possible.
package palm

import (
"context"
"errors"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/vertexai/internal/vertexaiclient"
"github.com/tmc/langchaingo/llms/googleai/palm/internal/palmclient"
)

var (
Expand All @@ -18,7 +21,7 @@ var (

type LLM struct {
CallbacksHandler callbacks.Handler
client *vertexaiclient.PaLMClient
client *palmclient.PaLMClient
}

var _ llms.Model = (*LLM)(nil)
Expand All @@ -44,7 +47,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
msg0 := messages[0]
part := msg0.Parts[0]

results, err := o.client.CreateCompletion(ctx, &vertexaiclient.CompletionRequest{
results, err := o.client.CreateCompletion(ctx, &palmclient.CompletionRequest{
Prompts: []string{part.(llms.TextContent).Text},
MaxTokens: opts.MaxTokens,
Temperature: opts.Temperature,
Expand Down Expand Up @@ -73,7 +76,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten

// CreateEmbedding creates embeddings for the given input texts.
func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) {
embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{
embeddings, err := o.client.CreateEmbedding(ctx, &palmclient.EmbeddingRequest{
Input: inputTexts,
})
if err != nil {
Expand All @@ -90,13 +93,13 @@ func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]flo
return embeddings, nil
}

// New returns a new VertexAI PaLM LLM.
// New returns a new palmclient PaLM LLM.
func New(opts ...Option) (*LLM, error) {
client, err := newClient(opts...)
return &LLM{client: client}, err
}

func newClient(opts ...Option) (*vertexaiclient.PaLMClient, error) {
func newClient(opts ...Option) (*palmclient.PaLMClient, error) {
// Ensure options are initialized only once.
initOptions.Do(initOpts)
options := &options{}
Expand All @@ -109,5 +112,5 @@ func newClient(opts ...Option) (*vertexaiclient.PaLMClient, error) {
return nil, ErrMissingProjectID
}

return vertexaiclient.New(options.projectID, options.clientOptions...)
return palmclient.New(options.projectID, options.clientOptions...)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package vertexai
package palm

import (
"net/http"
Expand Down
3 changes: 3 additions & 0 deletions llms/googleai/vertex/new.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// package vertex implements a langchaingo provider for Google Vertex AI LLMs,
// including the new Gemini models.
// See https://cloud.google.com/vertex-ai for more details.
package vertex

import (
Expand Down
3 changes: 2 additions & 1 deletion llms/googleai/vertex/vertex.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// DO NOT EDIT: this code is auto-generated from llms/googleai/googleai_llm.go
package vertex

// DO NOT EDIT: this code is auto-generated from llms/googleai/googleai_llm.go

import (
"context"
"errors"
Expand Down

0 comments on commit 61b3cda

Please sign in to comment.