From 451aa4dca713ad4256e4d80a52f5b08a3fac79ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Schr=C3=B6ter?= Date: Thu, 3 Oct 2024 16:10:23 +0200 Subject: [PATCH 1/2] Add bearer token to ollama --- go.mod | 2 ++ go.sum | 10 ++++++++++ main.go | 14 ++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/go.mod b/go.mod index 25ab1c8..0b648bf 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ toolchain go1.22.2 require ( github.com/gin-gonic/gin v1.10.0 + github.com/hashicorp/go-retryablehttp v0.7.7 github.com/tmc/langchaingo v0.1.12 ) @@ -22,6 +23,7 @@ require ( github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index b313bca..611de4e 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -32,6 +34,12 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -40,6 +48,8 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/main.go b/main.go index 7c5a3bb..b670f46 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "time" "github.com/gin-gonic/gin" + retryablehttp "github.com/hashicorp/go-retryablehttp" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/openai" @@ -140,9 +141,22 @@ func createLLM() (llms.Model, error) { if host == "" { host = "http://127.0.0.1:11434" } + // custom http client (retryable http client) if bearer token is wanted + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = 10 + bearerToken := os.Getenv("OLLAMA_BEARER_TOKEN") + if bearerToken != "" { + retryClient.RequestLogHook = func(l retryablehttp.Logger, r *http.Request, i int) { + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) + shortenedBearerToken := fmt.Sprintf("%s...", r.Header.Get("Authorization")[:5]) + log.Printf("Request with bearer %s token to %s %s", shortenedBearerToken, r.Method, r.URL) + } + } + return ollama.New( ollama.WithModel(llmModel), ollama.WithServerURL(host), + ollama.WithHTTPClient(retryClient.StandardClient()), ) default: return nil, fmt.Errorf("unsupported LLM provider: %s", llmProvider) From 12c9819b4f1014662cd27a492aa51ed9518f8f30 Mon Sep 17 00:00:00 2001 From: Icereed Date: Fri, 4 Oct 2024 08:25:38 +0200 Subject: [PATCH 2/2] Fix code scanning alert no. 2: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- Dockerfile | 2 +- http_client_bearer.go | 34 ++++++++++++++++++++++++ http_client_bearer_test.go | 54 ++++++++++++++++++++++++++++++++++++++ main.go | 24 ++++++++--------- 4 files changed, 101 insertions(+), 13 deletions(-) create mode 100644 http_client_bearer.go create mode 100644 http_client_bearer_test.go diff --git a/Dockerfile b/Dockerfile index 13b8366..b4b81ec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ RUN go mod download COPY . . # Build the Go binary -RUN CGO_ENABLED=0 GOOS=linux go build -o paperless-gpt main.go +RUN CGO_ENABLED=0 GOOS=linux go build -o paperless-gpt . # Stage 2: Build Vite frontend FROM node:20 AS frontend diff --git a/http_client_bearer.go b/http_client_bearer.go new file mode 100644 index 0000000..09c479e --- /dev/null +++ b/http_client_bearer.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "net/http" +) + +// HttpTransportWithBearer wraps the default RoundTripper to add the Authorization header. +type HttpTransportWithBearer struct { + BaseTransport http.RoundTripper + Token string +} + +// RoundTrip implements the RoundTripper interface to modify the request. +func (t *HttpTransportWithBearer) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid side effects + reqClone := req.Clone(req.Context()) + + // Add the Authorization header + reqClone.Header.Set("Authorization", fmt.Sprintf("Bearer %s", t.Token)) + + // Use the base RoundTripper to perform the request + return t.BaseTransport.RoundTrip(reqClone) +} + +func NewHttpClientWithBearerTransport(token string) *http.Client { + // Create a new HTTP client with the custom transport + return &http.Client{ + Transport: &HttpTransportWithBearer{ + BaseTransport: http.DefaultTransport, + Token: token, + }, + } +} diff --git a/http_client_bearer_test.go b/http_client_bearer_test.go new file mode 100644 index 0000000..2c90081 --- /dev/null +++ b/http_client_bearer_test.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// TestHttpClientWithBearerTransport tests the addition of the Authorization header. +func TestHttpClientWithBearerTransport(t *testing.T) { + // Define the expected Bearer token + token := "test_bearer_token" + + // Set up a test HTTP server + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Retrieve the Authorization header from the request + authHeader := r.Header.Get("Authorization") + expectedHeader := fmt.Sprintf("Bearer %s", token) + + // Check if the Authorization header matches the expected value + if authHeader != expectedHeader { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Return a success response + w.WriteHeader(http.StatusOK) + io.WriteString(w, "Success") + })) + defer testServer.Close() + + // Create an HTTP client with the custom transport + client := NewHttpClientWithBearerTransport(token) + + // Create a new HTTP request to the test server + req, err := http.NewRequest("GET", testServer.URL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Perform the request using the custom client + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + // Check if the status code is 200 OK + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code 200 OK, got %d", resp.StatusCode) + } +} diff --git a/main.go b/main.go index b670f46..19703ff 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,6 @@ import ( "time" "github.com/gin-gonic/gin" - retryablehttp "github.com/hashicorp/go-retryablehttp" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/openai" @@ -141,22 +140,23 @@ func createLLM() (llms.Model, error) { if host == "" { host = "http://127.0.0.1:11434" } - // custom http client (retryable http client) if bearer token is wanted - retryClient := retryablehttp.NewClient() - retryClient.RetryMax = 10 + ollamaOptions := []ollama.Option{ + ollama.WithModel(llmModel), + ollama.WithServerURL(host), + } bearerToken := os.Getenv("OLLAMA_BEARER_TOKEN") if bearerToken != "" { - retryClient.RequestLogHook = func(l retryablehttp.Logger, r *http.Request, i int) { - r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) - shortenedBearerToken := fmt.Sprintf("%s...", r.Header.Get("Authorization")[:5]) - log.Printf("Request with bearer %s token to %s %s", shortenedBearerToken, r.Method, r.URL) - } + log.Println("Using bearer token for OLLAMA authentication") + ollamaOptions = append( + ollamaOptions, + ollama.WithHTTPClient( + NewHttpClientWithBearerTransport(bearerToken), + ), + ) } return ollama.New( - ollama.WithModel(llmModel), - ollama.WithServerURL(host), - ollama.WithHTTPClient(retryClient.StandardClient()), + ollamaOptions..., ) default: return nil, fmt.Errorf("unsupported LLM provider: %s", llmProvider)