From 2f7925d6a45b3d60edad730124fb119d95aa7c43 Mon Sep 17 00:00:00 2001 From: Sylvain <1552102+sgaunet@users.noreply.github.com> Date: Wed, 22 Jan 2025 20:52:45 +0100 Subject: [PATCH] refactor: update models and default models #18 (#19) --- perplexity.go | 7 +++---- perplexity_request.go | 24 ++++++++---------------- perplexity_request_test.go | 38 +++++++++++++++++--------------------- perplexity_test.go | 2 +- 4 files changed, 29 insertions(+), 42 deletions(-) diff --git a/perplexity.go b/perplexity.go index 33ac851..f67b9d0 100644 --- a/perplexity.go +++ b/perplexity.go @@ -16,10 +16,9 @@ const DefaultEndpoint = "https://api.perplexity.ai/chat/completions" // DefautTimeout is the default timeout for the HTTP client. const DefautTimeout = 10 * time.Second -// Llama31SonarSmall128kOnline is the default model for the Perplexity API. -const Llama31SonarSmall128kOnline = "llama-3.1-sonar-small-128k-online" -const Llama31SonarLarge128kOnline = "llama-3.1-sonar-large-128k-online" -const Llama31SonarHuge128kOnline = "llama-3.1-sonar-huge-128k-online" +// DefaultModel is the default model for the Perplexity API. +const DefaultModel = "sonar" +const ProModel = "sonar-pro" // Client is a client for the Perplexity API. type Client struct { diff --git a/perplexity_request.go b/perplexity_request.go index 7fb10e5..6def7e2 100644 --- a/perplexity_request.go +++ b/perplexity_request.go @@ -10,7 +10,6 @@ var ErrSearchDomainFilter = errors.New("search domain filter must be less than o var ErrSearchRecencyFilter = errors.New("search recency filter is incompatible with images") const ( - DefaultModel = Llama31SonarSmall128kOnline DefaultTemperature = 0.2 DefaultTopP = 0.9 DefaultTopK = 0 @@ -27,7 +26,7 @@ type CompletionRequest struct { Messages []Message `json:"messages" validate:"required,dive"` // Model: name of the model that will complete your prompt // supported model: https://docs.perplexity.ai/guides/model-cards - Model string `json:"model" validate:"required,oneof=llama-3.1-sonar-small-128k-online llama-3.1-sonar-large-128k-online llama-3.1-sonar-huge-128k-online"` + Model string `json:"model" validate:"required,oneof=sonar sonar-pro llama-3.1-sonar-huge-128k-online"` // MaxTokens: The maximum number of completion tokens returned by the API. // The total number of tokens requested in max_tokens plus the number of // prompt tokens sent in messages must not exceed the context window token limit of model requested. @@ -109,31 +108,24 @@ func WithMessages(msg []Message) CompletionRequestOption { } // WithModel sets the model option (overrides the default model). -// Prefer the use of WithModelLlama31SonarSmall128kOnline, WithModelLlama31SonarLarge128kOnline, or WithModelLlama31SonarHuge128kOnline. +// Prefer the use of WithModelDefaultModel, WithModelProModel, or WithModelHugeModel. func WithModel(model string) CompletionRequestOption { return func(r *CompletionRequest) { r.Model = model } } -// WithModelLlama31SonarSmall128kOnline sets the model to llama-3.1-sonar-small-128k-online. -func WithModelLlama31SonarSmall128kOnline() CompletionRequestOption { +// WithModelDefaultModel sets the model to sonar. +func WithDefaultModel() CompletionRequestOption { return func(r *CompletionRequest) { - r.Model = Llama31SonarSmall128kOnline + r.Model = DefaultModel } } -// WithModelLlama31SonarLarge128kOnline sets the model to llama-3.1-sonar-large-128k-online. -func WithModelLlama31SonarLarge128kOnline() CompletionRequestOption { +// WithModelProModel sets the model to sonar-pro. +func WithProModel() CompletionRequestOption { return func(r *CompletionRequest) { - r.Model = Llama31SonarLarge128kOnline - } -} - -// WithModelLlama31SonarHuge128kOnline sets the model to llama-3.1-sonar-huge-128k-online. -func WithModelLlama31SonarHuge128kOnline() CompletionRequestOption { - return func(r *CompletionRequest) { - r.Model = Llama31SonarHuge128kOnline + r.Model = ProModel } } diff --git a/perplexity_request_test.go b/perplexity_request_test.go index 2a8b164..0f04daa 100644 --- a/perplexity_request_test.go +++ b/perplexity_request_test.go @@ -22,21 +22,17 @@ func TestWithMessages(t *testing.T) { func TestWithModel(t *testing.T) { t.Run("creates a new CompletionRequest with model", func(t *testing.T) { - model := "llama-3.1-sonar-small-128k-online" + model := "sonar" req := perplexity.NewCompletionRequest(perplexity.WithModel(model)) assert.Equal(t, req.Model, model) }) - t.Run("Test WithModelLlama31SonarSmall128kOnline", func(t *testing.T) { - req := perplexity.NewCompletionRequest(perplexity.WithModelLlama31SonarSmall128kOnline()) - assert.Equal(t, perplexity.Llama31SonarSmall128kOnline, req.Model) + t.Run("Test WithDefaultModel", func(t *testing.T) { + req := perplexity.NewCompletionRequest(perplexity.WithDefaultModel()) + assert.Equal(t, perplexity.DefaultModel, req.Model) }) - t.Run("Test WithModelLlama31SonarLarge128kOnline", func(t *testing.T) { - req := perplexity.NewCompletionRequest(perplexity.WithModelLlama31SonarLarge128kOnline()) - assert.Equal(t, perplexity.Llama31SonarLarge128kOnline, req.Model) - }) - t.Run("Test WithModelLlama31SonarHuge128kOnline", func(t *testing.T) { - req := perplexity.NewCompletionRequest(perplexity.WithModelLlama31SonarHuge128kOnline()) - assert.Equal(t, perplexity.Llama31SonarHuge128kOnline, req.Model) + t.Run("Test WithProModel", func(t *testing.T) { + req := perplexity.NewCompletionRequest(perplexity.WithProModel()) + assert.Equal(t, perplexity.ProModel, req.Model) }) } @@ -117,16 +113,16 @@ func TestValidate(t *testing.T) { f("returns error if no message to send to the API", false) f("returns error if model is empty", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("")) - f("returns error if MaxTokens is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithMaxTokens(-1)) - f("returns error if Temperature is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTemperature(-1)) - f("returns error if TopP is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopP(-1)) - f("returns error if TopK is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopK(-1)) - f("returns error if TopK is gt 2048", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopK(2049)) - f("returns error if Temperature is gt 2", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTemperature(2.1)) - f("returns error if TopP is gt 1", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithTopP(1.1)) - f("returns error if SearchDomainFilter contains more than 3 elements", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2", "filter3", "filter4"})) - f("returns error return_images and searchRecencyFilter are set", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnImages(true), perplexity.WithReturnRelatedQuestions(true), perplexity.WithSearchRecencyFilter("filter"), perplexity.WithTopK(10)) - f("returns no error", true, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("llama-3.1-sonar-small-128k-online"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnRelatedQuestions(true), perplexity.WithTopK(10)) + f("returns error if MaxTokens is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithMaxTokens(-1)) + f("returns error if Temperature is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTemperature(-1)) + f("returns error if TopP is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopP(-1)) + f("returns error if TopK is negative", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopK(-1)) + f("returns error if TopK is gt 2048", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopK(2049)) + f("returns error if Temperature is gt 2", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTemperature(2.1)) + f("returns error if TopP is gt 1", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithTopP(1.1)) + f("returns error if SearchDomainFilter contains more than 3 elements", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2", "filter3", "filter4"})) + f("returns error return_images and searchRecencyFilter are set", false, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnImages(true), perplexity.WithReturnRelatedQuestions(true), perplexity.WithSearchRecencyFilter("filter"), perplexity.WithTopK(10)) + f("returns no error", true, perplexity.WithMessages([]perplexity.Message{{Role: "user", Content: "hello"}}), perplexity.WithModel("sonar"), perplexity.WithMaxTokens(10), perplexity.WithTemperature(0.5), perplexity.WithTopP(0.5), perplexity.WithSearchDomainFilter([]string{"filter1", "filter2"}), perplexity.WithReturnRelatedQuestions(true), perplexity.WithTopK(10)) } func TestValidateSearchRecencyFilter(t *testing.T) { diff --git a/perplexity_test.go b/perplexity_test.go index 613fee7..a800f9a 100644 --- a/perplexity_test.go +++ b/perplexity_test.go @@ -48,7 +48,7 @@ func TestGetCompletion(t *testing.T) { defer r.Body.Close() b, err := io.ReadAll(r.Body) assert.Nil(t, err) - assert.Equal(t, string(b), `{"messages":[{"role":"user","content":"What's the capital of France?"}],"model":"llama-3.1-sonar-small-128k-online","max_tokens":0,"temperature":0.2,"top_p":0.9,"search_domain_filter":null,"return_images":false,"return_related_questions":false,"search_recency_filter":"","top_k":0,"stream":false,"presence_penalty":0,"frequency_penalty":1}`) + assert.Equal(t, string(b), `{"messages":[{"role":"user","content":"What's the capital of France?"}],"model":"sonar","max_tokens":0,"temperature":0.2,"top_p":0.9,"search_domain_filter":null,"return_images":false,"return_related_questions":false,"search_recency_filter":"","top_k":0,"stream":false,"presence_penalty":0,"frequency_penalty":1}`) w.Header().Add("Content-Type", "application/json") fmt.Fprintln(w, "{}") }))