Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support extending the system prompt #1167

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/goose-cli/src/cli_prompt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/// Returns a system prompt extension that explains CLI-specific functionality
pub fn get_cli_prompt() -> String {
String::from(
"You are being accessed through a command-line interface. The following slash commands are available
- you can let the user know about them if they need help:

- /exit or /quit - Exit the session
- /t - Toggle between Light/Dark/Ansi themes
- /? or /help - Display help message

Additional keyboard shortcuts:
- Ctrl+C - Interrupt the current interaction (resets to before the interrupted request)
- Ctrl+J - Add a newline
- Up/Down arrows - Navigate command history"
)
}
5 changes: 5 additions & 0 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ pub async fn build_session(

let prompt = Box::new(RustylinePrompt::new());

// Add CLI-specific system prompt extension
agent
.extend_system_prompt(crate::cli_prompt::get_cli_prompt())
.await;

display_session_info(resume, &provider_name, &model, &session_file);
Session::new(agent, prompt, session_file)
}
Expand Down
1 change: 1 addition & 0 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
app_name: "goose".to_string(),
});

mod cli_prompt;
mod commands;
mod log_usage;
mod logging;
Expand Down
35 changes: 35 additions & 0 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ struct VersionsResponse {
default_version: String,
}

#[derive(Deserialize)]
struct ExtendPromptRequest {
extension: String,
}

#[derive(Serialize)]
struct ExtendPromptResponse {
success: bool,
}

#[derive(Deserialize)]
struct CreateAgentRequest {
version: Option<String>,
Expand Down Expand Up @@ -61,6 +71,30 @@ async fn get_versions() -> Json<VersionsResponse> {
})
}

async fn extend_prompt(
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<ExtendPromptRequest>,
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

let mut agent = state.agent.lock().await;
if let Some(ref mut agent) = *agent {
agent.extend_system_prompt(payload.extension).await;
Ok(Json(ExtendPromptResponse { success: true }))
} else {
Err(StatusCode::NOT_FOUND)
}
}

async fn create_agent(
State(state): State<AppState>,
headers: HeaderMap,
Expand Down Expand Up @@ -132,6 +166,7 @@ pub fn routes(state: AppState) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent/prompt", post(extend_prompt))
.route("/agent", post(create_agent))
.with_state(state)
}
3 changes: 3 additions & 0 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ pub trait Agent: Send + Sync {

/// Get the total usage of the agent
async fn usage(&self) -> Vec<ProviderUsage>;

/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);
}
19 changes: 18 additions & 1 deletion crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub struct Capabilities {
resource_capable_extensions: HashSet<String>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
system_prompt_extensions: Vec<String>,
}

/// A flattened representation of a resource used by the agent to prepare inference
Expand Down Expand Up @@ -88,6 +89,7 @@ impl Capabilities {
resource_capable_extensions: HashSet::new(),
provider,
provider_usage: Mutex::new(Vec::new()),
system_prompt_extensions: Vec::new(),
}
}

Expand Down Expand Up @@ -164,6 +166,11 @@ impl Capabilities {
Ok(())
}

/// Add a system prompt extension
pub fn add_system_prompt_extension(&mut self, extension: String) {
self.system_prompt_extensions.push(extension);
}

/// Get a reference to the provider
pub fn provider(&self) -> &dyn Provider {
&*self.provider
Expand Down Expand Up @@ -303,7 +310,17 @@ impl Capabilities {
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
context.insert("current_date_time", Value::String(current_date_time));

load_prompt_file("system.md", &context).expect("Prompt should render")
let base_prompt = load_prompt_file("system.md", &context).expect("Prompt should render");

if self.system_prompt_extensions.is_empty() {
base_prompt
} else {
format!(
"{}\n\n# Additional Instructions:\n\n{}",
base_prompt,
self.system_prompt_extensions.join("\n\n")
)
}
}

/// Find and return a reference to the appropriate client for a tool call
Expand Down
5 changes: 5 additions & 0 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ impl Agent for ReferenceAgent {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}

async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
}

register_agent!("reference", ReferenceAgent);
5 changes: 5 additions & 0 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ impl Agent for TruncateAgent {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}

async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
}

register_agent!("truncate", TruncateAgent);
Loading