-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement Messages object for Perplexity API with user and assi…
- Loading branch information
Showing
3 changed files
with
155 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package perplexity | ||
|
||
import "fmt" | ||
|
||
// Message is a message object for the Perplexity API. | ||
type Message struct { | ||
Role string `json:"role" validate:"required,oneof=system user assistant"` | ||
Content string `json:"content"` | ||
} | ||
|
||
// Messages is an object that contains a list of messages for the Perplexity API. | ||
type Messages struct { | ||
systemMessage string | ||
messages []Message // A list of messages comprising the conversation so far. | ||
} | ||
|
||
// NewMessages returns a new Messages object. | ||
func NewMessages(opts ...MessagesOption) Messages { | ||
m := Messages{} | ||
for _, opt := range opts { | ||
opt(&m) | ||
} | ||
return m | ||
} | ||
|
||
// MessagesOption is an option for the NewMessages function. | ||
type MessagesOption func(*Messages) | ||
|
||
// WithSystemMessage sets the system message for the Messages object. | ||
func WithSystemMessage(content string) MessagesOption { | ||
return func(m *Messages) { | ||
m.systemMessage = content | ||
} | ||
} | ||
|
||
// AddUserMessage adds a user message to the Messages object. | ||
func (m *Messages) AddUserMessage(content string) error { | ||
if len(m.messages) > 0 { | ||
// Previous message should be an assistant message. | ||
if m.messages[len(m.messages)-1].Role != "assistant" { | ||
return fmt.Errorf("previous message should be an assistant message") | ||
} | ||
} | ||
m.messages = append(m.messages, Message{ | ||
Role: "user", | ||
Content: content, | ||
}) | ||
return nil | ||
} | ||
|
||
// AddAgentMessage adds an assistant message to the Messages object. | ||
func (m *Messages) AddAgentMessage(content string) error { | ||
if len(m.messages) == 0 { | ||
// First message should be a user message. | ||
return fmt.Errorf("first message should be a user message") | ||
} | ||
// Previous message should be a user message. | ||
if m.messages[len(m.messages)-1].Role != "user" { | ||
return fmt.Errorf("previous message should be a user message") | ||
} | ||
m.messages = append(m.messages, Message{ | ||
Role: "assistant", | ||
Content: content, | ||
}) | ||
return nil | ||
} | ||
|
||
func (m *Messages) GetMessages() []Message { | ||
var result []Message | ||
// system message is added in the first position | ||
if m.systemMessage != "" { | ||
result = append(result, Message{ | ||
Role: "system", | ||
Content: m.systemMessage, | ||
}) | ||
} | ||
// user and assistant messages are added in the following positions | ||
result = append(result, m.messages...) | ||
return result | ||
} | ||
|
||
// GetSystemMessage returns the system message. | ||
func (m *Messages) GetSystemMessage() string { | ||
return m.systemMessage | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package perplexity_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/sgaunet/perplexity-go/v2" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestNewMessages(t *testing.T) { | ||
t.Run("creates a new Messages object", func(t *testing.T) { | ||
m := perplexity.NewMessages() | ||
assert.NotNil(t, m) | ||
}) | ||
} | ||
|
||
func TestWithSystemMessage(t *testing.T) { | ||
t.Run("sets the system message for the Messages object", func(t *testing.T) { | ||
m := perplexity.NewMessages(perplexity.WithSystemMessage("system message")) | ||
sysMsg := m.GetSystemMessage() | ||
assert.Equal(t, sysMsg, "system message") | ||
}) | ||
} | ||
|
||
func TestAddUserMessage(t *testing.T) { | ||
t.Run("adds a user message to the Messages object", func(t *testing.T) { | ||
m := perplexity.NewMessages() | ||
err := m.AddUserMessage("hello") | ||
assert.Nil(t, err) | ||
msgs := m.GetMessages() | ||
assert.Equal(t, len(msgs), 1) | ||
assert.Equal(t, msgs[0].Role, "user") | ||
assert.Equal(t, msgs[0].Content, "hello") | ||
}) | ||
} | ||
|
||
func TestAddAgentMessage(t *testing.T) { | ||
t.Run("adds an assistant message to the Messages object", func(t *testing.T) { | ||
m := perplexity.NewMessages() | ||
m.AddUserMessage("hello") | ||
err := m.AddAgentMessage("hello") | ||
assert.Nil(t, err) | ||
msgs := m.GetMessages() | ||
assert.Equal(t, len(msgs), 2) | ||
assert.Equal(t, msgs[1].Role, "assistant") | ||
assert.Equal(t, msgs[1].Content, "hello") | ||
}) | ||
} | ||
|
||
func TestAddTwiceUserMessage(t *testing.T) { | ||
t.Run("adds a user message to the Messages object", func(t *testing.T) { | ||
m := perplexity.NewMessages() | ||
err := m.AddUserMessage("hello") | ||
assert.Nil(t, err) | ||
err = m.AddUserMessage("hello") | ||
assert.NotNil(t, err) | ||
}) | ||
} | ||
|
||
func TestAddTwiceAgentMessage(t *testing.T) { | ||
t.Run("adds an assistant message to the Messages object", func(t *testing.T) { | ||
m := perplexity.NewMessages() | ||
m.AddUserMessage("hello") | ||
err := m.AddAgentMessage("hello") | ||
assert.Nil(t, err) | ||
err = m.AddAgentMessage("hello") | ||
assert.NotNil(t, err) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters