Skip to content

Commit

Permalink
feat: add support for claude v3 & mistral models
Browse files Browse the repository at this point in the history
  • Loading branch information
catpaladin committed Mar 12, 2024
1 parent f72c42f commit 90de74e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 37 deletions.
90 changes: 65 additions & 25 deletions internal/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,42 +75,46 @@ func (m *AWSModelConfig) InvokeModel(ctx context.Context, api ClientRuntimeAPI,

func (m *AWSModelConfig) constructPayload(message string) ([]byte, error) {
switch {
case strings.Contains(m.ModelID, "anthropic"):
case strings.Contains(m.ModelID, "sonnet"):
body := models.ClaudeMessagesInput{
AnthropicVersion: "bedrock-2023-05-31",
Messages: []models.ClaudeMessage{
{
Role: "user",
Content: []models.ClaudeContent{
{
Type: "text",
Text: message,
},
},
},
},
MaxTokens: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
}

payload, err := json.Marshal(body)
if err != nil {
return []byte{}, err
}
return payload, nil
case strings.Contains(m.ModelID, "anthropic"):
body := models.ClaudeModelInputs{
Prompt: fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", message),
MaxTokensToSample: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
}
// TODO: work on v3
//body := models.ClaudeMessagesInput{
// AnthropicVersion: "bedrock-2023-05-31",
// Messages: []models.ClaudeMessage{
// {
// Role: "user",
// Content: []models.ClaudeContent{
// {
// Type: "text",
// Text: message,
// },
// },
// },
// },
// MaxTokens: m.MaxTokens,
// Temperature: m.Temperature,
// TopP: m.TopP,
// TopK: m.TopK,
//}

payload, err := json.Marshal(body)
if err != nil {
return []byte{}, err
}
return payload, nil
case strings.Contains(m.ModelID, "cohere"):

body := models.CommandModelInput{
Prompt: message,
MaxTokensToSample: m.MaxTokens,
Expand All @@ -121,6 +125,26 @@ func (m *AWSModelConfig) constructPayload(message string) ([]byte, error) {
ReturnLiklihoods: "NONE",
NumGenerations: 1,
}

payload, err := json.Marshal(body)
if err != nil {
return []byte{}, err
}
return payload, nil
case strings.Contains(m.ModelID, "mistral"):
// handle the default being higher than the model allows
if m.TopK > 200 {
m.TopK = 200
}

body := models.MistralRequest{
Prompt: message,
MaxTokens: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
}

payload, err := json.Marshal(body)
if err != nil {
return []byte{}, err
Expand All @@ -140,24 +164,40 @@ func (m *AWSModelConfig) processStreamingOutput(output *bedrockruntime.InvokeMod
case *types.ResponseStreamMemberChunk:
// nested switch case for stream outputs. ugh
switch {
case strings.Contains(m.ModelID, "sonnet"):
var resp models.ClaudeMessagesOutput
if err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp); err != nil {
return combinedResult, err
}

if resp.Delta.Type == "text_delta" {
handler(context.Background(), []byte(resp.Delta.Text))
combinedResult += resp.Delta.Text
}
case strings.Contains(m.ModelID, "anthropic"):
var resp models.ClaudeModelOutputs
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp)
if err != nil {
if err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp); err != nil {
return combinedResult, err
}

handler(context.Background(), []byte(resp.Completion))
combinedResult += resp.Completion
case strings.Contains(m.ModelID, "cohere"):
var resp models.CommandModelOutput
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp)
if err != nil {
if err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&resp); err != nil {
return combinedResult, err
}

handler(context.Background(), []byte(resp.Generations[0].Text))
combinedResult += resp.Generations[0].Text
case strings.Contains(m.ModelID, "mistral"):
var resp models.MistralResponse
if err := json.Unmarshal([]byte(string(v.Value.Bytes)), &resp); err != nil {
return combinedResult, err
}

handler(context.Background(), []byte(resp.Outputs[0].Text))
combinedResult += resp.Outputs[0].Text
default:
fmt.Println("Unable to determine AWS Model")
}
Expand Down
23 changes: 11 additions & 12 deletions internal/bedrock/models/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,20 @@ type ClaudeMessage struct {

// ClaudeContent contains content
type ClaudeContent struct {
Type string `json:"type"` // The type of the content. Valid values are image and text.
Text string `json:"text,omitempty"` // The text content.
Source []ClaudeSource `json:"source,omitempty"` // The content of the conversation turn.
}

// ClaudeSource contains source
type ClaudeSource struct {
Type string `json:"type"` // The encoding type for the image. You can specify base64.
MediaType string `json:"media_type"` // The type of the image. You can specify the following image formats.
Data string `json:"data"` // The base64 encoded image bytes for the image. The maximum image size is 3.75MB.
Type string `json:"type"` // The type of the content. Valid values are image and text.
Text string `json:"text,omitempty"` // The text content.
}

// ClaudeMessagesOutput is needed to unmarshal the new request type
// Supported Models: claude-instant-v1.2, claude-v2, claude-v2.1, claude-v3
type ClaudeMessagesOutput struct {
Content []ClaudeContent `json:"content"` // The content generated by the model.
StopReason string `json:"stop_reason"` // The reason why Anthropic Claude stopped generating the response.
Type string `json:"type"` // The type of the response. Valid values are image and text.
Index int `json:"index"` // The index of the response.
Delta TextDelta `json:"delta"` // The delta of the response.
}

// TextDelta contains type and text
type TextDelta struct {
Type string `json:"type"`
Text string `json:"text"`
}
22 changes: 22 additions & 0 deletions internal/bedrock/models/mistral.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Package models contains structs on model requests/responses
package models

// MistralRequest contains the request needed for mistral models
type MistralRequest struct {
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
}

// MistralResponse contains the response obtained from mistral models
type MistralResponse struct {
Outputs []MistralOutput `json:"outputs"`
}

// MistralOutput contains the response text and stop response
type MistralOutput struct {
Text string `json:"text"`
StopResponse string `json:"stop_response"`
}

0 comments on commit 90de74e

Please sign in to comment.