diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go new file mode 100644 index 000000000..87f316487 --- /dev/null +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -0,0 +1,159 @@ +package databricksclientsllama31 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/databricks" +) + +// LlamaPayload represents the payload structure for the Llama model. +type Llama31 struct{} + +var _ databricks.Model = (*Llama31)(nil) + +// NewLlama31 creates a new Llama31 instance. +func NewLlama31() *Llama31 { + return &Llama31{} +} + +// FormatPayload implements databricks.Model to convert langchaingo llms.MessageContent to llama payload. +func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) { + // Initialize payload options with defaults + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + // Transform llms.MessageContent to LlamaMessage + llamaMessages := []LlamaMessage{} + for _, msg := range messages { + var contentBuilder strings.Builder + for _, part := range msg.Parts { + switch p := part.(type) { + case llms.TextContent: + contentBuilder.WriteString(p.Text) + case llms.ImageURLContent: + contentBuilder.WriteString(fmt.Sprintf("[Image: %s]", p.URL)) + case llms.BinaryContent: + contentBuilder.WriteString(fmt.Sprintf("[Binary Content: %s]", p.MIMEType)) + default: + return nil, fmt.Errorf("unsupported content part type: %T", p) + } + } + + llamaMessages = append(llamaMessages, LlamaMessage{ + Role: MapRole(msg.Role), + Content: contentBuilder.String(), + }) + } + + // Construct the LlamaPayload + payload := LlamaPayload{ + Model: "llama-3.1", + Messages: llamaMessages, + Temperature: opts.Temperature, + MaxTokens: opts.MaxTokens, + TopP: opts.TopP, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + Stop: opts.StopWords, // Add stop sequences if needed + } + + if opts.StreamingFunc != nil { + payload.Stream = true + } + + // Serialize to JSON + jsonPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + return jsonPayload, nil +} + +// FormatResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure. +func (l *Llama31) FormatResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + return formatResponse[LlamaChoice](response) +} + +// FormatStreamResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure. +func (l *Llama31) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + // The "data:" prefix is commonly used in Server-Sent Events (SSE) or streaming APIs + // to delimit individual chunks of data being sent from the server. It indicates + // that the following text is a data payload. Before parsing the JSON, we remove + // this prefix to work with the raw JSON payload. + response = bytes.TrimPrefix(response, []byte("data: ")) + + if string(response) == "[DONE]" || len(response) == 0 { + return &llms.ContentResponse{ + Choices: []*llms.ContentChoice{{ + Content: "", + }}, + }, nil + } + return formatResponse[LlamaChoiceDelta](response) +} + +func formatResponse[T LlamaChoiceDelta | LlamaChoice](response []byte) (*llms.ContentResponse, error) { + // Parse the LlamaResponse JSON + var llamaResp LlamaResponse[T] + err := json.Unmarshal(response, &llamaResp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal LlamaResponse: %w", err) + } + + // Initialize ContentResponse + contentResponse := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{}, + } + + // Map LlamaResponse choices to ContentChoice + for _, llamaChoice := range llamaResp.Choices { + var contentChoice *llms.ContentChoice + switch choice := any(llamaChoice).(type) { + case LlamaChoice: + contentChoice = &llms.ContentChoice{ + Content: choice.Message.Content, + StopReason: choice.FinishReason, + GenerationInfo: map[string]any{ + "index": choice.Index, + }, + } + + // If the LlamaMessage indicates a function/tool call, populate FuncCall or ToolCalls + if choice.Message.Role == RoleIPython { + funcCall := &llms.FunctionCall{ + Name: "tool_function_name", // Replace with actual function name if included in response + Arguments: choice.Message.Content, // Replace with parsed arguments if available + } + contentChoice.FuncCall = funcCall + contentChoice.ToolCalls = []llms.ToolCall{ + { + ID: fmt.Sprintf("tool-call-%d", choice.Index), + Type: "function", + FunctionCall: funcCall, + }, + } + } + case LlamaChoiceDelta: + contentChoice = &llms.ContentChoice{ + Content: choice.Delta.Content, + StopReason: choice.FinishReason, + GenerationInfo: map[string]any{ + "index": choice.Index, + }, + } + } + + // Append the ContentChoice to the ContentResponse + contentResponse.Choices = append(contentResponse.Choices, contentChoice) + } + + return contentResponse, nil +} diff --git a/llms/databricks/clients/llama/v3.1/map_role.go b/llms/databricks/clients/llama/v3.1/map_role.go new file mode 100644 index 000000000..edc760aab --- /dev/null +++ b/llms/databricks/clients/llama/v3.1/map_role.go @@ -0,0 +1,23 @@ +package databricksclientsllama31 + +import ( + "github.com/tmc/langchaingo/llms" +) + +// MapRole maps ChatMessageType to LlamaRole. +func MapRole(chatRole llms.ChatMessageType) Role { + switch chatRole { + case llms.ChatMessageTypeAI: + return RoleAssistant + case llms.ChatMessageTypeHuman: + return RoleUser + case llms.ChatMessageTypeSystem: + return RoleSystem + case llms.ChatMessageTypeFunction, llms.ChatMessageTypeTool: + return RoleIPython // Mapping tools and functions to ipython + case llms.ChatMessageTypeGeneric: + return RoleUser // Defaulting generic to user + default: + return Role(chatRole) + } +} diff --git a/llms/databricks/clients/llama/v3.1/types.go b/llms/databricks/clients/llama/v3.1/types.go new file mode 100644 index 000000000..f9004eb30 --- /dev/null +++ b/llms/databricks/clients/llama/v3.1/types.go @@ -0,0 +1,65 @@ +package databricksclientsllama31 + +type Role string + +const ( + RoleSystem Role = "system" // The system role provides instructions or context for the model + RoleUser Role = "user" // The user role represents inputs from the user + RoleAssistant Role = "assistant" // The assistant role represents responses from the model + RoleIPython Role = "ipython" // The ipython role represents responses from the model +) + +// LlamaMessage represents a message in the LLM. +type LlamaMessage struct { + Role Role `json:"role"` // Role of the message sender (e.g., "system", "user", "assistant") + Content string `json:"content"` // The content of the message +} + +// LlamaMessageDelta represents a message streamed by the LLM. +type LlamaMessageDelta struct { + Content string `json:"content"` // The content of the message +} + +// LlamaPayload represents the payload structure for the Llama model. +type LlamaPayload struct { + Model string `json:"model"` // Model to use (e.g., "llama-3.1") + Messages []LlamaMessage `json:"messages"` // List of structured messages + Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (0 to 1) + MaxTokens int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate + TopP float64 `json:"top_p,omitempty"` // Top-p (nucleus) sampling + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // Penalizes new tokens based on frequency + PresencePenalty float64 `json:"presence_penalty,omitempty"` // Penalizes tokens based on presence + Stop []string `json:"stop,omitempty"` // List of stop sequences to end generation + Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming +} + +// LlamaResponse represents the response structure for the Llama model. (full answer or streamed one). +type LlamaResponse[T LlamaChoice | LlamaChoiceDelta] struct { + ID string `json:"id"` // Unique ID of the response + Object string `json:"object"` // Type of response (e.g., "chat.completion") + Created int64 `json:"created"` // Timestamp of creation + Model string `json:"model"` // Model used (e.g., "llama-3.1") + Choices []T `json:"choices"` // List of response choices + Usage LlamaUsage `json:"usage"` // Token usage details +} + +// LlamaChoice represents a choice in the Llama response. +type LlamaChoice struct { + Index int `json:"index"` // Index of the choice + Message LlamaMessage `json:"message"` // The message content + FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") +} + +// LlamaChoiceDelta represents a choice in the Llama response. +type LlamaChoiceDelta struct { + Index int `json:"index"` // Index of the choice + Delta LlamaMessageDelta `json:"delta"` // The message content + FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") +} + +// LlamaUsage represents the token usage details of a response. +type LlamaUsage struct { + PromptTokens int `json:"prompt_tokens"` // Tokens used for the prompt + CompletionTokens int `json:"completion_tokens"` // Tokens used for the completion + TotalTokens int `json:"total_tokens"` // Total tokens used +} diff --git a/llms/databricks/clients/mistral/v1/map_role.go b/llms/databricks/clients/mistral/v1/map_role.go new file mode 100644 index 000000000..4abde22f2 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/map_role.go @@ -0,0 +1,20 @@ +package databricksclientsmistralv1 + +import "github.com/tmc/langchaingo/llms" + +// mapRole maps llms.ChatMessageType to Role. +// Map function. +func MapRole(chatRole llms.ChatMessageType) Role { + switch chatRole { + case llms.ChatMessageTypeAI: + return RoleAssistant + case llms.ChatMessageTypeHuman, llms.ChatMessageTypeGeneric: + return RoleUser + case llms.ChatMessageTypeSystem: + return RoleSystem + case llms.ChatMessageTypeTool, llms.ChatMessageTypeFunction: + return RoleTool + default: + return Role(chatRole) + } +} diff --git a/llms/databricks/clients/mistral/v1/mistralv1.go b/llms/databricks/clients/mistral/v1/mistralv1.go new file mode 100644 index 000000000..a79246511 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/mistralv1.go @@ -0,0 +1,181 @@ +package databricksclientsmistralv1 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/databricks" +) + +// Mistral1 represents the payload structure for the Mistral model. +type Mistral1 struct { + Model string `json:"model"` +} + +var _ databricks.Model = (*Mistral1)(nil) + +// NewMistral1 creates a new Mistral1 instance. +func NewMistral1(model string) *Mistral1 { + return &Mistral1{ + Model: model, + } +} + +// FormatPayload implements databricks.Model to convert langchaingo llms.MessageContent to llama payload. +func (l *Mistral1) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) { + // Initialize payload options with defaults + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + // Convert langchaingo MessageContent to Mistral ChatMessage + var mistralMessages []ChatMessage // nolint: prealloc + for _, msg := range messages { + // Process parts to handle the actual content and tool calls correctly + var toolCalls []ToolCall + var contentParts []string + for _, part := range msg.Parts { + switch p := part.(type) { + case llms.TextContent: + // Text parts go directly as content + contentParts = append(contentParts, p.Text) + case llms.ImageURLContent: + // Append structured description for the image + contentParts = append(contentParts, fmt.Sprintf("Image: %s", p.URL)) + if p.Detail != "" { + contentParts = append(contentParts, fmt.Sprintf("Detail: %s", p.Detail)) + } + case llms.ToolCall: + // Convert tool calls into structured Mistral ToolCall objects + toolCalls = append(toolCalls, ToolCall{ + ID: p.ID, + Type: ToolTypeFunction, // Assuming ToolTypeFunction + Function: FunctionCall{ + Name: p.FunctionCall.Name, + Arguments: p.FunctionCall.Arguments, + }, + }) + case llms.ToolCallResponse: + // Handle tool call responses as content + contentParts = append(contentParts, p.Content) + default: + return nil, fmt.Errorf("unknown content part type: %T", part) + } + } + + mistralMessage := ChatMessage{ + Role: MapRole(msg.Role), + Content: fmt.Sprintf("%s", contentParts), + ToolCalls: toolCalls, + } + mistralMessages = append(mistralMessages, mistralMessage) + } + + // Handle options (example: temperature, max_tokens, etc.) + payload := ChatCompletionPayload{ + Model: "mistral-7b", // Replace with the desired model + Messages: mistralMessages, + Temperature: opts.Temperature, + MaxTokens: opts.MaxTokens, + TopP: opts.TopP, + RandomSeed: opts.Seed, + } + + if opts.StreamingFunc != nil { + payload.Stream = true + } + + // Marshal the payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + return payloadBytes, nil +} + +// Refactored FormatResponse. +func (l *Mistral1) FormatResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + var resp ChatCompletionResponse + if err := json.Unmarshal(response, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &llms.ContentResponse{ + Choices: mapChoices(resp.Choices), + }, nil +} + +// Refactored FormatStreamResponse. +func (l *Mistral1) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + // The "data:" prefix is commonly used in Server-Sent Events (SSE) or streaming APIs + // to delimit individual chunks of data being sent from the server. It indicates + // that the following text is a data payload. Before parsing the JSON, we remove + // this prefix to work with the raw JSON payload. + response = bytes.TrimPrefix(response, []byte("data: ")) + + if string(response) == "[DONE]" || len(response) == 0 { + return &llms.ContentResponse{ + Choices: []*llms.ContentChoice{{ + Content: "", + }}, + }, nil + } + + var streamResp ChatCompletionStreamResponse + if err := json.Unmarshal(response, &streamResp); err != nil { + return nil, fmt.Errorf("failed to parse streaming response: %w", err) + } + + return &llms.ContentResponse{ + Choices: mapChoices(streamResp.Choices), + }, nil +} + +// Helper function to map choices. +func mapChoices[T ChatCompletionResponseChoice | ChatCompletionResponseChoiceStream](choices []T) []*llms.ContentChoice { + var contentChoices []*llms.ContentChoice // nolint: prealloc + + for _, choice := range choices { + var index int + var message ChatMessage + var finishReason FinishReason + switch c := any(choice).(type) { + case ChatCompletionResponseChoice: + index = c.Index + message = c.Message + finishReason = c.FinishReason + case ChatCompletionResponseChoiceStream: + index = c.Index + message = c.Delta + finishReason = c.FinishReason + } + + contentChoice := &llms.ContentChoice{ + Content: message.Content, + StopReason: string(finishReason), + GenerationInfo: map[string]any{ + "index": index, + }, + } + + for _, toolCall := range message.ToolCalls { + contentChoice.ToolCalls = append(contentChoice.ToolCalls, llms.ToolCall{ + ID: toolCall.ID, + Type: string(toolCall.Type), + FunctionCall: &llms.FunctionCall{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + }) + } + + contentChoices = append(contentChoices, contentChoice) + } + + return contentChoices +} diff --git a/llms/databricks/clients/mistral/v1/types.go b/llms/databricks/clients/mistral/v1/types.go new file mode 100644 index 000000000..7d8f105ea --- /dev/null +++ b/llms/databricks/clients/mistral/v1/types.go @@ -0,0 +1,82 @@ +package databricksclientsmistralv1 + +// Role the role of the chat message. +type Role string + +// Role the role of the chat message. +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleSystem Role = "system" + RoleTool Role = "tool" +) + +// FinishReason the reason that a chat message was finished. +type FinishReason string + +// FinishReason the reason that a chat message was finished. +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonError FinishReason = "error" +) + +// ResponseFormat the format that the response must adhere to. +type ResponseFormat string + +// ResponseFormat the format that the response must adhere to. +const ( + ResponseFormatText ResponseFormat = "text" + ResponseFormatJSONObject ResponseFormat = "json_object" +) + +// ToolType type of tool defined for the llm. +type ToolType string + +// ToolType type of tool defined for the llm. +const ( + ToolTypeFunction ToolType = "function" +) + +// ToolChoice the choice of tool to use. +type ToolChoice string + +// ToolChoice the choice of tool to use. +const ( + ToolChoiceAny ToolChoice = "any" + ToolChoiceAuto ToolChoice = "auto" + ToolChoiceNone ToolChoice = "none" +) + +// Tool definition of a tool that the llm can call. +type Tool struct { + Type ToolType `json:"type"` + Function Function `json:"function"` +} + +// Function definition of a function that the llm can call including its parameters. +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +// FunctionCall represents a request to call an external tool by the llm. +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ToolCall represents the call to a tool by the llm. +type ToolCall struct { + ID string `json:"id"` + Type ToolType `json:"type"` + Function FunctionCall `json:"function"` +} + +// ChatMessage represents a single message in a chat. +type ChatMessage struct { + Role Role `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} diff --git a/llms/databricks/clients/mistral/v1/types_payload.go b/llms/databricks/clients/mistral/v1/types_payload.go new file mode 100644 index 000000000..853a78618 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/types_payload.go @@ -0,0 +1,16 @@ +package databricksclientsmistralv1 + +// ChatCompletionPayload represents the payload for the chat completion request. +type ChatCompletionPayload struct { + Model string `json:"model"` // The model to use for completion + Messages []ChatMessage `json:"messages"` // The messages to use for completion + Temperature float64 `json:"temperature,omitempty"` // The temperature to use for sampling. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or TopP but not both. + MaxTokens int `json:"max_tokens,omitempty"` + TopP float64 `json:"top_p,omitempty"` // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or Temperature but not both. + RandomSeed int `json:"random_seed,omitempty"` + SafePrompt bool `json:"safe_prompt,omitempty"` // Adds a Mistral defined safety message to the system prompt to enforce guardrailing + Tools []Tool `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + ResponseFormat ResponseFormat `json:"response_format,omitempty"` + Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming +} diff --git a/llms/databricks/clients/mistral/v1/types_response.go b/llms/databricks/clients/mistral/v1/types_response.go new file mode 100644 index 000000000..0272bde07 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/types_response.go @@ -0,0 +1,43 @@ +package databricksclientsmistralv1 + +// ChatCompletionResponse represents the response from the chat completion endpoint. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionResponseChoice `json:"choices"` + Usage UsageInfo `json:"usage"` +} + +// ChatCompletionStreamResponse represents the streamed response from the chat completion endpoint. +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Model string `json:"model"` + Choices []ChatCompletionResponseChoiceStream `json:"choices"` + Created int `json:"created,omitempty"` + Object string `json:"object,omitempty"` + Usage UsageInfo `json:"usage,omitempty"` + Error error `json:"error,omitempty"` +} + +// ChatCompletionResponseChoice represents a choice in the chat completion response. +type ChatCompletionResponseChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason FinishReason `json:"finish_reason,omitempty"` +} + +// ChatCompletionResponseChoice represents a choice in the chat completion response. +type ChatCompletionResponseChoiceStream struct { + Index int `json:"index"` + Delta ChatMessage `json:"delta"` + FinishReason FinishReason `json:"finish_reason,omitempty"` +} + +// UsageInfo represents the usage information of a response. +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokens int `json:"completion_tokens,omitempty"` +} diff --git a/llms/databricks/databricksllm.go b/llms/databricks/databricksllm.go new file mode 100644 index 000000000..c38019976 --- /dev/null +++ b/llms/databricks/databricksllm.go @@ -0,0 +1,156 @@ +package databricks + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/tmc/langchaingo/llms" +) + +// Option is a function that applies a configuration to the LLM. +type Option func(*LLM) + +// LLM is a databricks LLM implementation. +type LLM struct { + url string // The constructed or provided URL + token string // The token for authentication + httpClient *http.Client + model Model +} + +var _ llms.Model = (*LLM)(nil) + +// New creates a new llamafile LLM implementation. +func New(model Model, opts ...Option) (*LLM, error) { + llm := &LLM{ + model: model, + } + + // Apply all options to customize the LLM. + for _, opt := range opts { + opt(llm) + } + + // Validate URL + if llm.url == "" { + return nil, fmt.Errorf("URL must be provided or constructed using options") + } + + if llm.httpClient == nil { + if llm.token == "" { + return nil, fmt.Errorf("token must be provided") + } + llm.httpClient = NewHTTPClient(llm.token) + } + + return llm, nil +} + +// Call Implement the call interface for LLM. +func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { + return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen + payload, err := o.model.FormatPayload(ctx, messages, options...) + if err != nil { + return nil, err + } + + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, o.url, bytes.NewBuffer(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := o.httpClient.Do(request) + if err != nil { + return nil, fmt.Errorf("failed to do request: %w", err) + } + defer resp.Body.Close() + + if opts.StreamingFunc != nil { + return o.stream(ctx, resp.Body, opts) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return o.model.FormatResponse(ctx, bodyBytes) +} + +func (o *LLM) stream(ctx context.Context, body io.Reader, opts llms.CallOptions) (*llms.ContentResponse, error) { + fullChoiceContent := []strings.Builder{} + scanner := bufio.NewScanner(body) + finalResponse := &llms.ContentResponse{} + + for scanner.Scan() { + scannedBytes := scanner.Bytes() + if len(scannedBytes) == 0 { + continue + } + + contentResponse, err := o.model.FormatStreamResponse(ctx, scannedBytes) + if err != nil { + return nil, err + } + + if len(contentResponse.Choices) == 0 { + continue + } + + index, err := concatenateAnswers(contentResponse.Choices, &fullChoiceContent) + if err != nil { + return nil, err + } + + if index == nil { + continue + } + + if err := opts.StreamingFunc(ctx, []byte(contentResponse.Choices[*index].Content)); err != nil { + return nil, err + } + + finalResponse = contentResponse + } + + for index := range finalResponse.Choices { + finalResponse.Choices[index].Content = fullChoiceContent[index].String() + } + + return finalResponse, nil +} + +func concatenateAnswers(choices []*llms.ContentChoice, fullChoiceContent *[]strings.Builder) (*int, error) { + var lastModifiedIndex *int + + for choiceIndex := range choices { + if len(*fullChoiceContent) <= choiceIndex { + *fullChoiceContent = append(*fullChoiceContent, strings.Builder{}) + } + + if choices[choiceIndex].Content == "" { + continue + } + + lastModifiedIndex = &choiceIndex + + if _, err := (*fullChoiceContent)[choiceIndex].WriteString(choices[choiceIndex].Content); err != nil { + return lastModifiedIndex, err + } + } + + return lastModifiedIndex, nil +} diff --git a/llms/databricks/databricksllm_test.go b/llms/databricks/databricksllm_test.go new file mode 100644 index 000000000..4f3043373 --- /dev/null +++ b/llms/databricks/databricksllm_test.go @@ -0,0 +1,118 @@ +package databricks_test + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/databricks" + databricksclientsllama31 "github.com/tmc/langchaingo/llms/databricks/clients/llama/v3.1" + databricksclientsmistralv1 "github.com/tmc/langchaingo/llms/databricks/clients/mistral/v1" +) + +func testModelStream(t *testing.T, model databricks.Model, url string) { + t.Helper() + + const envVarToken = "DATABRICKS_TOKEN" + + if os.Getenv(envVarToken) == "" { + t.Skipf("%s not set", envVarToken) + } + + dbllm, err := databricks.New( + model, + databricks.WithFullURL(url), + databricks.WithToken(os.Getenv(envVarToken)), + ) + require.NoError(t, err, "failed to create databricks LLM") + + ctx := context.Background() + resp, err := dbllm.GenerateContent(ctx, + []llms.MessageContent{ + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextContent{Text: "Brazil is a country?"}, + }, + }, + }, + llms.WithStreamingFunc(func(_ context.Context, chunk []byte) error { + require.NotEmpty(t, chunk, "unexpected empty chunk in streaming response") + return nil + }), + ) + require.NoError(t, err, "failed to generate content") + + assert.NotEmpty(t, resp.Choices, "expected at least one choice from model") +} + +func testModel(t *testing.T, model databricks.Model, url string) { + t.Helper() + + const envVarToken = "DATABRICKS_TOKEN" + + if os.Getenv(envVarToken) == "" { + t.Skipf("%s not set", envVarToken) + } + + dbllm, err := databricks.New( + model, + databricks.WithFullURL(url), + databricks.WithToken(os.Getenv(envVarToken)), + ) + require.NoError(t, err, "failed to create databricks LLM") + + ctx := context.Background() + resp, err := dbllm.GenerateContent(ctx, + []llms.MessageContent{ + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextContent{Text: "Brazil is a country?"}, + }, + }, + }, + ) + require.NoError(t, err, "failed to generate content") + + assert.NotEmpty(t, resp.Choices, "expected at least one choice from model") +} + +func TestDatabricksLlama31(t *testing.T) { + t.Parallel() + + const envVar = "DATABRICKS_LLAMA31_URL" + url := os.Getenv(envVar) + if url == "" { + t.Skipf("%s not set", envVar) + } + + llama31 := databricksclientsllama31.NewLlama31() + testModelStream(t, llama31, url) + testModel(t, llama31, url) +} + +func TestDatabricksMistral1(t *testing.T) { + t.Parallel() + + const envVarURL = "DATABRICKS_MISTAL1_URL" + const envVarModel = "DATABRICKS_MISTAL1_MODEL" + + model := os.Getenv(envVarModel) + url := os.Getenv(envVarURL) + + if url == "" { + t.Skipf("%s not set", envVarURL) + } + if model == "" { + t.Skipf("%s not set", envVarModel) + } + + mistral1 := databricksclientsmistralv1.NewMistral1(model) + testModelStream(t, mistral1, url) + testModel(t, mistral1, url) +} diff --git a/llms/databricks/http_client.go b/llms/databricks/http_client.go new file mode 100644 index 000000000..e3b3dc020 --- /dev/null +++ b/llms/databricks/http_client.go @@ -0,0 +1,45 @@ +package databricks + +import ( + "fmt" + "io" + "net/http" +) + +// TokenRoundTripper is a custom RoundTripper that adds a Bearer token to each request. +type TokenRoundTripper struct { + Token string + Transport http.RoundTripper +} + +// RoundTrip executes a single HTTP transaction and adds the Bearer token to the request. +func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.Token) + req.Header.Set("content-type", "application/json") + // Use the underlying transport to perform the request + resp, err := t.Transport.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + + // Handle status codes + if resp.StatusCode >= 400 { + // Read the response body for detailed error message (optional) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() // Ensure the body is closed to avoid resource leaks + + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + return resp, nil +} + +// NewClientWithToken creates a new HTTP client with a Bearer token. +func NewHTTPClient(token string) *http.Client { + return &http.Client{ + Transport: &TokenRoundTripper{ + Token: token, + Transport: http.DefaultTransport, // Use http.DefaultTransport as the fallback + }, + } +} diff --git a/llms/databricks/model.go b/llms/databricks/model.go new file mode 100644 index 000000000..aca213c40 --- /dev/null +++ b/llms/databricks/model.go @@ -0,0 +1,14 @@ +package databricks + +import ( + "context" + + "github.com/tmc/langchaingo/llms" +) + +// Model is the interface that wraps the methods to format the payload and response. +type Model interface { + FormatPayload(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) + FormatResponse(ctx context.Context, response []byte) (*llms.ContentResponse, error) + FormatStreamResponse(ctx context.Context, response []byte) (*llms.ContentResponse, error) +} diff --git a/llms/databricks/options.go b/llms/databricks/options.go new file mode 100644 index 000000000..370c4b2bc --- /dev/null +++ b/llms/databricks/options.go @@ -0,0 +1,34 @@ +package databricks + +import ( + "fmt" + "net/http" +) + +// WithFullURL sets the full URL for the LLM. +func WithFullURL(fullURL string) Option { + return func(llm *LLM) { + llm.url = fullURL + } +} + +// WithURLComponents constructs the URL from individual components. +func WithURLComponents(databricksInstance, modelName, modelVersion string) Option { + return func(llm *LLM) { + llm.url = fmt.Sprintf("https://%s/model/%s/%s/invocations", databricksInstance, modelName, modelVersion) + } +} + +// WithToken pass the token for authentication. +func WithToken(token string) Option { + return func(llm *LLM) { + llm.token = token + } +} + +// WithHTTPClient sets the HTTP client for the LLM. +func WithHTTPClient(client *http.Client) Option { + return func(llm *LLM) { + llm.httpClient = client + } +}