Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks LLMs service endpoints implementation #1074

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
163 changes: 163 additions & 0 deletions llms/databricks/clients/llama/v3.1/llama31.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package databricksclientsllama31

import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/databricks"
)

// LlamaPayload represents the payload structure for the Llama model.
type Llama31 struct{}

var _ databricks.Model = (*Llama31)(nil)

// NewLlama31 creates a new Llama31 instance.
func NewLlama31() *Llama31 {
return &Llama31{}
}

// FormatPayload implements databricks.Model to convert langchaingo llms.MessageContent to llama payload.
func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) {
// Initialize payload options with defaults
opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}

// Transform llms.MessageContent to LlamaMessage
llamaMessages := []LlamaMessage{}
for _, msg := range messages {
var contentBuilder strings.Builder
for _, part := range msg.Parts {
switch p := part.(type) {
case llms.TextContent:
contentBuilder.WriteString(p.Text)
case llms.ImageURLContent:
contentBuilder.WriteString(fmt.Sprintf("[Image: %s]", p.URL))
case llms.BinaryContent:
contentBuilder.WriteString(fmt.Sprintf("[Binary Content: %s]", p.MIMEType))
default:
return nil, fmt.Errorf("unsupported content part type: %T", p)
}
}

llamaMessages = append(llamaMessages, LlamaMessage{
Role: MapRole(msg.Role),
Content: contentBuilder.String(),
})
}

// Construct the LlamaPayload
payload := LlamaPayload{
Model: "llama-3.1",
Messages: llamaMessages,
Temperature: opts.Temperature,
MaxTokens: opts.MaxTokens,
TopP: opts.TopP,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,
Stop: opts.StopWords, // Add stop sequences if needed
}

if opts.StreamingFunc != nil {
payload.Stream = true
}

// Serialize to JSON
jsonPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}

return jsonPayload, nil
}

// FormatResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure.
func (l *Llama31) FormatResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) {
return formatResponse[LlamaChoice](response)
}

// FormatStreamResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure.
func (l *Llama31) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) {
// The "data:" prefix is commonly used in Server-Sent Events (SSE) or streaming APIs
// to delimit individual chunks of data being sent from the server. It indicates
// that the following text is a data payload. Before parsing the JSON, we remove
// this prefix to work with the raw JSON payload.
response = bytes.TrimPrefix(response, []byte("data: "))

if string(response) == "[DONE]" || len(response) == 0 {
return &llms.ContentResponse{
Choices: []*llms.ContentChoice{{
Content: "",
}},
}, nil
}
return formatResponse[LlamaChoiceDelta](response)
}

func formatResponse[T LlamaChoiceDelta | LlamaChoice](response []byte) (*llms.ContentResponse, error) {
fmt.Printf("response: %+v\n", string(response))

// Parse the LlamaResponse JSON
var llamaResp LlamaResponse[T]
err := json.Unmarshal(response, &llamaResp)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal LlamaResponse: %w", err)
}

// Initialize ContentResponse
contentResponse := &llms.ContentResponse{
Choices: []*llms.ContentChoice{},
}

fmt.Printf("llamaResp: %+v\n", llamaResp)

// Map LlamaResponse choices to ContentChoice
for _, llamaChoice := range llamaResp.Choices {
var contentChoice *llms.ContentChoice
switch choice := any(llamaChoice).(type) {
case LlamaChoice:
contentChoice = &llms.ContentChoice{
Content: choice.Message.Content,
StopReason: choice.FinishReason,
GenerationInfo: map[string]any{
"index": choice.Index,
},
}

// If the LlamaMessage indicates a function/tool call, populate FuncCall or ToolCalls
if choice.Message.Role == RoleIPython {
funcCall := &llms.FunctionCall{
Name: "tool_function_name", // Replace with actual function name if included in response
Arguments: choice.Message.Content, // Replace with parsed arguments if available
}
contentChoice.FuncCall = funcCall
contentChoice.ToolCalls = []llms.ToolCall{
{
ID: fmt.Sprintf("tool-call-%d", choice.Index),
Type: "function",
FunctionCall: funcCall,
},
}
}
case LlamaChoiceDelta:
contentChoice = &llms.ContentChoice{
Content: choice.Delta.Content,
StopReason: choice.FinishReason,
GenerationInfo: map[string]any{
"index": choice.Index,
},
}
}

// Append the ContentChoice to the ContentResponse
contentResponse.Choices = append(contentResponse.Choices, contentChoice)
}

