Skip to content

Commit

Permalink
Merge pull request #545 from mheers/main
Browse files Browse the repository at this point in the history
huggingface: allows to set custom inference endpoint
  • Loading branch information
tmc authored Jan 25, 2024
2 parents 2545ace + 288cc58 commit 56b3d1c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
3 changes: 2 additions & 1 deletion llms/huggingface/huggingfacellm.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func New(opts ...Option) (*LLM, error) {
options := &options{
token: os.Getenv(tokenEnvVarName),
model: defaultModel,
url: defaultURL,
}

for _, opt := range opts {
Expand All @@ -86,7 +87,7 @@ func New(opts ...Option) (*LLM, error) {
return nil, ErrMissingToken
}

c, err := huggingfaceclient.New(options.token, options.model)
c, err := huggingfaceclient.New(options.token, options.model, options.url)
if err != nil {
return nil, err
}
Expand Down
10 changes: 10 additions & 0 deletions llms/huggingface/huggingfacellm_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package huggingface
const (
tokenEnvVarName = "HUGGINGFACEHUB_API_TOKEN"
defaultModel = "gpt2"
defaultURL = "https://api-inference.huggingface.co"
)

type options struct {
token string
model string
url string
}

type Option func(*options)
Expand All @@ -27,3 +29,11 @@ func WithModel(model string) Option {
opts.model = model
}
}

// WithURL passes the HuggingFace url to the client. If not set, then will be
// used default url.
func WithURL(url string) Option {
return func(opts *options) {
opts.url = url
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,20 @@ var (
ErrEmptyResponse = errors.New("empty response")
)

const huggingfaceAPIBaseURL = "https://api-inference.huggingface.co"

type Client struct {
Token string
Model string
url string
}

func New(token string, model string) (*Client, error) {
func New(token, model, url string) (*Client, error) {
if token == "" {
return nil, ErrInvalidToken
}
return &Client{
Token: token,
Model: model,
url: huggingfaceAPIBaseURL,
url: url,
}, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ func TestRunInference(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
client, err := New("token", "model")
client, err := New("token", "model", server.URL)
require.NoError(t, err)
// Override the URL to point to our mock server.
client.url = server.URL

resp, err := client.RunInference(context.TODO(), tc.req)
assert.Equal(t, tc.expected, resp)
Expand Down

0 comments on commit 56b3d1c

Please sign in to comment.