From e8d2fa960c7b29a82dd5a1a2040f93d7114ac193 Mon Sep 17 00:00:00 2001 From: Thomas Date: Thu, 21 Nov 2024 09:42:47 -0500 Subject: [PATCH 1/8] databricks llms --- llms/databricks/clients/llama/v3.1/llama31.go | 125 ++++++++++++++ .../databricks/clients/llama/v3.1/map_role.go | 23 +++ llms/databricks/clients/llama/v3.1/types.go | 47 ++++++ .../databricks/clients/mistral/v1/map_role.go | 20 +++ .../clients/mistral/v1/mistralv1.go | 159 ++++++++++++++++++ llms/databricks/clients/mistral/v1/types.go | 76 +++++++++ .../clients/mistral/v1/types_payload.go | 15 ++ .../clients/mistral/v1/types_response.go | 43 +++++ llms/databricks/databricksllm.go | 123 ++++++++++++++ llms/databricks/databricksllm_test.go | 81 +++++++++ llms/databricks/http_client.go | 45 +++++ llms/databricks/model.go | 13 ++ llms/databricks/options.go | 34 ++++ 13 files changed, 804 insertions(+) create mode 100644 llms/databricks/clients/llama/v3.1/llama31.go create mode 100644 llms/databricks/clients/llama/v3.1/map_role.go create mode 100644 llms/databricks/clients/llama/v3.1/types.go create mode 100644 llms/databricks/clients/mistral/v1/map_role.go create mode 100644 llms/databricks/clients/mistral/v1/mistralv1.go create mode 100644 llms/databricks/clients/mistral/v1/types.go create mode 100644 llms/databricks/clients/mistral/v1/types_payload.go create mode 100644 llms/databricks/clients/mistral/v1/types_response.go create mode 100644 llms/databricks/databricksllm.go create mode 100644 llms/databricks/databricksllm_test.go create mode 100644 llms/databricks/http_client.go create mode 100644 llms/databricks/model.go create mode 100644 llms/databricks/options.go diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go new file mode 100644 index 000000000..0e83a16a4 --- /dev/null +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -0,0 +1,125 @@ +package databricksclientsllama31 + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/databricks" +) + +type Llama31 struct{} + +var _ databricks.Model = (*Llama31)(nil) + +func NewLlama31() *Llama31 { + return &Llama31{} +} + +func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) { + // Initialize payload options with defaults + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + // Transform llms.MessageContent to LlamaMessage + llamaMessages := []LlamaMessage{} + for _, msg := range messages { + var contentBuilder strings.Builder + for _, part := range msg.Parts { + switch p := part.(type) { + case llms.TextContent: + contentBuilder.WriteString(p.Text) + case llms.ImageURLContent: + contentBuilder.WriteString(fmt.Sprintf("[Image: %s]", p.URL)) + case llms.BinaryContent: + contentBuilder.WriteString(fmt.Sprintf("[Binary Content: %s]", p.MIMEType)) + default: + return nil, fmt.Errorf("unsupported content part type: %T", p) + } + } + + llamaMessages = append(llamaMessages, LlamaMessage{ + Role: MapRole(msg.Role), + Content: contentBuilder.String(), + }) + } + + // Construct the LlamaPayload + payload := LlamaPayload{ + Model: "llama-3.1", + Messages: llamaMessages, + Temperature: opts.Temperature, + MaxTokens: opts.MaxTokens, + TopP: opts.TopP, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + Stop: opts.StopWords, // Add stop sequences if needed + } + + // Serialize to JSON + jsonPayload, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + return jsonPayload, nil +} + +// FormatResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure. +func (l *Llama31) FormatResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + return formatResponse(response) +} + +// FormatStreamResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure. +func (l *Llama31) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + return formatResponse(response) +} + +func formatResponse(response []byte) (*llms.ContentResponse, error) { + // Parse the LlamaResponse JSON + var llamaResp LlamaResponse + err := json.Unmarshal(response, &llamaResp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal LlamaResponse: %w", err) + } + + // Initialize ContentResponse + contentResponse := &llms.ContentResponse{ + Choices: []*llms.ContentChoice{}, + } + + // Map LlamaResponse choices to ContentChoice + for _, llamaChoice := range llamaResp.Choices { + contentChoice := &llms.ContentChoice{ + Content: llamaChoice.Message.Content, + StopReason: llamaChoice.FinishReason, + GenerationInfo: map[string]any{ + "index": llamaChoice.Index, + }, + } + + // If the LlamaMessage indicates a function/tool call, populate FuncCall or ToolCalls + if llamaChoice.Message.Role == RoleIPython { + funcCall := &llms.FunctionCall{ + Name: "tool_function_name", // Replace with actual function name if included in response + Arguments: llamaChoice.Message.Content, // Replace with parsed arguments if available + } + contentChoice.FuncCall = funcCall + contentChoice.ToolCalls = []llms.ToolCall{ + { + ID: fmt.Sprintf("tool-call-%d", llamaChoice.Index), + Type: "function", + FunctionCall: funcCall, + }, + } + } + + contentResponse.Choices = append(contentResponse.Choices, contentChoice) + } + + return contentResponse, nil +} diff --git a/llms/databricks/clients/llama/v3.1/map_role.go b/llms/databricks/clients/llama/v3.1/map_role.go new file mode 100644 index 000000000..edc760aab --- /dev/null +++ b/llms/databricks/clients/llama/v3.1/map_role.go @@ -0,0 +1,23 @@ +package databricksclientsllama31 + +import ( + "github.com/tmc/langchaingo/llms" +) + +// MapRole maps ChatMessageType to LlamaRole. +func MapRole(chatRole llms.ChatMessageType) Role { + switch chatRole { + case llms.ChatMessageTypeAI: + return RoleAssistant + case llms.ChatMessageTypeHuman: + return RoleUser + case llms.ChatMessageTypeSystem: + return RoleSystem + case llms.ChatMessageTypeFunction, llms.ChatMessageTypeTool: + return RoleIPython // Mapping tools and functions to ipython + case llms.ChatMessageTypeGeneric: + return RoleUser // Defaulting generic to user + default: + return Role(chatRole) + } +} diff --git a/llms/databricks/clients/llama/v3.1/types.go b/llms/databricks/clients/llama/v3.1/types.go new file mode 100644 index 000000000..1ed5d8ec9 --- /dev/null +++ b/llms/databricks/clients/llama/v3.1/types.go @@ -0,0 +1,47 @@ +package databricksclientsllama31 + +type Role string + +const ( + RoleSystem Role = "system" // The system role provides instructions or context for the model + RoleUser Role = "user" // The user role represents inputs from the user + RoleAssistant Role = "assistant" // The assistant role represents responses from the model + RoleIPython Role = "ipython" // The ipython role represents responses from the model +) + +type LlamaMessage struct { + Role Role `json:"role"` // Role of the message sender (e.g., "system", "user", "assistant") + Content string `json:"content"` // The content of the message +} + +type LlamaPayload struct { + Model string `json:"model"` // Model to use (e.g., "llama-3.1") + Messages []LlamaMessage `json:"messages"` // List of structured messages + Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (0 to 1) + MaxTokens int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate + TopP float64 `json:"top_p,omitempty"` // Top-p (nucleus) sampling + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // Penalizes new tokens based on frequency + PresencePenalty float64 `json:"presence_penalty,omitempty"` // Penalizes tokens based on presence + Stop []string `json:"stop,omitempty"` // List of stop sequences to end generation +} + +type LlamaResponse struct { + ID string `json:"id"` // Unique ID of the response + Object string `json:"object"` // Type of response (e.g., "chat.completion") + Created int64 `json:"created"` // Timestamp of creation + Model string `json:"model"` // Model used (e.g., "llama-3.1") + Choices []LlamaChoice `json:"choices"` // List of response choices + Usage LlamaUsage `json:"usage"` // Token usage details +} + +type LlamaChoice struct { + Index int `json:"index"` // Index of the choice + Message LlamaMessage `json:"message"` // The message content + FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") +} + +type LlamaUsage struct { + PromptTokens int `json:"prompt_tokens"` // Tokens used for the prompt + CompletionTokens int `json:"completion_tokens"` // Tokens used for the completion + TotalTokens int `json:"total_tokens"` // Total tokens used +} diff --git a/llms/databricks/clients/mistral/v1/map_role.go b/llms/databricks/clients/mistral/v1/map_role.go new file mode 100644 index 000000000..4abde22f2 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/map_role.go @@ -0,0 +1,20 @@ +package databricksclientsmistralv1 + +import "github.com/tmc/langchaingo/llms" + +// mapRole maps llms.ChatMessageType to Role. +// Map function. +func MapRole(chatRole llms.ChatMessageType) Role { + switch chatRole { + case llms.ChatMessageTypeAI: + return RoleAssistant + case llms.ChatMessageTypeHuman, llms.ChatMessageTypeGeneric: + return RoleUser + case llms.ChatMessageTypeSystem: + return RoleSystem + case llms.ChatMessageTypeTool, llms.ChatMessageTypeFunction: + return RoleTool + default: + return Role(chatRole) + } +} diff --git a/llms/databricks/clients/mistral/v1/mistralv1.go b/llms/databricks/clients/mistral/v1/mistralv1.go new file mode 100644 index 000000000..9135b7b0f --- /dev/null +++ b/llms/databricks/clients/mistral/v1/mistralv1.go @@ -0,0 +1,159 @@ +package databricksclientsmistralv1 + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/databricks" +) + +type Mistral1 struct { + Model string `json:"model"` +} + +var _ databricks.Model = (*Mistral1)(nil) + +func NewMistral1(model string) *Mistral1 { + return &Mistral1{ + Model: model, + } +} + +func (l *Mistral1) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) { + // Initialize payload options with defaults + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + // Convert langchaingo MessageContent to Mistral ChatMessage + var mistralMessages []ChatMessage // nolint: prealloc + for _, msg := range messages { + // Process parts to handle the actual content and tool calls correctly + var toolCalls []ToolCall + var contentParts []string + for _, part := range msg.Parts { + switch p := part.(type) { + case llms.TextContent: + // Text parts go directly as content + contentParts = append(contentParts, p.Text) + case llms.ImageURLContent: + // Append structured description for the image + contentParts = append(contentParts, fmt.Sprintf("Image: %s", p.URL)) + if p.Detail != "" { + contentParts = append(contentParts, fmt.Sprintf("Detail: %s", p.Detail)) + } + case llms.ToolCall: + // Convert tool calls into structured Mistral ToolCall objects + toolCalls = append(toolCalls, ToolCall{ + ID: p.ID, + Type: ToolTypeFunction, // Assuming ToolTypeFunction + Function: FunctionCall{ + Name: p.FunctionCall.Name, + Arguments: p.FunctionCall.Arguments, + }, + }) + case llms.ToolCallResponse: + // Handle tool call responses as content + contentParts = append(contentParts, p.Content) + default: + return nil, fmt.Errorf("unknown content part type: %T", part) + } + } + + mistralMessage := ChatMessage{ + Role: MapRole(msg.Role), + Content: fmt.Sprintf("%s", contentParts), + ToolCalls: toolCalls, + } + mistralMessages = append(mistralMessages, mistralMessage) + } + + // Handle options (example: temperature, max_tokens, etc.) + payload := ChatCompletionPayload{ + Model: "mistral-7b", // Replace with the desired model + Messages: mistralMessages, + Temperature: opts.Temperature, + MaxTokens: opts.MaxTokens, + TopP: opts.TopP, + RandomSeed: opts.Seed, + } + + // Marshal the payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + return payloadBytes, nil +} + +// Refactored FormatResponse. +func (l *Mistral1) FormatResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + var resp ChatCompletionResponse + if err := json.Unmarshal(response, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &llms.ContentResponse{ + Choices: mapChoices(resp.Choices), + }, nil +} + +// Refactored FormatStreamResponse. +func (l *Mistral1) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + var streamResp ChatCompletionStreamResponse + if err := json.Unmarshal(response, &streamResp); err != nil { + return nil, fmt.Errorf("failed to parse streaming response: %w", err) + } + + return &llms.ContentResponse{ + Choices: mapChoices(streamResp.Choices), + }, nil +} + +// Helper function to map choices. +func mapChoices[T ChatCompletionResponseChoice | ChatCompletionResponseChoiceStream](choices []T) []*llms.ContentChoice { + var contentChoices []*llms.ContentChoice // nolint: prealloc + + for _, choice := range choices { + var index int + var message ChatMessage + var finishReason FinishReason + switch c := any(choice).(type) { + case ChatCompletionResponseChoice: + index = c.Index + message = c.Message + finishReason = c.FinishReason + case ChatCompletionResponseChoiceStream: + index = c.Index + message = c.Delta + finishReason = c.FinishReason + } + + contentChoice := &llms.ContentChoice{ + Content: message.Content, + StopReason: string(finishReason), + GenerationInfo: map[string]any{ + "index": index, + }, + } + + for _, toolCall := range message.ToolCalls { + contentChoice.ToolCalls = append(contentChoice.ToolCalls, llms.ToolCall{ + ID: toolCall.ID, + Type: string(toolCall.Type), + FunctionCall: &llms.FunctionCall{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + }) + } + + contentChoices = append(contentChoices, contentChoice) + } + + return contentChoices +} diff --git a/llms/databricks/clients/mistral/v1/types.go b/llms/databricks/clients/mistral/v1/types.go new file mode 100644 index 000000000..009a9b443 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/types.go @@ -0,0 +1,76 @@ +package databricksclientsmistralv1 + +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleSystem Role = "system" + RoleTool Role = "tool" +) + +// FinishReason the reason that a chat message was finished. +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonError FinishReason = "error" +) + +// ResponseFormat the format that the response must adhere to. +type ResponseFormat string + +const ( + ResponseFormatText ResponseFormat = "text" + ResponseFormatJSONObject ResponseFormat = "json_object" +) + +// ToolType type of tool defined for the llm. +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +// ToolChoice the choice of tool to use. +type ToolChoice string + +const ( + ToolChoiceAny ToolChoice = "any" + ToolChoiceAuto ToolChoice = "auto" + ToolChoiceNone ToolChoice = "none" +) + +// Tool definition of a tool that the llm can call. +type Tool struct { + Type ToolType `json:"type"` + Function Function `json:"function"` +} + +// Function definition of a function that the llm can call including its parameters. +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +// FunctionCall represents a request to call an external tool by the llm. +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ToolCall represents the call to a tool by the llm. +type ToolCall struct { + ID string `json:"id"` + Type ToolType `json:"type"` + Function FunctionCall `json:"function"` +} + +// ChatMessage represents a single message in a chat. +type ChatMessage struct { + Role Role `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} diff --git a/llms/databricks/clients/mistral/v1/types_payload.go b/llms/databricks/clients/mistral/v1/types_payload.go new file mode 100644 index 000000000..93ba9f3f2 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/types_payload.go @@ -0,0 +1,15 @@ +package databricksclientsmistralv1 + +// ChatCompletionPayload represents the payload for the chat completion request. +type ChatCompletionPayload struct { + Model string `json:"model"` // The model to use for completion + Messages []ChatMessage `json:"messages"` // The messages to use for completion + Temperature float64 `json:"temperature,omitempty"` // The temperature to use for sampling. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or TopP but not both. + MaxTokens int `json:"max_tokens,omitempty"` + TopP float64 `json:"top_p,omitempty"` // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or Temperature but not both. + RandomSeed int `json:"random_seed,omitempty"` + SafePrompt bool `json:"safe_prompt,omitempty"` // Adds a Mistral defined safety message to the system prompt to enforce guardrailing + Tools []Tool `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + ResponseFormat ResponseFormat `json:"response_format,omitempty"` +} diff --git a/llms/databricks/clients/mistral/v1/types_response.go b/llms/databricks/clients/mistral/v1/types_response.go new file mode 100644 index 000000000..0272bde07 --- /dev/null +++ b/llms/databricks/clients/mistral/v1/types_response.go @@ -0,0 +1,43 @@ +package databricksclientsmistralv1 + +// ChatCompletionResponse represents the response from the chat completion endpoint. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionResponseChoice `json:"choices"` + Usage UsageInfo `json:"usage"` +} + +// ChatCompletionStreamResponse represents the streamed response from the chat completion endpoint. +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Model string `json:"model"` + Choices []ChatCompletionResponseChoiceStream `json:"choices"` + Created int `json:"created,omitempty"` + Object string `json:"object,omitempty"` + Usage UsageInfo `json:"usage,omitempty"` + Error error `json:"error,omitempty"` +} + +// ChatCompletionResponseChoice represents a choice in the chat completion response. +type ChatCompletionResponseChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason FinishReason `json:"finish_reason,omitempty"` +} + +// ChatCompletionResponseChoice represents a choice in the chat completion response. +type ChatCompletionResponseChoiceStream struct { + Index int `json:"index"` + Delta ChatMessage `json:"delta"` + FinishReason FinishReason `json:"finish_reason,omitempty"` +} + +// UsageInfo represents the usage information of a response. +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokens int `json:"completion_tokens,omitempty"` +} diff --git a/llms/databricks/databricksllm.go b/llms/databricks/databricksllm.go new file mode 100644 index 000000000..0290e1869 --- /dev/null +++ b/llms/databricks/databricksllm.go @@ -0,0 +1,123 @@ +package databricks + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + + "github.com/tmc/langchaingo/llms" +) + +// Option is a function that applies a configuration to the LLM. +type Option func(*LLM) + +// LLM is a databricks LLM implementation. +type LLM struct { + url string // The constructed or provided URL + token string // The token for authentication + httpClient *http.Client + model Model +} + +var _ llms.Model = (*LLM)(nil) + +// New creates a new llamafile LLM implementation. +func New(model Model, opts ...Option) (*LLM, error) { + llm := &LLM{ + model: model, + } + + // Apply all options to customize the LLM. + for _, opt := range opts { + opt(llm) + } + + // Validate URL + if llm.url == "" { + return nil, fmt.Errorf("URL must be provided or constructed using options") + } + + if llm.httpClient == nil { + if llm.token == "" { + return nil, fmt.Errorf("token must be provided") + } + llm.httpClient = NewHTTPClient(llm.token) + } + + return llm, nil +} + +// Call Implement the call interface for LLM. +func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { + return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...) +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { // nolint: lll, cyclop, funlen + payload, err := o.model.FormatPayload(ctx, messages, options...) + if err != nil { + return nil, err + } + + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + fmt.Printf("payload: %v\n", string(payload)) + + request, err := http.NewRequestWithContext(ctx, http.MethodPost, o.url, bytes.NewBuffer(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := o.httpClient.Do(request) + if err != nil { + return nil, fmt.Errorf("failed to do request: %w", err) + } + defer resp.Body.Close() + + // Create a buffer to save a copy of the body + var buffer bytes.Buffer + teeReader := io.TeeReader(resp.Body, &buffer) + + if err := o.stream(ctx, &buffer, opts); err != nil { + return nil, err + } + + bodyBytes, err := io.ReadAll(teeReader) + if err != nil { + return nil, err + } + fmt.Printf("bodyBytes: %v\n", string(bodyBytes)) + return o.model.FormatResponse(ctx, bodyBytes) +} + +func (o *LLM) stream(ctx context.Context, resp io.Reader, opts llms.CallOptions) error { + if opts.StreamingFunc == nil { + return nil + } + + scanner := bufio.NewScanner(resp) + for scanner.Scan() { + contentResponse, err := o.model.FormatResponse(ctx, scanner.Bytes()) + if err != nil { + return err + } + + fmt.Printf("contentResponse: %v\n", *contentResponse) + + if len(contentResponse.Choices) == 0 { + continue + } + + if err := opts.StreamingFunc(ctx, []byte(contentResponse.Choices[0].Content)); err != nil { + return err + } + } + + return scanner.Err() +} diff --git a/llms/databricks/databricksllm_test.go b/llms/databricks/databricksllm_test.go new file mode 100644 index 000000000..4d3b3bb91 --- /dev/null +++ b/llms/databricks/databricksllm_test.go @@ -0,0 +1,81 @@ +package databricks_test + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/databricks" + databricksclientsllama31 "github.com/tmc/langchaingo/llms/databricks/clients/llama/v3.1" + databricksclientsmistralv1 "github.com/tmc/langchaingo/llms/databricks/clients/mistral/v1" +) + +func testModel(t *testing.T, model databricks.Model, url string) { + t.Helper() + + const envVarToken = "DATABRICKS_TOKEN" + + if os.Getenv(envVarToken) == "" { + t.Skipf("%s not set", envVarToken) + } + + dbllm, err := databricks.New(model, databricks.WithFullURL(url), databricks.WithToken(os.Getenv(envVarToken))) + if err != nil { + t.Fatalf("failed to create databricks LLM: %v", err) + } + + ctx := context.Background() + resp, err := dbllm.GenerateContent(ctx, []llms.MessageContent{ + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextContent{Text: "Brazil is a country? the answer should just be yes or no"}, + }, + }, + }, llms.WithStreamingFunc(func(_ context.Context, chunk []byte) error { + fmt.Printf("string(chunk): %v\n", string(chunk)) + return nil + })) + if err != nil { + t.Fatalf("failed to generate content: %v", err) + } + + if len(resp.Choices) < 1 { + t.Fatalf("empty response from model") + } +} + +func TestDatabricksLlama31(t *testing.T) { + t.Parallel() + + const envVar = "DATABRICKS_LLAMA31_URL" + + if os.Getenv(envVar) == "" { + t.Skipf("%s not set", envVar) + } + + testModel(t, databricksclientsllama31.NewLlama31(), os.Getenv(envVar)) + + t.Error() +} + +func TestDatabricksMistal1(t *testing.T) { + t.Parallel() + + const envVarURL = "DATABRICKS_MISTAL1_URL" + const envVarModel = "DATABRICKS_MISTAL1_MODEL" + + if os.Getenv(envVarURL) == "" { + t.Skipf("%s not set", envVarURL) + } + + if os.Getenv(envVarModel) == "" { + t.Skipf("%s not set", envVarModel) + } + + testModel(t, databricksclientsmistralv1.NewMistral1(os.Getenv(envVarModel)), os.Getenv(envVarURL)) + + t.Error() +} diff --git a/llms/databricks/http_client.go b/llms/databricks/http_client.go new file mode 100644 index 000000000..e3b3dc020 --- /dev/null +++ b/llms/databricks/http_client.go @@ -0,0 +1,45 @@ +package databricks + +import ( + "fmt" + "io" + "net/http" +) + +// TokenRoundTripper is a custom RoundTripper that adds a Bearer token to each request. +type TokenRoundTripper struct { + Token string + Transport http.RoundTripper +} + +// RoundTrip executes a single HTTP transaction and adds the Bearer token to the request. +func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+t.Token) + req.Header.Set("content-type", "application/json") + // Use the underlying transport to perform the request + resp, err := t.Transport.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("failed to make request: %w", err) + } + + // Handle status codes + if resp.StatusCode >= 400 { + // Read the response body for detailed error message (optional) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() // Ensure the body is closed to avoid resource leaks + + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + return resp, nil +} + +// NewClientWithToken creates a new HTTP client with a Bearer token. +func NewHTTPClient(token string) *http.Client { + return &http.Client{ + Transport: &TokenRoundTripper{ + Token: token, + Transport: http.DefaultTransport, // Use http.DefaultTransport as the fallback + }, + } +} diff --git a/llms/databricks/model.go b/llms/databricks/model.go new file mode 100644 index 000000000..4bf2b8535 --- /dev/null +++ b/llms/databricks/model.go @@ -0,0 +1,13 @@ +package databricks + +import ( + "context" + + "github.com/tmc/langchaingo/llms" +) + +type Model interface { + FormatPayload(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) + FormatResponse(ctx context.Context, response []byte) (*llms.ContentResponse, error) + FormatStreamResponse(ctx context.Context, response []byte) (*llms.ContentResponse, error) +} diff --git a/llms/databricks/options.go b/llms/databricks/options.go new file mode 100644 index 000000000..370c4b2bc --- /dev/null +++ b/llms/databricks/options.go @@ -0,0 +1,34 @@ +package databricks + +import ( + "fmt" + "net/http" +) + +// WithFullURL sets the full URL for the LLM. +func WithFullURL(fullURL string) Option { + return func(llm *LLM) { + llm.url = fullURL + } +} + +// WithURLComponents constructs the URL from individual components. +func WithURLComponents(databricksInstance, modelName, modelVersion string) Option { + return func(llm *LLM) { + llm.url = fmt.Sprintf("https://%s/model/%s/%s/invocations", databricksInstance, modelName, modelVersion) + } +} + +// WithToken pass the token for authentication. +func WithToken(token string) Option { + return func(llm *LLM) { + llm.token = token + } +} + +// WithHTTPClient sets the HTTP client for the LLM. +func WithHTTPClient(client *http.Client) Option { + return func(llm *LLM) { + llm.httpClient = client + } +} From 047e82ab8554a752f1c90a8bbbbc9a6f8c7496fe Mon Sep 17 00:00:00 2001 From: Thomas Date: Fri, 22 Nov 2024 12:52:36 -0500 Subject: [PATCH 2/8] add streaming to llama --- llms/databricks/clients/llama/v3.1/llama31.go | 4 ++++ llms/databricks/clients/llama/v3.1/types.go | 1 + 2 files changed, 5 insertions(+) diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go index 0e83a16a4..8a82834f3 100644 --- a/llms/databricks/clients/llama/v3.1/llama31.go +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -60,6 +60,10 @@ func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageConten Stop: opts.StopWords, // Add stop sequences if needed } + if opts.StreamingFunc != nil { + payload.Streaming = true + } + // Serialize to JSON jsonPayload, err := json.Marshal(payload) if err != nil { diff --git a/llms/databricks/clients/llama/v3.1/types.go b/llms/databricks/clients/llama/v3.1/types.go index 1ed5d8ec9..48f6979c1 100644 --- a/llms/databricks/clients/llama/v3.1/types.go +++ b/llms/databricks/clients/llama/v3.1/types.go @@ -23,6 +23,7 @@ type LlamaPayload struct { FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // Penalizes new tokens based on frequency PresencePenalty float64 `json:"presence_penalty,omitempty"` // Penalizes tokens based on presence Stop []string `json:"stop,omitempty"` // List of stop sequences to end generation + Streaming bool `json:"streaming,omitempty"` // Enable token-by-token streaming } type LlamaResponse struct { From 65ca2629a721e4ce4d79495bcb4f99db5d2cbcfe Mon Sep 17 00:00:00 2001 From: Thomas Date: Sat, 23 Nov 2024 21:58:59 -0500 Subject: [PATCH 3/8] fix streaming --- llms/databricks/clients/llama/v3.1/llama31.go | 77 +++++++++++++------ llms/databricks/clients/llama/v3.1/types.go | 26 +++++-- .../clients/mistral/v1/mistralv1.go | 19 +++++ .../clients/mistral/v1/types_payload.go | 1 + llms/databricks/databricksllm.go | 76 +++++++++++++----- llms/databricks/databricksllm_test.go | 62 ++++++++++++--- 6 files changed, 199 insertions(+), 62 deletions(-) diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go index 8a82834f3..908826444 100644 --- a/llms/databricks/clients/llama/v3.1/llama31.go +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -1,6 +1,7 @@ package databricksclientsllama31 import ( + "bytes" "context" "encoding/json" "fmt" @@ -61,7 +62,7 @@ func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageConten } if opts.StreamingFunc != nil { - payload.Streaming = true + payload.Stream = true } // Serialize to JSON @@ -75,17 +76,32 @@ func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageConten // FormatResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure. func (l *Llama31) FormatResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { - return formatResponse(response) + return formatResponse[LlamaChoice](response) } // FormatStreamResponse parses the LlamaResponse JSON and converts it to a ContentResponse structure. func (l *Llama31) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { - return formatResponse(response) + // The "data:" prefix is commonly used in Server-Sent Events (SSE) or streaming APIs + // to delimit individual chunks of data being sent from the server. It indicates + // that the following text is a data payload. Before parsing the JSON, we remove + // this prefix to work with the raw JSON payload. + response = bytes.TrimPrefix(response, []byte("data: ")) + + if string(response) == "[DONE]" || len(response) == 0 { + return &llms.ContentResponse{ + Choices: []*llms.ContentChoice{{ + Content: "", + }}, + }, nil + } + return formatResponse[LlamaChoiceDelta](response) } -func formatResponse(response []byte) (*llms.ContentResponse, error) { +func formatResponse[T LlamaChoiceDelta | LlamaChoice](response []byte) (*llms.ContentResponse, error) { + fmt.Printf("response: %+v\n", string(response)) + // Parse the LlamaResponse JSON - var llamaResp LlamaResponse + var llamaResp LlamaResponse[T] err := json.Unmarshal(response, &llamaResp) if err != nil { return nil, fmt.Errorf("failed to unmarshal LlamaResponse: %w", err) @@ -96,32 +112,47 @@ func formatResponse(response []byte) (*llms.ContentResponse, error) { Choices: []*llms.ContentChoice{}, } + fmt.Printf("llamaResp: %+v\n", llamaResp) + // Map LlamaResponse choices to ContentChoice for _, llamaChoice := range llamaResp.Choices { - contentChoice := &llms.ContentChoice{ - Content: llamaChoice.Message.Content, - StopReason: llamaChoice.FinishReason, - GenerationInfo: map[string]any{ - "index": llamaChoice.Index, - }, - } + var contentChoice *llms.ContentChoice + switch choice := any(llamaChoice).(type) { + case LlamaChoice: + contentChoice = &llms.ContentChoice{ + Content: choice.Message.Content, + StopReason: choice.FinishReason, + GenerationInfo: map[string]any{ + "index": choice.Index, + }, + } - // If the LlamaMessage indicates a function/tool call, populate FuncCall or ToolCalls - if llamaChoice.Message.Role == RoleIPython { - funcCall := &llms.FunctionCall{ - Name: "tool_function_name", // Replace with actual function name if included in response - Arguments: llamaChoice.Message.Content, // Replace with parsed arguments if available + // If the LlamaMessage indicates a function/tool call, populate FuncCall or ToolCalls + if choice.Message.Role == RoleIPython { + funcCall := &llms.FunctionCall{ + Name: "tool_function_name", // Replace with actual function name if included in response + Arguments: choice.Message.Content, // Replace with parsed arguments if available + } + contentChoice.FuncCall = funcCall + contentChoice.ToolCalls = []llms.ToolCall{ + { + ID: fmt.Sprintf("tool-call-%d", choice.Index), + Type: "function", + FunctionCall: funcCall, + }, + } } - contentChoice.FuncCall = funcCall - contentChoice.ToolCalls = []llms.ToolCall{ - { - ID: fmt.Sprintf("tool-call-%d", llamaChoice.Index), - Type: "function", - FunctionCall: funcCall, + case LlamaChoiceDelta: + contentChoice = &llms.ContentChoice{ + Content: choice.Delta.Content, + StopReason: choice.FinishReason, + GenerationInfo: map[string]any{ + "index": choice.Index, }, } } + // Append the ContentChoice to the ContentResponse contentResponse.Choices = append(contentResponse.Choices, contentChoice) } diff --git a/llms/databricks/clients/llama/v3.1/types.go b/llms/databricks/clients/llama/v3.1/types.go index 48f6979c1..949542e6a 100644 --- a/llms/databricks/clients/llama/v3.1/types.go +++ b/llms/databricks/clients/llama/v3.1/types.go @@ -14,6 +14,10 @@ type LlamaMessage struct { Content string `json:"content"` // The content of the message } +type LlamaMessageDelta struct { + Content string `json:"content"` // The content of the message +} + type LlamaPayload struct { Model string `json:"model"` // Model to use (e.g., "llama-3.1") Messages []LlamaMessage `json:"messages"` // List of structured messages @@ -23,16 +27,16 @@ type LlamaPayload struct { FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // Penalizes new tokens based on frequency PresencePenalty float64 `json:"presence_penalty,omitempty"` // Penalizes tokens based on presence Stop []string `json:"stop,omitempty"` // List of stop sequences to end generation - Streaming bool `json:"streaming,omitempty"` // Enable token-by-token streaming + Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming } -type LlamaResponse struct { - ID string `json:"id"` // Unique ID of the response - Object string `json:"object"` // Type of response (e.g., "chat.completion") - Created int64 `json:"created"` // Timestamp of creation - Model string `json:"model"` // Model used (e.g., "llama-3.1") - Choices []LlamaChoice `json:"choices"` // List of response choices - Usage LlamaUsage `json:"usage"` // Token usage details +type LlamaResponse[T LlamaChoice | LlamaChoiceDelta] struct { + ID string `json:"id"` // Unique ID of the response + Object string `json:"object"` // Type of response (e.g., "chat.completion") + Created int64 `json:"created"` // Timestamp of creation + Model string `json:"model"` // Model used (e.g., "llama-3.1") + Choices []T `json:"choices"` // List of response choices + Usage LlamaUsage `json:"usage"` // Token usage details } type LlamaChoice struct { @@ -41,6 +45,12 @@ type LlamaChoice struct { FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") } +type LlamaChoiceDelta struct { + Index int `json:"index"` // Index of the choice + Delta LlamaMessageDelta `json:"delta"` // The message content + FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") +} + type LlamaUsage struct { PromptTokens int `json:"prompt_tokens"` // Tokens used for the prompt CompletionTokens int `json:"completion_tokens"` // Tokens used for the completion diff --git a/llms/databricks/clients/mistral/v1/mistralv1.go b/llms/databricks/clients/mistral/v1/mistralv1.go index 9135b7b0f..3c8f48bd2 100644 --- a/llms/databricks/clients/mistral/v1/mistralv1.go +++ b/llms/databricks/clients/mistral/v1/mistralv1.go @@ -1,6 +1,7 @@ package databricksclientsmistralv1 import ( + "bytes" "context" "encoding/json" "fmt" @@ -81,6 +82,10 @@ func (l *Mistral1) FormatPayload(_ context.Context, messages []llms.MessageConte RandomSeed: opts.Seed, } + if opts.StreamingFunc != nil { + payload.Stream = true + } + // Marshal the payload to JSON payloadBytes, err := json.Marshal(payload) if err != nil { @@ -104,6 +109,20 @@ func (l *Mistral1) FormatResponse(_ context.Context, response []byte) (*llms.Con // Refactored FormatStreamResponse. func (l *Mistral1) FormatStreamResponse(_ context.Context, response []byte) (*llms.ContentResponse, error) { + // The "data:" prefix is commonly used in Server-Sent Events (SSE) or streaming APIs + // to delimit individual chunks of data being sent from the server. It indicates + // that the following text is a data payload. Before parsing the JSON, we remove + // this prefix to work with the raw JSON payload. + response = bytes.TrimPrefix(response, []byte("data: ")) + + if string(response) == "[DONE]" || len(response) == 0 { + return &llms.ContentResponse{ + Choices: []*llms.ContentChoice{{ + Content: "", + }}, + }, nil + } + var streamResp ChatCompletionStreamResponse if err := json.Unmarshal(response, &streamResp); err != nil { return nil, fmt.Errorf("failed to parse streaming response: %w", err) diff --git a/llms/databricks/clients/mistral/v1/types_payload.go b/llms/databricks/clients/mistral/v1/types_payload.go index 93ba9f3f2..853a78618 100644 --- a/llms/databricks/clients/mistral/v1/types_payload.go +++ b/llms/databricks/clients/mistral/v1/types_payload.go @@ -12,4 +12,5 @@ type ChatCompletionPayload struct { Tools []Tool `json:"tools,omitempty"` ToolChoice string `json:"tool_choice,omitempty"` ResponseFormat ResponseFormat `json:"response_format,omitempty"` + Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming } diff --git a/llms/databricks/databricksllm.go b/llms/databricks/databricksllm.go index 0290e1869..95519e024 100644 --- a/llms/databricks/databricksllm.go +++ b/llms/databricks/databricksllm.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/tmc/langchaingo/llms" ) @@ -80,44 +81,79 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten } defer resp.Body.Close() - // Create a buffer to save a copy of the body - var buffer bytes.Buffer - teeReader := io.TeeReader(resp.Body, &buffer) - - if err := o.stream(ctx, &buffer, opts); err != nil { - return nil, err + if opts.StreamingFunc != nil { + return o.stream(ctx, resp.Body, opts) } - bodyBytes, err := io.ReadAll(teeReader) + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return nil, err } - fmt.Printf("bodyBytes: %v\n", string(bodyBytes)) + // fmt.Printf("bodyBytes: %v\n", string(bodyBytes)) return o.model.FormatResponse(ctx, bodyBytes) } -func (o *LLM) stream(ctx context.Context, resp io.Reader, opts llms.CallOptions) error { - if opts.StreamingFunc == nil { - return nil - } +func (o *LLM) stream(ctx context.Context, body io.Reader, opts llms.CallOptions) (*llms.ContentResponse, error) { + fullChoiceContent := []strings.Builder{} + scanner := bufio.NewScanner(body) + finalResponse := &llms.ContentResponse{} - scanner := bufio.NewScanner(resp) for scanner.Scan() { - contentResponse, err := o.model.FormatResponse(ctx, scanner.Bytes()) - if err != nil { - return err + scannedBytes := scanner.Bytes() + if len(scannedBytes) == 0 { + continue } - fmt.Printf("contentResponse: %v\n", *contentResponse) + contentResponse, err := o.model.FormatStreamResponse(ctx, scannedBytes) + if err != nil { + return nil, err + } if len(contentResponse.Choices) == 0 { continue } - if err := opts.StreamingFunc(ctx, []byte(contentResponse.Choices[0].Content)); err != nil { - return err + index, err := concatenateAnswers(contentResponse.Choices, &fullChoiceContent) + if err != nil { + return nil, err + } + + if index == nil { + continue + } + + if err := opts.StreamingFunc(ctx, []byte(fullChoiceContent[*index].String())); err != nil { + return nil, err + } + + finalResponse = contentResponse + } + + for index := range finalResponse.Choices { + finalResponse.Choices[index].Content = fullChoiceContent[index].String() + } + + return finalResponse, nil +} + +func concatenateAnswers(choices []*llms.ContentChoice, fullChoiceContent *[]strings.Builder) (*int, error) { + var lastModifiedIndex *int + + for choiceIndex := range choices { + if len(*fullChoiceContent) <= choiceIndex { + *fullChoiceContent = append(*fullChoiceContent, strings.Builder{}) + } + + if choices[choiceIndex].Content == "" { + continue + } + + lastModifiedIndex = &choiceIndex + + if _, err := (*fullChoiceContent)[choiceIndex].WriteString(choices[choiceIndex].Content); err != nil { + return lastModifiedIndex, err } } - return scanner.Err() + return lastModifiedIndex, nil } diff --git a/llms/databricks/databricksllm_test.go b/llms/databricks/databricksllm_test.go index 4d3b3bb91..b89b166da 100644 --- a/llms/databricks/databricksllm_test.go +++ b/llms/databricks/databricksllm_test.go @@ -2,7 +2,6 @@ package databricks_test import ( "context" - "fmt" "os" "testing" @@ -12,7 +11,7 @@ import ( databricksclientsmistralv1 "github.com/tmc/langchaingo/llms/databricks/clients/mistral/v1" ) -func testModel(t *testing.T, model databricks.Model, url string) { +func testModelStream(t *testing.T, model databricks.Model, url string) { t.Helper() const envVarToken = "DATABRICKS_TOKEN" @@ -31,11 +30,13 @@ func testModel(t *testing.T, model databricks.Model, url string) { { Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ - llms.TextContent{Text: "Brazil is a country? the answer should just be yes or no"}, + llms.TextContent{Text: "Brazil is a country?"}, }, }, }, llms.WithStreamingFunc(func(_ context.Context, chunk []byte) error { - fmt.Printf("string(chunk): %v\n", string(chunk)) + if len(chunk) == 0 { + t.Fatalf("empty chunk") + } return nil })) if err != nil { @@ -47,18 +48,53 @@ func testModel(t *testing.T, model databricks.Model, url string) { } } +func testModel(t *testing.T, model databricks.Model, url string) { + t.Helper() + + const envVarToken = "DATABRICKS_TOKEN" + + if os.Getenv(envVarToken) == "" { + t.Skipf("%s not set", envVarToken) + } + + dbllm, err := databricks.New(model, databricks.WithFullURL(url), databricks.WithToken(os.Getenv(envVarToken))) + if err != nil { + t.Fatalf("failed to create databricks LLM: %v", err) + } + + ctx := context.Background() + resp, err := dbllm.GenerateContent(ctx, []llms.MessageContent{ + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextContent{Text: "Brazil is a country?"}, + }, + }, + }) + if err != nil { + t.Fatalf("failed to generate content: %v", err) + } + + if len(resp.Choices) < 1 { + t.Fatalf("empty response from model") + } +} + func TestDatabricksLlama31(t *testing.T) { t.Parallel() const envVar = "DATABRICKS_LLAMA31_URL" - if os.Getenv(envVar) == "" { + url := os.Getenv(envVar) + + if url == "" { t.Skipf("%s not set", envVar) } - testModel(t, databricksclientsllama31.NewLlama31(), os.Getenv(envVar)) + llama31 := databricksclientsllama31.NewLlama31() - t.Error() + testModelStream(t, llama31, url) + testModel(t, llama31, url) } func TestDatabricksMistal1(t *testing.T) { @@ -67,15 +103,19 @@ func TestDatabricksMistal1(t *testing.T) { const envVarURL = "DATABRICKS_MISTAL1_URL" const envVarModel = "DATABRICKS_MISTAL1_MODEL" - if os.Getenv(envVarURL) == "" { + model := os.Getenv(envVarModel) + url := os.Getenv(envVarURL) + + if url == "" { t.Skipf("%s not set", envVarURL) } - if os.Getenv(envVarModel) == "" { + if model == "" { t.Skipf("%s not set", envVarModel) } - testModel(t, databricksclientsmistralv1.NewMistral1(os.Getenv(envVarModel)), os.Getenv(envVarURL)) + mistral1 := databricksclientsmistralv1.NewMistral1(model) - t.Error() + testModelStream(t, mistral1, url) + testModel(t, mistral1, url) } From 4bbd06e5f3e542d52baacfdfff63a90f8ed956c7 Mon Sep 17 00:00:00 2001 From: Thomas Date: Mon, 25 Nov 2024 11:54:37 -0500 Subject: [PATCH 4/8] do not accumulate while streaming --- llms/databricks/databricksllm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/databricks/databricksllm.go b/llms/databricks/databricksllm.go index 95519e024..390d609f8 100644 --- a/llms/databricks/databricksllm.go +++ b/llms/databricks/databricksllm.go @@ -122,7 +122,7 @@ func (o *LLM) stream(ctx context.Context, body io.Reader, opts llms.CallOptions) continue } - if err := opts.StreamingFunc(ctx, []byte(fullChoiceContent[*index].String())); err != nil { + if err := opts.StreamingFunc(ctx, []byte(contentResponse.Choices[*index].Content)); err != nil { return nil, err } From 2b59ab1ef428bc76dd32c7c59daac69ba6eca2fa Mon Sep 17 00:00:00 2001 From: Thomas Date: Mon, 25 Nov 2024 15:41:11 -0500 Subject: [PATCH 5/8] add comment to each exported --- llms/databricks/clients/llama/v3.1/llama31.go | 3 +++ llms/databricks/clients/llama/v3.1/types.go | 7 +++++++ llms/databricks/clients/mistral/v1/mistralv1.go | 3 +++ llms/databricks/clients/mistral/v1/types.go | 6 ++++++ llms/databricks/model.go | 1 + 5 files changed, 20 insertions(+) diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go index 908826444..a42833b43 100644 --- a/llms/databricks/clients/llama/v3.1/llama31.go +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -11,14 +11,17 @@ import ( "github.com/tmc/langchaingo/llms/databricks" ) +// LlamaPayload represents the payload structure for the Llama model. type Llama31 struct{} var _ databricks.Model = (*Llama31)(nil) +// NewLlama31 creates a new Llama31 instance. func NewLlama31() *Llama31 { return &Llama31{} } +// FormatPayload implements databricks.Model to convert langchaingo llms.MessageContent to llama payload. func (l *Llama31) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) { // Initialize payload options with defaults opts := llms.CallOptions{} diff --git a/llms/databricks/clients/llama/v3.1/types.go b/llms/databricks/clients/llama/v3.1/types.go index 949542e6a..1345e6afa 100644 --- a/llms/databricks/clients/llama/v3.1/types.go +++ b/llms/databricks/clients/llama/v3.1/types.go @@ -9,15 +9,18 @@ const ( RoleIPython Role = "ipython" // The ipython role represents responses from the model ) +// LlamaMessage represents a message in the LLM. type LlamaMessage struct { Role Role `json:"role"` // Role of the message sender (e.g., "system", "user", "assistant") Content string `json:"content"` // The content of the message } +// LlamaMessageDelta represents a message streamed by the LLM. type LlamaMessageDelta struct { Content string `json:"content"` // The content of the message } +// LlamaPayload represents the payload structure for the Llama model. type LlamaPayload struct { Model string `json:"model"` // Model to use (e.g., "llama-3.1") Messages []LlamaMessage `json:"messages"` // List of structured messages @@ -30,6 +33,7 @@ type LlamaPayload struct { Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming } +// LlamaResponse represents the response structure for the Llama model. (full answer or streamed one) type LlamaResponse[T LlamaChoice | LlamaChoiceDelta] struct { ID string `json:"id"` // Unique ID of the response Object string `json:"object"` // Type of response (e.g., "chat.completion") @@ -39,18 +43,21 @@ type LlamaResponse[T LlamaChoice | LlamaChoiceDelta] struct { Usage LlamaUsage `json:"usage"` // Token usage details } +// LlamaChoice represents a choice in the Llama response. type LlamaChoice struct { Index int `json:"index"` // Index of the choice Message LlamaMessage `json:"message"` // The message content FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") } +// LlamaChoiceDelta represents a choice in the Llama response. type LlamaChoiceDelta struct { Index int `json:"index"` // Index of the choice Delta LlamaMessageDelta `json:"delta"` // The message content FinishReason string `json:"finish_reason"` // Why the response stopped (e.g., "stop") } +// LlamaUsage represents the token usage details of a response. type LlamaUsage struct { PromptTokens int `json:"prompt_tokens"` // Tokens used for the prompt CompletionTokens int `json:"completion_tokens"` // Tokens used for the completion diff --git a/llms/databricks/clients/mistral/v1/mistralv1.go b/llms/databricks/clients/mistral/v1/mistralv1.go index 3c8f48bd2..a79246511 100644 --- a/llms/databricks/clients/mistral/v1/mistralv1.go +++ b/llms/databricks/clients/mistral/v1/mistralv1.go @@ -10,18 +10,21 @@ import ( "github.com/tmc/langchaingo/llms/databricks" ) +// Mistral1 represents the payload structure for the Mistral model. type Mistral1 struct { Model string `json:"model"` } var _ databricks.Model = (*Mistral1)(nil) +// NewMistral1 creates a new Mistral1 instance. func NewMistral1(model string) *Mistral1 { return &Mistral1{ Model: model, } } +// FormatPayload implements databricks.Model to convert langchaingo llms.MessageContent to llama payload. func (l *Mistral1) FormatPayload(_ context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) { // Initialize payload options with defaults opts := llms.CallOptions{} diff --git a/llms/databricks/clients/mistral/v1/types.go b/llms/databricks/clients/mistral/v1/types.go index 009a9b443..7d8f105ea 100644 --- a/llms/databricks/clients/mistral/v1/types.go +++ b/llms/databricks/clients/mistral/v1/types.go @@ -1,7 +1,9 @@ package databricksclientsmistralv1 +// Role the role of the chat message. type Role string +// Role the role of the chat message. const ( RoleUser Role = "user" RoleAssistant Role = "assistant" @@ -12,6 +14,7 @@ const ( // FinishReason the reason that a chat message was finished. type FinishReason string +// FinishReason the reason that a chat message was finished. const ( FinishReasonStop FinishReason = "stop" FinishReasonLength FinishReason = "length" @@ -21,6 +24,7 @@ const ( // ResponseFormat the format that the response must adhere to. type ResponseFormat string +// ResponseFormat the format that the response must adhere to. const ( ResponseFormatText ResponseFormat = "text" ResponseFormatJSONObject ResponseFormat = "json_object" @@ -29,6 +33,7 @@ const ( // ToolType type of tool defined for the llm. type ToolType string +// ToolType type of tool defined for the llm. const ( ToolTypeFunction ToolType = "function" ) @@ -36,6 +41,7 @@ const ( // ToolChoice the choice of tool to use. type ToolChoice string +// ToolChoice the choice of tool to use. const ( ToolChoiceAny ToolChoice = "any" ToolChoiceAuto ToolChoice = "auto" diff --git a/llms/databricks/model.go b/llms/databricks/model.go index 4bf2b8535..aca213c40 100644 --- a/llms/databricks/model.go +++ b/llms/databricks/model.go @@ -6,6 +6,7 @@ import ( "github.com/tmc/langchaingo/llms" ) +// Model is the interface that wraps the methods to format the payload and response. type Model interface { FormatPayload(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) ([]byte, error) FormatResponse(ctx context.Context, response []byte) (*llms.ContentResponse, error) From 842b1c09f6cf103753bf826861b537644084a8ec Mon Sep 17 00:00:00 2001 From: Thomas Date: Mon, 25 Nov 2024 15:43:00 -0500 Subject: [PATCH 6/8] fix golint-ci --- llms/databricks/clients/llama/v3.1/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/databricks/clients/llama/v3.1/types.go b/llms/databricks/clients/llama/v3.1/types.go index 1345e6afa..f9004eb30 100644 --- a/llms/databricks/clients/llama/v3.1/types.go +++ b/llms/databricks/clients/llama/v3.1/types.go @@ -33,7 +33,7 @@ type LlamaPayload struct { Stream bool `json:"stream,omitempty"` // Enable token-by-token streaming } -// LlamaResponse represents the response structure for the Llama model. (full answer or streamed one) +// LlamaResponse represents the response structure for the Llama model. (full answer or streamed one). type LlamaResponse[T LlamaChoice | LlamaChoiceDelta] struct { ID string `json:"id"` // Unique ID of the response Object string `json:"object"` // Type of response (e.g., "chat.completion") From 366a2f53ca17b025d68e5c10fe184beea79bb5de Mon Sep 17 00:00:00 2001 From: Thomas Date: Thu, 16 Jan 2025 18:07:15 +0100 Subject: [PATCH 7/8] use stretchr/testify --- llms/databricks/clients/llama/v3.1/llama31.go | 3 - llms/databricks/databricksllm.go | 3 - llms/databricks/databricksllm_test.go | 85 +++++++++---------- 3 files changed, 41 insertions(+), 50 deletions(-) diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go index a42833b43..fb0f19a17 100644 --- a/llms/databricks/clients/llama/v3.1/llama31.go +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -101,7 +101,6 @@ func (l *Llama31) FormatStreamResponse(_ context.Context, response []byte) (*llm } func formatResponse[T LlamaChoiceDelta | LlamaChoice](response []byte) (*llms.ContentResponse, error) { - fmt.Printf("response: %+v\n", string(response)) // Parse the LlamaResponse JSON var llamaResp LlamaResponse[T] @@ -115,8 +114,6 @@ func formatResponse[T LlamaChoiceDelta | LlamaChoice](response []byte) (*llms.Co Choices: []*llms.ContentChoice{}, } - fmt.Printf("llamaResp: %+v\n", llamaResp) - // Map LlamaResponse choices to ContentChoice for _, llamaChoice := range llamaResp.Choices { var contentChoice *llms.ContentChoice diff --git a/llms/databricks/databricksllm.go b/llms/databricks/databricksllm.go index 390d609f8..c38019976 100644 --- a/llms/databricks/databricksllm.go +++ b/llms/databricks/databricksllm.go @@ -68,8 +68,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten opt(&opts) } - fmt.Printf("payload: %v\n", string(payload)) - request, err := http.NewRequestWithContext(ctx, http.MethodPost, o.url, bytes.NewBuffer(payload)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -89,7 +87,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten if err != nil { return nil, err } - // fmt.Printf("bodyBytes: %v\n", string(bodyBytes)) return o.model.FormatResponse(ctx, bodyBytes) } diff --git a/llms/databricks/databricksllm_test.go b/llms/databricks/databricksllm_test.go index b89b166da..4f3043373 100644 --- a/llms/databricks/databricksllm_test.go +++ b/llms/databricks/databricksllm_test.go @@ -5,6 +5,9 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/databricks" databricksclientsllama31 "github.com/tmc/langchaingo/llms/databricks/clients/llama/v3.1" @@ -20,32 +23,31 @@ func testModelStream(t *testing.T, model databricks.Model, url string) { t.Skipf("%s not set", envVarToken) } - dbllm, err := databricks.New(model, databricks.WithFullURL(url), databricks.WithToken(os.Getenv(envVarToken))) - if err != nil { - t.Fatalf("failed to create databricks LLM: %v", err) - } + dbllm, err := databricks.New( + model, + databricks.WithFullURL(url), + databricks.WithToken(os.Getenv(envVarToken)), + ) + require.NoError(t, err, "failed to create databricks LLM") ctx := context.Background() - resp, err := dbllm.GenerateContent(ctx, []llms.MessageContent{ - { - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextContent{Text: "Brazil is a country?"}, + resp, err := dbllm.GenerateContent(ctx, + []llms.MessageContent{ + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextContent{Text: "Brazil is a country?"}, + }, }, }, - }, llms.WithStreamingFunc(func(_ context.Context, chunk []byte) error { - if len(chunk) == 0 { - t.Fatalf("empty chunk") - } - return nil - })) - if err != nil { - t.Fatalf("failed to generate content: %v", err) - } - - if len(resp.Choices) < 1 { - t.Fatalf("empty response from model") - } + llms.WithStreamingFunc(func(_ context.Context, chunk []byte) error { + require.NotEmpty(t, chunk, "unexpected empty chunk in streaming response") + return nil + }), + ) + require.NoError(t, err, "failed to generate content") + + assert.NotEmpty(t, resp.Choices, "expected at least one choice from model") } func testModel(t *testing.T, model databricks.Model, url string) { @@ -57,47 +59,44 @@ func testModel(t *testing.T, model databricks.Model, url string) { t.Skipf("%s not set", envVarToken) } - dbllm, err := databricks.New(model, databricks.WithFullURL(url), databricks.WithToken(os.Getenv(envVarToken))) - if err != nil { - t.Fatalf("failed to create databricks LLM: %v", err) - } + dbllm, err := databricks.New( + model, + databricks.WithFullURL(url), + databricks.WithToken(os.Getenv(envVarToken)), + ) + require.NoError(t, err, "failed to create databricks LLM") ctx := context.Background() - resp, err := dbllm.GenerateContent(ctx, []llms.MessageContent{ - { - Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{ - llms.TextContent{Text: "Brazil is a country?"}, + resp, err := dbllm.GenerateContent(ctx, + []llms.MessageContent{ + { + Role: llms.ChatMessageTypeHuman, + Parts: []llms.ContentPart{ + llms.TextContent{Text: "Brazil is a country?"}, + }, }, }, - }) - if err != nil { - t.Fatalf("failed to generate content: %v", err) - } + ) + require.NoError(t, err, "failed to generate content") - if len(resp.Choices) < 1 { - t.Fatalf("empty response from model") - } + assert.NotEmpty(t, resp.Choices, "expected at least one choice from model") } func TestDatabricksLlama31(t *testing.T) { t.Parallel() const envVar = "DATABRICKS_LLAMA31_URL" - url := os.Getenv(envVar) - if url == "" { t.Skipf("%s not set", envVar) } llama31 := databricksclientsllama31.NewLlama31() - testModelStream(t, llama31, url) testModel(t, llama31, url) } -func TestDatabricksMistal1(t *testing.T) { +func TestDatabricksMistral1(t *testing.T) { t.Parallel() const envVarURL = "DATABRICKS_MISTAL1_URL" @@ -109,13 +108,11 @@ func TestDatabricksMistal1(t *testing.T) { if url == "" { t.Skipf("%s not set", envVarURL) } - if model == "" { t.Skipf("%s not set", envVarModel) } mistral1 := databricksclientsmistralv1.NewMistral1(model) - testModelStream(t, mistral1, url) testModel(t, mistral1, url) } From 313cdad4e5b08d15ede9fb4998d0d4a6a43d232c Mon Sep 17 00:00:00 2001 From: Thomas Date: Thu, 16 Jan 2025 18:23:16 +0100 Subject: [PATCH 8/8] gofumpt --- llms/databricks/clients/llama/v3.1/llama31.go | 1 - 1 file changed, 1 deletion(-) diff --git a/llms/databricks/clients/llama/v3.1/llama31.go b/llms/databricks/clients/llama/v3.1/llama31.go index fb0f19a17..87f316487 100644 --- a/llms/databricks/clients/llama/v3.1/llama31.go +++ b/llms/databricks/clients/llama/v3.1/llama31.go @@ -101,7 +101,6 @@ func (l *Llama31) FormatStreamResponse(_ context.Context, response []byte) (*llm } func formatResponse[T LlamaChoiceDelta | LlamaChoice](response []byte) (*llms.ContentResponse, error) { - // Parse the LlamaResponse JSON var llamaResp LlamaResponse[T] err := json.Unmarshal(response, &llamaResp)