From 7d7257240fcf36af298d899ceb2ffd7b9baed988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Fri, 31 Jan 2025 22:36:29 +0000 Subject: [PATCH 01/15] Add Bedrock provider --- crates/goose/Cargo.toml | 5 + crates/goose/src/providers/bedrock.rs | 332 ++++++++++++++++++++++++++ crates/goose/src/providers/factory.rs | 3 + crates/goose/src/providers/mod.rs | 1 + 4 files changed, 341 insertions(+) create mode 100644 crates/goose/src/providers/bedrock.rs diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 053f55ab3c..a706d6a8e2 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -61,6 +61,11 @@ once_cell = "1.20.2" dirs = "6.0.0" rand = "0.8.5" +# For Bedrock provider +aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } +aws-smithy-types = "1.2.12" +aws-sdk-bedrockruntime = "1.72.0" + [dev-dependencies] criterion = "0.5" tempfile = "3.15.0" diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs new file mode 100644 index 0000000000..c0b9f18232 --- /dev/null +++ b/crates/goose/src/providers/bedrock.rs @@ -0,0 +1,332 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; +use aws_sdk_bedrockruntime::{types as bedrock, Client}; +use aws_smithy_types::{Document, Number}; +use chrono::Utc; +use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; +use serde_json::Value; + +use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::errors::ProviderError; +use crate::config::Config; +use crate::message::{Message, MessageContent}; +use crate::model::ModelConfig; + +pub const BEDROCK_DOC_LINK: &str = + "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"; + +pub const BEDROCK_DEFAULT_MODEL: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0"; +pub const BEDROCK_KNOWN_MODELS: &[&str] = &[ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", +]; + +#[derive(Debug, serde::Serialize)] +pub struct BedrockProvider { + #[serde(skip)] + client: Client, + model: ModelConfig, +} + +impl BedrockProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = Config::global(); + let sdk_config = tokio::task::block_in_place(|| { + let mut aws_config = aws_config::from_env(); + + if let Some(region) = config.get::("AWS_REGION").ok() { + aws_config = aws_config.region(aws_config::Region::new(region)); + } + + tokio::runtime::Handle::current().block_on(aws_config.load()) + }); + let client = Client::new(&sdk_config); + + Ok(Self { client, model }) + } +} + +#[async_trait] +impl Provider for BedrockProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "bedrock", + "Amazon Bedrock", + "Run models through Amazon Bedrock", + BEDROCK_DEFAULT_MODEL, + BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(), + BEDROCK_DOC_LINK, + vec![ConfigKey::new("AWS_REGION", false, false, None)], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let model_name = &self.model.model_name; + + let response = self + .client + .converse() + .tool_config(to_bedrock_tool_config(tools)?) + .model_id(model_name.to_string()) + .system(bedrock::SystemContentBlock::Text(system.to_string())) + .set_messages(Some( + messages + .iter() + .map(to_bedrock_message) + .collect::>()?, + )) + .send() + .await + .or_else(|err| Err(anyhow!("Failed to call Bedrock: {}", err)))?; + + let message = match response.output { + Some(bedrock::ConverseOutput::Message(message)) => message, + _ => { + return Err(ProviderError::RequestFailed( + "No output from Bedrock".to_string(), + )) + } + }; + + let usage = response + .usage + .as_ref() + .map(from_bedrock_usage) + .unwrap_or_default(); + + let message = from_bedrock_message(&message)?; + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + + Ok((message, provider_usage)) + } +} + +fn to_bedrock_message(message: &Message) -> Result { + bedrock::Message::builder() + .role(to_bedrock_role(&message.role)) + .set_content(Some( + message + .content + .iter() + .map(to_bedrock_message_content) + .collect::>()?, + )) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err)) +} + +fn to_bedrock_message_content(content: &MessageContent) -> Result { + Ok(match content { + MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()), + MessageContent::Image(_) => { + bail!("Image content is not supported by Bedrock provider yet") + } + MessageContent::ToolRequest(tool_req) => { + let tool_use_id = tool_req.id.to_string(); + let tool_use = if let Some(call) = tool_req.tool_call.as_ref().ok() { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .name(call.name.to_string()) + .input(to_bedrock_json(&call.arguments)) + .build() + } else { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .build() + }?; + bedrock::ContentBlock::ToolUse(tool_use) + } + MessageContent::ToolResponse(tool_res) => { + let content = match &tool_res.tool_result { + Ok(content) => Some( + content + .iter() + .map(to_bedrock_tool_result_content_block) + .collect::>()?, + ), + Err(_) => None, + }; + bedrock::ContentBlock::ToolResult( + bedrock::ToolResultBlock::builder() + .tool_use_id(tool_res.id.to_string()) + .status(if content.is_some() { + bedrock::ToolResultStatus::Success + } else { + bedrock::ToolResultStatus::Error + }) + .set_content(content) + .build()?, + ) + } + }) +} + +fn to_bedrock_tool_result_content_block( + content: &Content, +) -> Result { + Ok(match content { + Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), + Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), + Content::Resource(_) => bail!("Resource content is not supported by Bedrock provider yet"), + }) +} + +fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { + match role { + Role::User => bedrock::ConversationRole::User, + Role::Assistant => bedrock::ConversationRole::Assistant, + } +} + +fn to_bedrock_tool_config(tools: &[Tool]) -> Result { + Ok(bedrock::ToolConfiguration::builder() + .set_tools(Some( + tools.iter().map(to_bedrock_tool).collect::>()?, + )) + .build()?) +} + +fn to_bedrock_tool(tool: &Tool) -> Result { + Ok(bedrock::Tool::ToolSpec( + bedrock::ToolSpecification::builder() + .name(tool.name.to_string()) + .description(tool.description.to_string()) + .input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json( + &tool.input_schema, + ))) + .build()?, + )) +} + +fn to_bedrock_json(value: &Value) -> Document { + match value { + Value::Null => Document::Null, + Value::Bool(bool) => Document::Bool(*bool), + Value::Number(num) => { + if let Some(n) = num.as_u64() { + Document::Number(Number::PosInt(n)) + } else if let Some(n) = num.as_i64() { + Document::Number(Number::NegInt(n)) + } else if let Some(n) = num.as_f64() { + Document::Number(Number::Float(n)) + } else { + unreachable!() + } + } + Value::String(str) => Document::String(str.to_string()), + Value::Array(arr) => Document::Array(arr.into_iter().map(to_bedrock_json).collect()), + Value::Object(obj) => Document::Object(HashMap::from_iter( + obj.into_iter() + .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), + )), + } +} + +fn from_bedrock_message(message: &bedrock::Message) -> Result { + let role = from_bedrock_role(message.role())?; + let content = message + .content() + .iter() + .map(from_bedrock_content_block) + .collect::>>()?; + let created = Utc::now().timestamp(); + + Ok(Message { + role, + content, + created, + }) +} + +fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result { + Ok(match block { + bedrock::ContentBlock::Text(text) => MessageContent::text(text), + bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request( + tool_use.tool_use_id.to_string(), + Ok(ToolCall::new( + tool_use.name.to_string(), + from_bedrock_json(&tool_use.input), + )), + ), + bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response( + tool_res.tool_use_id.to_string(), + if tool_res.content.is_empty() { + Err(ToolError::ExecutionError( + "Empty content for tool use from Bedrock".to_string(), + )) + } else { + tool_res + .content + .iter() + .map(from_bedrock_tool_result_content_block) + .collect::>>() + }, + ), + _ => bail!("Unsupported content block type from Bedrock"), + }) +} + +fn from_bedrock_tool_result_content_block( + content: &bedrock::ToolResultContentBlock, +) -> ToolResult { + Ok(match content { + bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()), + _ => { + return Err(ToolError::ExecutionError( + "Unsupported tool result from Bedrock".to_string(), + )) + } + }) +} + +fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result { + Ok(match role { + bedrock::ConversationRole::User => Role::User, + bedrock::ConversationRole::Assistant => Role::Assistant, + _ => bail!("Unknown role from Bedrock"), + }) +} + +fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + } +} + +fn from_bedrock_json(document: &Document) -> Value { + match document { + Document::Null => Value::Null, + Document::Bool(bool) => Value::Bool(*bool), + Document::Number(num) => match num { + Number::PosInt(i) => Value::Number((*i).into()), + Number::NegInt(i) => Value::Number((*i).into()), + Number::Float(f) => { + Value::Number(serde_json::Number::from_f64(*f).expect("Expected a valid f64")) + } + }, + Document::String(str) => Value::String(str.clone()), + Document::Array(arr) => Value::Array(arr.iter().map(from_bedrock_json).collect()), + Document::Object(obj) => Value::Object( + obj.iter() + .map(|(key, val)| (key.clone(), from_bedrock_json(val))) + .collect(), + ), + } +} diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index ed169aa7e8..d17fb8893f 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -2,6 +2,7 @@ use super::{ anthropic::AnthropicProvider, azure::AzureProvider, base::{Provider, ProviderMetadata}, + bedrock::BedrockProvider, databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, @@ -16,6 +17,7 @@ pub fn providers() -> Vec { vec![ AnthropicProvider::metadata(), AzureProvider::metadata(), + BedrockProvider::metadata(), DatabricksProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), @@ -30,6 +32,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(OpenAiProvider::from_env(model)?)), "anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)), "azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)), + "bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)), "databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)), "groq" => Ok(Box::new(GroqProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index de6225767a..634224fd7e 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,6 +1,7 @@ pub mod anthropic; pub mod azure; pub mod base; +pub mod bedrock; pub mod databricks; pub mod errors; mod factory; From 01954b0bd1a9c8d0428cdca78971ec6bc7e17df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:13:57 +0000 Subject: [PATCH 02/15] Fix Clippy errors --- crates/goose/src/providers/bedrock.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index c0b9f18232..f4e5e72f3f 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -36,7 +36,7 @@ impl BedrockProvider { let sdk_config = tokio::task::block_in_place(|| { let mut aws_config = aws_config::from_env(); - if let Some(region) = config.get::("AWS_REGION").ok() { + if let Ok(region) = config.get::("AWS_REGION") { aws_config = aws_config.region(aws_config::Region::new(region)); } @@ -92,7 +92,7 @@ impl Provider for BedrockProvider { )) .send() .await - .or_else(|err| Err(anyhow!("Failed to call Bedrock: {}", err)))?; + .map_err(|err| anyhow!("Failed to call Bedrock: {}", err))?; let message = match response.output { Some(bedrock::ConverseOutput::Message(message)) => message, @@ -138,7 +138,7 @@ fn to_bedrock_message_content(content: &MessageContent) -> Result { let tool_use_id = tool_req.id.to_string(); - let tool_use = if let Some(call) = tool_req.tool_call.as_ref().ok() { + let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() { bedrock::ToolUseBlock::builder() .tool_use_id(tool_use_id) .name(call.name.to_string()) @@ -229,7 +229,7 @@ fn to_bedrock_json(value: &Value) -> Document { } } Value::String(str) => Document::String(str.to_string()), - Value::Array(arr) => Document::Array(arr.into_iter().map(to_bedrock_json).collect()), + Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()), Value::Object(obj) => Document::Object(HashMap::from_iter( obj.into_iter() .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), From b48b9be38fb4e9401e871e642d55478ca99a85c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:17:17 +0000 Subject: [PATCH 03/15] Return a `Result<>` from `from_bedrock_json` instead of using `expect` --- crates/goose/src/providers/bedrock.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index f4e5e72f3f..c882b5086e 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -260,7 +260,7 @@ fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result MessageContent::tool_response( @@ -310,23 +310,25 @@ fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { } } -fn from_bedrock_json(document: &Document) -> Value { - match document { +fn from_bedrock_json(document: &Document) -> Result { + Ok(match document { Document::Null => Value::Null, Document::Bool(bool) => Value::Bool(*bool), Document::Number(num) => match num { Number::PosInt(i) => Value::Number((*i).into()), Number::NegInt(i) => Value::Number((*i).into()), - Number::Float(f) => { - Value::Number(serde_json::Number::from_f64(*f).expect("Expected a valid f64")) - } + Number::Float(f) => Value::Number( + serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?, + ), }, Document::String(str) => Value::String(str.clone()), - Document::Array(arr) => Value::Array(arr.iter().map(from_bedrock_json).collect()), + Document::Array(arr) => { + Value::Array(arr.iter().map(from_bedrock_json).collect::>()?) + } Document::Object(obj) => Value::Object( obj.iter() - .map(|(key, val)| (key.clone(), from_bedrock_json(val))) - .collect(), + .map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?))) + .collect::>()?, ), - } + }) } From 722cac09935426db630e76bf0f7bd5d3b6eb4c7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:23:14 +0000 Subject: [PATCH 04/15] Remove `AWS_REGION` configuration and just rely on `aws_config::from_env()` --- crates/goose/src/providers/bedrock.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index c882b5086e..0a086fc3a0 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -8,9 +8,8 @@ use chrono::Utc; use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; -use crate::config::Config; use crate::message::{Message, MessageContent}; use crate::model::ModelConfig; @@ -32,15 +31,8 @@ pub struct BedrockProvider { impl BedrockProvider { pub fn from_env(model: ModelConfig) -> Result { - let config = Config::global(); let sdk_config = tokio::task::block_in_place(|| { - let mut aws_config = aws_config::from_env(); - - if let Ok(region) = config.get::("AWS_REGION") { - aws_config = aws_config.region(aws_config::Region::new(region)); - } - - tokio::runtime::Handle::current().block_on(aws_config.load()) + tokio::runtime::Handle::current().block_on(aws_config::from_env().load()) }); let client = Client::new(&sdk_config); @@ -58,7 +50,7 @@ impl Provider for BedrockProvider { BEDROCK_DEFAULT_MODEL, BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(), BEDROCK_DOC_LINK, - vec![ConfigKey::new("AWS_REGION", false, false, None)], + vec![], ) } From 4975a7778d9b9d69a78f6cb88260ee565d4cb361 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 16:40:45 +0000 Subject: [PATCH 05/15] Add Bedrock provider tests --- crates/goose/src/providers/bedrock.rs | 7 ++++++ crates/goose/tests/providers.rs | 32 ++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 0a086fc3a0..e08dd7abd6 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -40,6 +40,13 @@ impl BedrockProvider { } } +impl Default for BedrockProvider { + fn default() -> Self { + let model = ModelConfig::new(BedrockProvider::metadata().default_model); + BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider") + } +} + #[async_trait] impl Provider for BedrockProvider { fn metadata() -> ProviderMetadata { diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 6a5f4b9dab..332f3ee765 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -3,7 +3,9 @@ use dotenv::dotenv; use goose::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; -use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter}; +use goose::providers::{ + anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, +}; use mcp_core::content::Content; use mcp_core::tool::Tool; use std::collections::HashMap; @@ -374,6 +376,34 @@ async fn test_azure_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_bedrock_provider_long_term_credentials() -> Result<()> { + test_provider( + "Bedrock", + &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + None, + bedrock::BedrockProvider::default, + ) + .await +} + +#[tokio::test] +async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { + let env_mods = HashMap::from_iter([ + // Ensure to unset long-term credentials to use AWS Profile provider + ("AWS_ACCESS_KEY_ID", None), + ("AWS_SECRET_ACCESS_KEY", None), + ]); + + test_provider( + "Bedrock AWS Profile Credentials", + &["AWS_PROFILE"], + Some(env_mods), + bedrock::BedrockProvider::default, + ) + .await +} + #[tokio::test] async fn test_databricks_provider() -> Result<()> { test_provider( From 19d08d694fa35107c4941d6cf9d567594ba320d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 17:43:22 +0000 Subject: [PATCH 06/15] Add truncate agent tests for Bedrock provider --- crates/goose/tests/truncate_agent.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index d3702d5f80..fbd4967414 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -8,7 +8,7 @@ use goose::model::ModelConfig; use goose::providers::base::Provider; use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; use goose::providers::{ - azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider, + azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, }; use goose::providers::{google::GoogleProvider, groq::GroqProvider}; @@ -18,6 +18,7 @@ enum ProviderType { Azure, OpenAi, Anthropic, + Bedrock, Databricks, Google, Groq, @@ -35,6 +36,7 @@ impl ProviderType { ], ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], + ProviderType::Bedrock => &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Google => &["GOOGLE_API_KEY"], ProviderType::Groq => &["GROQ_API_KEY"], @@ -66,6 +68,7 @@ impl ProviderType { ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?), ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?), ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), + ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?), ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?), @@ -200,6 +203,16 @@ mod tests { .await } + #[tokio::test] + async fn test_truncate_agent_with_bedrock() -> Result<()> { + run_test_with_config(TestConfig { + provider_type: ProviderType::Bedrock, + model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + context_window: 200_000, + }) + .await + } + #[tokio::test] async fn test_truncate_agent_with_databricks() -> Result<()> { run_test_with_config(TestConfig { From f712ed8085af87fd5e0167653dcd2be696f5fa70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 17:44:35 +0000 Subject: [PATCH 07/15] Use `futures::executor::block_on` to load AWS Config Tokio's runtime panics on single threaded tests. --- crates/goose/src/providers/bedrock.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index e08dd7abd6..f72c5068d8 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -31,9 +31,7 @@ pub struct BedrockProvider { impl BedrockProvider { pub fn from_env(model: ModelConfig) -> Result { - let sdk_config = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(aws_config::from_env().load()) - }); + let sdk_config = futures::executor::block_on(aws_config::load_from_env()); let client = Client::new(&sdk_config); Ok(Self { client, model }) From b700083164e92646f6ba5e31f0d2a38c575b7b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Sat, 1 Feb 2025 17:45:54 +0000 Subject: [PATCH 08/15] Properly map Bedrock errors to `ProviderError`s --- crates/goose/src/providers/bedrock.rs | 46 +++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index f72c5068d8..849a377b2b 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::{anyhow, bail, Result}; use async_trait::async_trait; +use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use aws_smithy_types::{Document, Number}; use chrono::Utc; @@ -75,21 +76,52 @@ impl Provider for BedrockProvider { ) -> Result<(Message, ProviderUsage), ProviderError> { let model_name = &self.model.model_name; - let response = self + let mut request = self .client .converse() - .tool_config(to_bedrock_tool_config(tools)?) - .model_id(model_name.to_string()) .system(bedrock::SystemContentBlock::Text(system.to_string())) + .model_id(model_name.to_string()) .set_messages(Some( messages .iter() .map(to_bedrock_message) .collect::>()?, - )) - .send() - .await - .map_err(|err| anyhow!("Failed to call Bedrock: {}", err))?; + )); + + if !tools.is_empty() { + request = request.tool_config(to_bedrock_tool_config(tools)?); + } + + let response = request.send().await; + + let response = match response { + Ok(response) => response, + Err(err) => { + return Err(match err.into_service_error() { + ConverseError::AccessDeniedException(err) => { + ProviderError::Authentication(format!("Failed to call Bedrock: {}", err)) + } + ConverseError::ThrottlingException(err) => { + ProviderError::RateLimitExceeded(format!("Failed to call Bedrock: {}", err)) + } + ConverseError::ValidationException(err) + if err + .message() + .unwrap_or_default() + .contains("Input is too long for requested model.") => + { + ProviderError::ContextLengthExceeded(format!( + "Failed to call Bedrock: {}", + err + )) + } + ConverseError::ModelErrorException(err) => { + ProviderError::ExecutionError(format!("Failed to call Bedrock: {}", err)) + } + err => ProviderError::ServerError(format!("Failed to call Bedrock: {}", err,)), + }); + } + }; let message = match response.output { Some(bedrock::ConverseOutput::Message(message)) => message, From db8a827a2e8267b87bb93676b4d0125c90215e72 Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Tue, 4 Feb 2025 16:24:46 -0500 Subject: [PATCH 09/15] update error messages to be more descriptive and add AWS_REGION detection to truncate_agent tests --- crates/goose/src/providers/bedrock.rs | 12 ++++++------ crates/goose/tests/truncate_agent.rs | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 849a377b2b..b73c34f524 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -52,7 +52,7 @@ impl Provider for BedrockProvider { ProviderMetadata::new( "bedrock", "Amazon Bedrock", - "Run models through Amazon Bedrock", + "Run models through Amazon Bedrock. You may have to set AWS_ACCESS_KEY_ID, AWS_ACCESS_KEY, and AWS_REGION as env vars before configuring.", BEDROCK_DEFAULT_MODEL, BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(), BEDROCK_DOC_LINK, @@ -99,10 +99,10 @@ impl Provider for BedrockProvider { Err(err) => { return Err(match err.into_service_error() { ConverseError::AccessDeniedException(err) => { - ProviderError::Authentication(format!("Failed to call Bedrock: {}", err)) + ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err)) } ConverseError::ThrottlingException(err) => { - ProviderError::RateLimitExceeded(format!("Failed to call Bedrock: {}", err)) + ProviderError::RateLimitExceeded(format!("Failed to call Bedrock: {:?}", err)) } ConverseError::ValidationException(err) if err @@ -111,14 +111,14 @@ impl Provider for BedrockProvider { .contains("Input is too long for requested model.") => { ProviderError::ContextLengthExceeded(format!( - "Failed to call Bedrock: {}", + "Failed to call Bedrock: {:?}", err )) } ConverseError::ModelErrorException(err) => { - ProviderError::ExecutionError(format!("Failed to call Bedrock: {}", err)) + ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err)) } - err => ProviderError::ServerError(format!("Failed to call Bedrock: {}", err,)), + err => ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,)), }); } }; diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index fbd4967414..be2b14ecc0 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -36,7 +36,7 @@ impl ProviderType { ], ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], - ProviderType::Bedrock => &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + ProviderType::Bedrock => &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"], ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Google => &["GOOGLE_API_KEY"], ProviderType::Groq => &["GROQ_API_KEY"], From b9f244f480be2e83bd1ed88d0458ee01231e5fe2 Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Tue, 4 Feb 2025 16:35:00 -0500 Subject: [PATCH 10/15] migrate bedrock formatting into formats/ utils --- crates/goose/src/providers/bedrock.rs | 236 +----------------- crates/goose/src/providers/formats/bedrock.rs | 228 +++++++++++++++++ crates/goose/src/providers/formats/mod.rs | 1 + 3 files changed, 242 insertions(+), 223 deletions(-) create mode 100644 crates/goose/src/providers/formats/bedrock.rs diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index b73c34f524..48308ac270 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -6,14 +6,19 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use aws_smithy_types::{Document, Number}; use chrono::Utc; -use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{Role, Tool}; use serde_json::Value; use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; -use crate::message::{Message, MessageContent}; +use crate::message::Message; use crate::model::ModelConfig; +// Import the migrated helper functions from providers/formats/bedrock.rs +use super::formats::bedrock::{ + from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config, +}; + pub const BEDROCK_DOC_LINK: &str = "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"; @@ -101,9 +106,9 @@ impl Provider for BedrockProvider { ConverseError::AccessDeniedException(err) => { ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err)) } - ConverseError::ThrottlingException(err) => { - ProviderError::RateLimitExceeded(format!("Failed to call Bedrock: {:?}", err)) - } + ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded( + format!("Failed to call Bedrock: {:?}", err), + ), ConverseError::ValidationException(err) if err .message() @@ -118,7 +123,9 @@ impl Provider for BedrockProvider { ConverseError::ModelErrorException(err) => { ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err)) } - err => ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,)), + err => { + ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,)) + } }); } }; @@ -144,220 +151,3 @@ impl Provider for BedrockProvider { Ok((message, provider_usage)) } } - -fn to_bedrock_message(message: &Message) -> Result { - bedrock::Message::builder() - .role(to_bedrock_role(&message.role)) - .set_content(Some( - message - .content - .iter() - .map(to_bedrock_message_content) - .collect::>()?, - )) - .build() - .map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err)) -} - -fn to_bedrock_message_content(content: &MessageContent) -> Result { - Ok(match content { - MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()), - MessageContent::Image(_) => { - bail!("Image content is not supported by Bedrock provider yet") - } - MessageContent::ToolRequest(tool_req) => { - let tool_use_id = tool_req.id.to_string(); - let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() { - bedrock::ToolUseBlock::builder() - .tool_use_id(tool_use_id) - .name(call.name.to_string()) - .input(to_bedrock_json(&call.arguments)) - .build() - } else { - bedrock::ToolUseBlock::builder() - .tool_use_id(tool_use_id) - .build() - }?; - bedrock::ContentBlock::ToolUse(tool_use) - } - MessageContent::ToolResponse(tool_res) => { - let content = match &tool_res.tool_result { - Ok(content) => Some( - content - .iter() - .map(to_bedrock_tool_result_content_block) - .collect::>()?, - ), - Err(_) => None, - }; - bedrock::ContentBlock::ToolResult( - bedrock::ToolResultBlock::builder() - .tool_use_id(tool_res.id.to_string()) - .status(if content.is_some() { - bedrock::ToolResultStatus::Success - } else { - bedrock::ToolResultStatus::Error - }) - .set_content(content) - .build()?, - ) - } - }) -} - -fn to_bedrock_tool_result_content_block( - content: &Content, -) -> Result { - Ok(match content { - Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), - Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), - Content::Resource(_) => bail!("Resource content is not supported by Bedrock provider yet"), - }) -} - -fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { - match role { - Role::User => bedrock::ConversationRole::User, - Role::Assistant => bedrock::ConversationRole::Assistant, - } -} - -fn to_bedrock_tool_config(tools: &[Tool]) -> Result { - Ok(bedrock::ToolConfiguration::builder() - .set_tools(Some( - tools.iter().map(to_bedrock_tool).collect::>()?, - )) - .build()?) -} - -fn to_bedrock_tool(tool: &Tool) -> Result { - Ok(bedrock::Tool::ToolSpec( - bedrock::ToolSpecification::builder() - .name(tool.name.to_string()) - .description(tool.description.to_string()) - .input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json( - &tool.input_schema, - ))) - .build()?, - )) -} - -fn to_bedrock_json(value: &Value) -> Document { - match value { - Value::Null => Document::Null, - Value::Bool(bool) => Document::Bool(*bool), - Value::Number(num) => { - if let Some(n) = num.as_u64() { - Document::Number(Number::PosInt(n)) - } else if let Some(n) = num.as_i64() { - Document::Number(Number::NegInt(n)) - } else if let Some(n) = num.as_f64() { - Document::Number(Number::Float(n)) - } else { - unreachable!() - } - } - Value::String(str) => Document::String(str.to_string()), - Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()), - Value::Object(obj) => Document::Object(HashMap::from_iter( - obj.into_iter() - .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), - )), - } -} - -fn from_bedrock_message(message: &bedrock::Message) -> Result { - let role = from_bedrock_role(message.role())?; - let content = message - .content() - .iter() - .map(from_bedrock_content_block) - .collect::>>()?; - let created = Utc::now().timestamp(); - - Ok(Message { - role, - content, - created, - }) -} - -fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result { - Ok(match block { - bedrock::ContentBlock::Text(text) => MessageContent::text(text), - bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request( - tool_use.tool_use_id.to_string(), - Ok(ToolCall::new( - tool_use.name.to_string(), - from_bedrock_json(&tool_use.input)?, - )), - ), - bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response( - tool_res.tool_use_id.to_string(), - if tool_res.content.is_empty() { - Err(ToolError::ExecutionError( - "Empty content for tool use from Bedrock".to_string(), - )) - } else { - tool_res - .content - .iter() - .map(from_bedrock_tool_result_content_block) - .collect::>>() - }, - ), - _ => bail!("Unsupported content block type from Bedrock"), - }) -} - -fn from_bedrock_tool_result_content_block( - content: &bedrock::ToolResultContentBlock, -) -> ToolResult { - Ok(match content { - bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()), - _ => { - return Err(ToolError::ExecutionError( - "Unsupported tool result from Bedrock".to_string(), - )) - } - }) -} - -fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result { - Ok(match role { - bedrock::ConversationRole::User => Role::User, - bedrock::ConversationRole::Assistant => Role::Assistant, - _ => bail!("Unknown role from Bedrock"), - }) -} - -fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { - Usage { - input_tokens: Some(usage.input_tokens), - output_tokens: Some(usage.output_tokens), - total_tokens: Some(usage.total_tokens), - } -} - -fn from_bedrock_json(document: &Document) -> Result { - Ok(match document { - Document::Null => Value::Null, - Document::Bool(bool) => Value::Bool(*bool), - Document::Number(num) => match num { - Number::PosInt(i) => Value::Number((*i).into()), - Number::NegInt(i) => Value::Number((*i).into()), - Number::Float(f) => Value::Number( - serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?, - ), - }, - Document::String(str) => Value::String(str.clone()), - Document::Array(arr) => { - Value::Array(arr.iter().map(from_bedrock_json).collect::>()?) - } - Document::Object(obj) => Value::Object( - obj.iter() - .map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?))) - .collect::>()?, - ), - }) -} diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs new file mode 100644 index 0000000000..88b588c0db --- /dev/null +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -0,0 +1,228 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, bail, Result}; +use aws_sdk_bedrockruntime::types as bedrock; +use aws_smithy_types::{Document, Number}; +use chrono::Utc; +use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; +use serde_json::Value; + +use super::super::base::Usage; +use crate::message::{Message, MessageContent}; + +pub fn to_bedrock_message(message: &Message) -> Result { + bedrock::Message::builder() + .role(to_bedrock_role(&message.role)) + .set_content(Some( + message + .content + .iter() + .map(to_bedrock_message_content) + .collect::>()?, + )) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err)) +} + +pub fn to_bedrock_message_content(content: &MessageContent) -> Result { + Ok(match content { + MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()), + MessageContent::Image(_) => { + bail!("Image content is not supported by Bedrock provider yet") + } + MessageContent::ToolRequest(tool_req) => { + let tool_use_id = tool_req.id.to_string(); + let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .name(call.name.to_string()) + .input(to_bedrock_json(&call.arguments)) + .build() + } else { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .build() + }?; + bedrock::ContentBlock::ToolUse(tool_use) + } + MessageContent::ToolResponse(tool_res) => { + let content = match &tool_res.tool_result { + Ok(content) => Some( + content + .iter() + .map(to_bedrock_tool_result_content_block) + .collect::>()?, + ), + Err(_) => None, + }; + bedrock::ContentBlock::ToolResult( + bedrock::ToolResultBlock::builder() + .tool_use_id(tool_res.id.to_string()) + .status(if content.is_some() { + bedrock::ToolResultStatus::Success + } else { + bedrock::ToolResultStatus::Error + }) + .set_content(content) + .build()?, + ) + } + }) +} + +pub fn to_bedrock_tool_result_content_block( + content: &Content, +) -> Result { + Ok(match content { + Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), + Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), + Content::Resource(_) => bail!("Resource content is not supported by Bedrock provider yet"), + }) +} + +pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { + match role { + Role::User => bedrock::ConversationRole::User, + Role::Assistant => bedrock::ConversationRole::Assistant, + } +} + +pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result { + Ok(bedrock::ToolConfiguration::builder() + .set_tools(Some( + tools.iter().map(to_bedrock_tool).collect::>()?, + )) + .build()?) +} + +pub fn to_bedrock_tool(tool: &Tool) -> Result { + Ok(bedrock::Tool::ToolSpec( + bedrock::ToolSpecification::builder() + .name(tool.name.to_string()) + .description(tool.description.to_string()) + .input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json( + &tool.input_schema, + ))) + .build()?, + )) +} + +pub fn to_bedrock_json(value: &Value) -> Document { + match value { + Value::Null => Document::Null, + Value::Bool(bool) => Document::Bool(*bool), + Value::Number(num) => { + if let Some(n) = num.as_u64() { + Document::Number(Number::PosInt(n)) + } else if let Some(n) = num.as_i64() { + Document::Number(Number::NegInt(n)) + } else if let Some(n) = num.as_f64() { + Document::Number(Number::Float(n)) + } else { + unreachable!() + } + } + Value::String(str) => Document::String(str.to_string()), + Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()), + Value::Object(obj) => Document::Object(HashMap::from_iter( + obj.into_iter() + .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), + )), + } +} + +pub fn from_bedrock_message(message: &bedrock::Message) -> Result { + let role = from_bedrock_role(message.role())?; + let content = message + .content() + .iter() + .map(from_bedrock_content_block) + .collect::>>()?; + let created = Utc::now().timestamp(); + + Ok(Message { + role, + content, + created, + }) +} + +pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result { + Ok(match block { + bedrock::ContentBlock::Text(text) => MessageContent::text(text), + bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request( + tool_use.tool_use_id.to_string(), + Ok(ToolCall::new( + tool_use.name.to_string(), + from_bedrock_json(&tool_use.input)?, + )), + ), + bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response( + tool_res.tool_use_id.to_string(), + if tool_res.content.is_empty() { + Err(ToolError::ExecutionError( + "Empty content for tool use from Bedrock".to_string(), + )) + } else { + tool_res + .content + .iter() + .map(from_bedrock_tool_result_content_block) + .collect::>>() + }, + ), + _ => bail!("Unsupported content block type from Bedrock"), + }) +} + +pub fn from_bedrock_tool_result_content_block( + content: &bedrock::ToolResultContentBlock, +) -> ToolResult { + Ok(match content { + bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()), + _ => { + return Err(ToolError::ExecutionError( + "Unsupported tool result from Bedrock".to_string(), + )) + } + }) +} + +pub fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result { + Ok(match role { + bedrock::ConversationRole::User => Role::User, + bedrock::ConversationRole::Assistant => Role::Assistant, + _ => bail!("Unknown role from Bedrock"), + }) +} + +pub fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + } +} + +pub fn from_bedrock_json(document: &Document) -> Result { + Ok(match document { + Document::Null => Value::Null, + Document::Bool(bool) => Value::Bool(*bool), + Document::Number(num) => match num { + Number::PosInt(i) => Value::Number((*i).into()), + Number::NegInt(i) => Value::Number((*i).into()), + Number::Float(f) => Value::Number( + serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?, + ), + }, + Document::String(str) => Value::String(str.clone()), + Document::Array(arr) => { + Value::Array(arr.iter().map(from_bedrock_json).collect::>()?) + } + Document::Object(obj) => Value::Object( + obj.iter() + .map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?))) + .collect::>()?, + ), + }) +} diff --git a/crates/goose/src/providers/formats/mod.rs b/crates/goose/src/providers/formats/mod.rs index 7134682858..780f384884 100644 --- a/crates/goose/src/providers/formats/mod.rs +++ b/crates/goose/src/providers/formats/mod.rs @@ -1,3 +1,4 @@ pub mod anthropic; +pub mod bedrock; pub mod google; pub mod openai; From ade209306a3c9757c9561036881fc2e0d8e3c873 Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Tue, 4 Feb 2025 16:46:55 -0500 Subject: [PATCH 11/15] add debug trace on bedrock completion --- crates/goose/src/providers/bedrock.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 48308ac270..2e8287d977 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -13,6 +13,7 @@ use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use crate::message::Message; use crate::model::ModelConfig; +use crate::providers::utils::emit_debug_trace; // Import the migrated helper functions from providers/formats/bedrock.rs use super::formats::bedrock::{ @@ -146,8 +147,16 @@ impl Provider for BedrockProvider { .unwrap_or_default(); let message = from_bedrock_message(&message)?; - let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + + // Add debug trace with input context + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace(&self.model, &debug_payload, &serde_json::to_value(&message).unwrap_or_default(), &usage); + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); Ok((message, provider_usage)) } } From a9d52e9fe384e3a9163662e4690601aba662c491 Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Tue, 4 Feb 2025 16:50:51 -0500 Subject: [PATCH 12/15] format --- crates/goose/src/providers/bedrock.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 2e8287d977..c15e8a51c0 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -147,14 +147,19 @@ impl Provider for BedrockProvider { .unwrap_or_default(); let message = from_bedrock_message(&message)?; - + // Add debug trace with input context let debug_payload = serde_json::json!({ "system": system, "messages": messages, "tools": tools }); - emit_debug_trace(&self.model, &debug_payload, &serde_json::to_value(&message).unwrap_or_default(), &usage); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); let provider_usage = ProviderUsage::new(model_name.to_string(), usage); Ok((message, provider_usage)) From 9b109b23e4d84386143dd91fe3415763bd3aed2c Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Tue, 4 Feb 2025 16:52:56 -0500 Subject: [PATCH 13/15] remove unused imports --- crates/goose/src/providers/bedrock.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index c15e8a51c0..8867ddb70d 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,15 +1,10 @@ -use std::collections::HashMap; - -use anyhow::{anyhow, bail, Result}; +use anyhow::Result; use async_trait::async_trait; use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; -use aws_smithy_types::{Document, Number}; -use chrono::Utc; -use mcp_core::{Role, Tool}; -use serde_json::Value; +use mcp_core:: Tool; -use super::base::{Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use crate::message::Message; use crate::model::ModelConfig; From 190fd8a11ec26ee01d98fe3f6eaf1df5d5f0d0d8 Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Tue, 4 Feb 2025 16:53:36 -0500 Subject: [PATCH 14/15] format --- crates/goose/src/providers/bedrock.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 8867ddb70d..ad40a321c1 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -2,7 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; -use mcp_core:: Tool; +use mcp_core::Tool; use super::base::{Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; From fa1fe463d2c4fb0cd8e3f2575bac4f0c7b20941f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Varl=C4=B1?= Date: Tue, 4 Feb 2025 17:52:26 +0000 Subject: [PATCH 15/15] cherry picked 06d396f --- crates/goose/src/providers/formats/bedrock.rs | 48 +++++++++++++++++-- crates/goose/tests/truncate_agent.rs | 2 +- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 88b588c0db..812fda263c 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; +use std::path::Path; use anyhow::{anyhow, bail, Result}; use aws_sdk_bedrockruntime::types as bedrock; use aws_smithy_types::{Document, Number}; use chrono::Utc; -use mcp_core::{Content, Role, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; use super::super::base::Usage; @@ -50,7 +51,7 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result Some( content .iter() - .map(to_bedrock_tool_result_content_block) + .map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c)) .collect::>()?, ), Err(_) => None, @@ -71,12 +72,15 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result Result { Ok(match content { Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), - Content::Resource(_) => bail!("Resource content is not supported by Bedrock provider yet"), + Content::Resource(resource) => bedrock::ToolResultContentBlock::Document( + to_bedrock_document(tool_use_id, &resource.resource)?, + ), }) } @@ -131,6 +135,44 @@ pub fn to_bedrock_json(value: &Value) -> Document { } } +fn to_bedrock_document( + tool_use_id: &str, + content: &ResourceContents, +) -> Result { + let (uri, text) = match content { + ResourceContents::TextResourceContents { uri, text, .. } => (uri, text), + ResourceContents::BlobResourceContents { .. } => { + bail!("Blob resource content is not supported by Bedrock provider yet") + } + }; + + let filename = Path::new(uri) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(uri); + + let (name, format) = match filename.split_once('.') { + Some((name, "txt")) => (name, bedrock::DocumentFormat::Txt), + Some((name, "csv")) => (name, bedrock::DocumentFormat::Csv), + Some((name, "md")) => (name, bedrock::DocumentFormat::Md), + Some((name, "html")) => (name, bedrock::DocumentFormat::Html), + Some((name, _)) => (name, bedrock::DocumentFormat::Txt), + _ => (filename, bedrock::DocumentFormat::Txt), + }; + + // Since we can't use the full path (due to character limit and also Bedrock does not accept `/` etc.), + // and Bedrock wants document names to be unique, we're adding `tool_use_id` as a prefix to make + // document names unique. + let name = format!("{tool_use_id}-{name}"); + + bedrock::DocumentBlock::builder() + .format(format) + .name(name) + .source(bedrock::DocumentSource::Bytes(text.as_bytes().into())) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock document: {}", err)) +} + pub fn from_bedrock_message(message: &bedrock::Message) -> Result { let role = from_bedrock_role(message.role())?; let content = message diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index be2b14ecc0..f8375086f5 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -36,7 +36,7 @@ impl ProviderType { ], ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], - ProviderType::Bedrock => &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"], + ProviderType::Bedrock => &["AWS_PROFILE", "AWS_REGION"], ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Google => &["GOOGLE_API_KEY"], ProviderType::Groq => &["GROQ_API_KEY"],