Skip to content

Commit 23632f6

Browse files
committed
feat: add custom HTTP transport with headers for OpenAI client
Closes #237
1 parent 1647219 commit 23632f6

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

main.go

+44-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ import (
88
"paperless-gpt/ocr"
99
"path/filepath"
1010
"runtime"
11-
"strconv"
1211
"slices"
13-
"strings"
12+
"strconv"
13+
"strings"
1414
"sync"
1515
"text/template"
1616
"time"
@@ -639,9 +639,23 @@ func createLLM() (llms.Model, error) {
639639
if openaiAPIKey == "" {
640640
return nil, fmt.Errorf("OpenAI API key is not set")
641641
}
642+
643+
// Create custom transport that adds headers
644+
customTransport := &headerTransport{
645+
transport: http.DefaultTransport,
646+
headers: map[string]string{
647+
"X-Title": "paperless-gpt",
648+
},
649+
}
650+
651+
// Create custom client with the transport
652+
httpClient := http.DefaultClient
653+
httpClient.Transport = customTransport
654+
642655
return openai.New(
643656
openai.WithModel(llmModel),
644657
openai.WithToken(openaiAPIKey),
658+
openai.WithHTTPClient(httpClient),
645659
)
646660
case "ollama":
647661
host := os.Getenv("OLLAMA_HOST")
@@ -663,9 +677,23 @@ func createVisionLLM() (llms.Model, error) {
663677
if openaiAPIKey == "" {
664678
return nil, fmt.Errorf("OpenAI API key is not set")
665679
}
680+
681+
// Create custom transport that adds headers
682+
customTransport := &headerTransport{
683+
transport: http.DefaultTransport,
684+
headers: map[string]string{
685+
"X-Title": "paperless-gpt",
686+
},
687+
}
688+
689+
// Create custom client with the transport
690+
httpClient := http.DefaultClient
691+
httpClient.Transport = customTransport
692+
666693
return openai.New(
667694
openai.WithModel(visionLlmModel),
668695
openai.WithToken(openaiAPIKey),
696+
openai.WithHTTPClient(httpClient),
669697
)
670698
case "ollama":
671699
host := os.Getenv("OLLAMA_HOST")
@@ -681,3 +709,17 @@ func createVisionLLM() (llms.Model, error) {
681709
return nil, nil
682710
}
683711
}
712+
713+
// headerTransport is a custom http.RoundTripper that adds custom headers to requests
714+
type headerTransport struct {
715+
transport http.RoundTripper
716+
headers map[string]string
717+
}
718+
719+
// RoundTrip implements the http.RoundTripper interface
720+
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
721+
for key, value := range t.headers {
722+
req.Header.Add(key, value)
723+
}
724+
return t.transport.RoundTrip(req)
725+
}

0 commit comments

Comments
 (0)