return contentResponse, nil
}
23 changes: 23 additions & 0 deletions llms/databricks/clients/llama/v3.1/map_role.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package databricksclientsllama31

import (
"github.com/tmc/langchaingo/llms"
)

// MapRole maps ChatMessageType to LlamaRole.
func MapRole(chatRole llms.ChatMessageType) Role {
switch chatRole {
case llms.ChatMessageTypeAI:
return RoleAssistant
case llms.ChatMessageTypeHuman:
return RoleUser
case llms.ChatMessageTypeSystem:
return RoleSystem
case llms.ChatMessageTypeFunction, llms.ChatMessageTypeTool:
return RoleIPython // Mapping tools and functions to ipython
case llms.ChatMessageTypeGeneric:
return RoleUser // Defaulting generic to user
default:
return Role(chatRole)
}
}
65 changes: 65 additions & 0 deletions llms/databricks/clients/llama/v3.1/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package databricksclientsllama31

type Role string

const (
RoleSystem Role = "system" // The system role provides instructions or context for the model
RoleUser Role = "user" // The user role represents inputs from the user
RoleAssistant Role = "assistant" // The assistant role represents responses from the model
RoleIPython Role = "ipython" // The ipython role represents responses from the model
)

// LlamaMessage represents a message in the LLM.
type LlamaMessage struct {
Role Role `json:"role"` // Role of the message sender (e.g., "system", "user", "assistant")
Content string `json:"content"` // The content of the message
}

// LlamaMessageDelta represents a message streamed by the LLM.
type LlamaMessageDelta struct {
Content string `json:"content"` // The content of the message
}

// LlamaPayload represents the payload structure for the Llama model.
type LlamaPayload struct {
Model string `json:"model"` // Model to use (e.g., "llama-3.1")
Messages []LlamaMessage `json:"messages"` // List of structured messages
Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (0 to 1)
MaxTokens int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate
TopP float64 `json:"top_p,omitempty"` // Top-p (nucleus) sampling
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // Penalizes new tokens based on frequency
PresencePenalty float64 `json:"presence_penalty,omitempty"` // Penalizes tokens based on presence
Stop []string `json:"stop,omitempty"` // List of stop sequences to end generation
Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming
}

// LlamaResponse represents the response structure for the Llama model. (full answer or streamed one).
type LlamaResponse[T LlamaChoice | LlamaChoiceDelta] struct {
ID string `json:"id"` // Unique ID of the response
Object string `json:"object"` // Type of response (e.g., "chat.completion")
Created int64 `json:"created"` // Timestamp of creation
Model string `json:"model"` // Model used (e.g., "llama-3.1")
Choices []T `json:"choices"` // List of response choices
Usage LlamaUsage `json:"usage"` // Token usage details
}

// LlamaChoice represents a choice in the Llama response.
type LlamaChoice struct {
Index int `json:"index"` // Index of the choice
Message LlamaMessage `json:"message"` // The message content
FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop")
}

// LlamaChoiceDelta represents a choice in the Llama response.
type LlamaChoiceDelta struct {
Index int `json:"index"` // Index of the choice
Delta LlamaMessageDelta `json:"delta"` // The message content
FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop")
}

// LlamaUsage represents the token usage details of a response.
type LlamaUsage struct {
PromptTokens int `json:"prompt_tokens"` // Tokens used for the prompt
CompletionTokens int `json:"completion_tokens"` // Tokens used for the completion
TotalTokens int `json:"total_tokens"` // Total tokens used
}
20 changes: 20 additions & 0 deletions llms/databricks/clients/mistral/v1/map_role.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package databricksclientsmistralv1

import "github.com/tmc/langchaingo/llms"

// mapRole maps llms.ChatMessageType to Role.
// Map function.
func MapRole(chatRole llms.ChatMessageType) Role {
switch chatRole {
case llms.ChatMessageTypeAI:
return RoleAssistant
case llms.ChatMessageTypeHuman, llms.ChatMessageTypeGeneric:
return RoleUser
case llms.ChatMessageTypeSystem:
return RoleSystem
case llms.ChatMessageTypeTool, llms.ChatMessageTypeFunction:
return RoleTool
default:
return Role(chatRole)
}
}
Loading