From b8fc0e35a7efc28308ef5eb3722c1ce052a24df4 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Thu, 18 Jan 2024 06:27:04 -0800 Subject: [PATCH] googleai: move downloadImageData to separate file (#530) --- .golangci.yaml | 1 + llms/googleai/download.go | 33 +++++++++++++++++++++++++ llms/googleai/googleai_llm.go | 41 ++++---------------------------- llms/googleai/googleai_option.go | 1 - 4 files changed, 39 insertions(+), 37 deletions(-) create mode 100644 llms/googleai/download.go diff --git a/.golangci.yaml b/.golangci.yaml index ace7e0b09..7c8215d4f 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -20,6 +20,7 @@ linters: - exhaustruct - varnamelen - nlreturn + - gomnd - wrapcheck # TODO: we should probably enable this one (at least for new code). - testpackage - nolintlint # see https://github.com/golangci/golangci-lint/issues/3228. diff --git a/llms/googleai/download.go b/llms/googleai/download.go new file mode 100644 index 000000000..ec2c618c6 --- /dev/null +++ b/llms/googleai/download.go @@ -0,0 +1,33 @@ +package googleai + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +// downloadImageData downloads the content from the given URL and returns the +// image type and data. The image type is the second part of the response's +// MIME (e.g. "png" from "image/png"). +func downloadImageData(url string) (string, []byte, error) { + resp, err := http.Get(url) //nolint + if err != nil { + return "", nil, fmt.Errorf("failed to fetch image from url: %w", err) + } + defer resp.Body.Close() + + urlData, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, fmt.Errorf("failed to read image bytes: %w", err) + } + + mimeType := resp.Header.Get("Content-Type") + + parts := strings.Split(mimeType, "/") + if len(parts) != 2 { + return "", nil, ErrInvalidMimeType + } + + return parts[1], urlData, nil +} diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index bc51b5827..0dc0859d4 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -8,9 +8,7 @@ import ( "context" "errors" "fmt" - "io" "log" - "net/http" "strings" "github.com/google/generative-ai-go/genai" @@ -115,35 +113,6 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC return response, nil } -// downloadImageData downloads the content from the given URL and returns it as -// a *genai.Blob. -func downloadImageData(url string) (*genai.Blob, error) { - resp, err := http.Get(url) //nolint - if err != nil { - return nil, fmt.Errorf("failed to fetch image from url: %w", err) - } - defer resp.Body.Close() - - urlData, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read image bytes: %w", err) - } - - mimeType := resp.Header.Get("Content-Type") - - // The convenience function genai.ImageData requires just the right part of - // the mime type, so we need to parse it - parts := strings.Split(mimeType, "/") - - if len(parts) != 2 { //nolint - return nil, ErrInvalidMimeType - } - - blob := genai.ImageData(parts[1], urlData) - - return &blob, nil -} - // convertCandidates converts a sequence of genai.Candidate to a response. func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, error) { var contentResponse llms.ContentResponse @@ -197,7 +166,6 @@ func convertParts(parts []llms.ContentPart) ([]genai.Part, error) { convertedParts := make([]genai.Part, 0, len(parts)) for _, part := range parts { var out genai.Part - var err error switch p := part.(type) { case llms.TextContent: @@ -205,10 +173,11 @@ func convertParts(parts []llms.ContentPart) ([]genai.Part, error) { case llms.BinaryContent: out = genai.Blob{MIMEType: p.MIMEType, Data: p.Data} case llms.ImageURLContent: - out, err = downloadImageData(p.URL) - } - if err != nil { - return nil, err + typ, data, err := downloadImageData(p.URL) + if err != nil { + return nil, err + } + out = genai.ImageData(typ, data) } convertedParts = append(convertedParts, out) diff --git a/llms/googleai/googleai_option.go b/llms/googleai/googleai_option.go index bef2ef2be..300b6a811 100644 --- a/llms/googleai/googleai_option.go +++ b/llms/googleai/googleai_option.go @@ -1,4 +1,3 @@ -//nolint:gomnd package googleai // options is a set of options for GoogleAI clients.