Skip to content

Commit

Permalink
refactor: update models and default models #18 (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgaunet authored Jan 22, 2025
1 parent c14b01c commit 2f7925d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 42 deletions.
7 changes: 3 additions & 4 deletions perplexity.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ const DefaultEndpoint = "https://api.perplexity.ai/chat/completions"
// DefautTimeout is the default timeout for the HTTP client.
const DefautTimeout = 10 * time.Second

// Llama31SonarSmall128kOnline is the default model for the Perplexity API.
const Llama31SonarSmall128kOnline = "llama-3.1-sonar-small-128k-online"
const Llama31SonarLarge128kOnline = "llama-3.1-sonar-large-128k-online"
const Llama31SonarHuge128kOnline = "llama-3.1-sonar-huge-128k-online"
// DefaultModel is the default model for the Perplexity API.
const DefaultModel = "sonar"
const ProModel = "sonar-pro"

// Client is a client for the Perplexity API.
type Client struct {
Expand Down
24 changes: 8 additions & 16 deletions perplexity_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ var ErrSearchDomainFilter = errors.New("search domain filter must be less than o
var ErrSearchRecencyFilter = errors.New("search recency filter is incompatible with images")

const (
DefaultModel = Llama31SonarSmall128kOnline
DefaultTemperature = 0.2
DefaultTopP = 0.9
DefaultTopK = 0
Expand All @@ -27,7 +26,7 @@ type CompletionRequest struct {
Messages []Message `json:"messages" validate:"required,dive"`
// Model: name of the model that will complete your prompt
// supported model: https://docs.perplexity.ai/guides/model-cards
Model string `json:"model" validate:"required,oneof=llama-3.1-sonar-small-128k-online llama-3.1-sonar-large-128k-online llama-3.1-sonar-huge-128k-online"`
Model string `json:"model" validate:"required,oneof=sonar sonar-pro llama-3.1-sonar-huge-128k-online"`
// MaxTokens: The maximum number of completion tokens returned by the API.
// The total number of tokens requested in max_tokens plus the number of
// prompt tokens sent in messages must not exceed the context window token limit of model requested.
Expand Down Expand Up @@ -109,31 +108,24 @@ func WithMessages(msg []Message) CompletionRequestOption {
}

// WithModel sets the model option (overrides the default model).
// Prefer the use of WithModelLlama31SonarSmall128kOnline, WithModelLlama31SonarLarge128kOnline, or WithModelLlama31SonarHuge128kOnline.
// Prefer the use of WithModelDefaultModel, WithModelProModel, or WithModelHugeModel.
func WithModel(model string) CompletionRequestOption {
return func(r *CompletionRequest) {
r.Model = model
}
}

// WithModelLlama31SonarSmall128kOnline sets the model to llama-3.1-sonar-small-128k-online.
func WithModelLlama31SonarSmall128kOnline() CompletionRequestOption {
// WithModelDefaultModel sets the model to sonar.
func WithDefaultModel() CompletionRequestOption {
return func(r *CompletionRequest) {
r.Model = Llama31SonarSmall128kOnline
r.Model = DefaultModel
}
}

// WithModelLlama31SonarLarge128kOnline sets the model to llama-3.1-sonar-large-128k-online.
func WithModelLlama31SonarLarge128kOnline() CompletionRequestOption {
// WithModelProModel sets the model to sonar-pro.
func WithProModel() CompletionRequestOption {
return func(r *CompletionRequest) {
r.Model = Llama31SonarLarge128kOnline
}
}

// WithModelLlama31SonarHuge128kOnline sets the model to llama-3.1-sonar-huge-128k-online.
func WithModelLlama31SonarHuge128kOnline() CompletionRequestOption {
return func(r *CompletionRequest) {
r.Model = Llama31SonarHuge128kOnline
r.Model = ProModel
}
}

Expand Down
38 changes: 17 additions & 21 deletions perplexity_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@ func TestWithMessages(t *testing.T) {

func TestWithModel(t *testing.T) {
t.Run("creates a new CompletionRequest with model", func(t *testing.T) {
model := "llama-3.1-sonar-small-128k-online"
model := "sonar"
req := perplexity.NewCompletionRequest(perplexity.WithModel(model))
assert.Equal(t, req.Model, model)
})
t.Run("Test WithModelLlama31SonarSmall128kOnline", func(t *testing.T) {
req := perplexity.NewCompletionRequest(perplexity.WithModelLlama31SonarSmall128kOnline())
assert.Equal(t, perplexity.Llama31SonarSmall128kOnline, req.Model)
t.Run("Test WithDefaultModel", func(t *testing.T) {
req := perplexity.NewCompletionRequest(perplexity.WithDefaultModel())
assert.Equal(t, perplexity.DefaultModel, req.Model)
})
t.Run("Test WithModelLlama31SonarLarge128kOnline", func(t *testing.T) {
req := perplexity.NewCompletionRequest(perplexity.WithModelLlama31SonarLarge128kOnline())
assert.Equal(t, perplexity.Llama31SonarLarge128kOnline, req.Model)
})
t.Run("Test WithModelLlama31SonarHuge128kOnline", func(t *testing.T) {
req := perplexity.NewCompletionRequest(perplexity.WithModelLlama31SonarHuge128kOnline())
assert.Equal(t, perplexity.Llama31SonarHuge128kOnline, req.Model)
t.Run("Test WithProModel", func(t *testing.T) {
req := perplexity.NewCompletionRequest(perplexity.WithProModel())
assert.Equal(t, perplexity.ProModel, req.Model)
})
}

Expand Down Expand Up @@ -117,16 +113,16 @@ func TestValidate(t *testing.T) {

f("returns error if no message to send to the API", false)
f("returns error if model is empty", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel(""))
f("returns error if MaxTokens is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithMaxTokens(-1))
f("returns error if Temperature is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTemperature(-1))
f("returns error if TopP is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopP(-1))
f("returns error if TopK is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopK(-1))
f("returns error if TopK is gt 2048", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopK(2049))
f("returns error if Temperature is gt 2", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTemperature(2.1))
f("returns error if TopP is gt 1", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopP(1.1))
f("returns error if SearchDomainFilter contains more than 3 elements", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2", "filter3", "filter4"}))
f("returns error return_images and searchRecencyFilter are set", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnImages(true), perplexity.WithReturnRelatedQuestions(true), perplexity.WithSearchRecencyFilter("filter"), perplexity.WithTopK(10))
f("returns no error", true, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnRelatedQuestions(true), perplexity.WithTopK(10))
f("returns error if MaxTokens is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithMaxTokens(-1))
f("returns error if Temperature is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTemperature(-1))
f("returns error if TopP is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopP(-1))
f("returns error if TopK is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopK(-1))
f("returns error if TopK is gt 2048", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopK(2049))
f("returns error if Temperature is gt 2", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTemperature(2.1))
f("returns error if TopP is gt 1", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopP(1.1))
f("returns error if SearchDomainFilter contains more than 3 elements", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2", "filter3", "filter4"}))
f("returns error return_images and searchRecencyFilter are set", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnImages(true), perplexity.WithReturnRelatedQuestions(true), perplexity.WithSearchRecencyFilter("filter"), perplexity.WithTopK(10))
f("returns no error", true, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnRelatedQuestions(true), perplexity.WithTopK(10))
}

func TestValidateSearchRecencyFilter(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion perplexity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestGetCompletion(t *testing.T) {
defer r.Body.Close()
b, err := io.ReadAll(r.Body)
assert.Nil(t, err)
assert.Equal(t, string(b), `{"messages":[{"role":"user","content":"What's the capital of France?"}],"model":"llama-3.1-sonar-small-128k-online","max_tokens":0,"temperature":0.2,"top_p":0.9,"search_domain_filter":null,"return_images":false,"return_related_questions":false,"search_recency_filter":"","top_k":0,"stream":false,"presence_penalty":0,"frequency_penalty":1}`)
assert.Equal(t, string(b), `{"messages":[{"role":"user","content":"What's the capital of France?"}],"model":"sonar","max_tokens":0,"temperature":0.2,"top_p":0.9,"search_domain_filter":null,"return_images":false,"return_related_questions":false,"search_recency_filter":"","top_k":0,"stream":false,"presence_penalty":0,"frequency_penalty":1}`)
w.Header().Add("Content-Type", "application/json")
fmt.Fprintln(w, "{}")
}))
Expand Down

0 comments on commit 2f7925d

Please sign in to comment.