Skip to content

Commit

Permalink
chore(models): adding option to set model and default model
Browse files Browse the repository at this point in the history
  • Loading branch information
kiraum committed Sep 30, 2024
1 parent 5a46c69 commit facbc9b
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ struct Opt {

#[structopt(long, help = "List available models")]
list_models: bool,

#[structopt(long, help = "Set the model to use")]
set_model: Option<String>,
}

#[tokio::main]
Expand All @@ -65,10 +68,6 @@ async fn main() -> Result<(), Box<dyn Error>> {
.expect("Error: SRC_ACCESS_TOKEN environment variable is not set.");
let endpoint =
env::var("SRC_ENDPOINT").expect("Error: SRC_ENDPOINT environment variable is not set.");
let chat_completions_url = format!(
"{}/.api/completions/stream?api-version=1&client-name=defaultclient&client-version=6.0.0'",
endpoint
);

if opt.list_models {
list_available_models(&endpoint).await?;
Expand All @@ -82,6 +81,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
process::exit(0);
}

let chat_completions_url = format!(
"{}/.api/completions/stream?api-version=1&client-name=defaultclient&client-version=6.0.0'",
endpoint
);

debug!("Chat completions URL: {}", chat_completions_url);

let mut headers = HeaderMap::new();
Expand All @@ -102,6 +106,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
pb.enable_steady_tick(100);
pb.set_message("Processing...");

let model = opt
.set_model
.unwrap_or_else(|| "anthropic::2023-06-01::claude-3.5-sonnet".to_string());

let result = match (opt.jql, opt.message) {
(Some(jql), Some(message)) => {
debug!(
Expand All @@ -127,7 +135,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
debug!("Jira data fetched, length: {} characters", jira_data.len());

let batch_summaries =
process_jira_data(&message, jira_data, &chat_completions_url, &headers).await?;
process_jira_data(&message, jira_data, &chat_completions_url, &headers, &model)
.await?;
debug!(
"Jira data processed, {} batch summaries",
batch_summaries.len()
Expand All @@ -141,13 +150,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
);

debug!("Final query length: {} characters", final_query.len());
let final_answer = cody_chat(&final_query, &chat_completions_url, &headers).await?;
let final_answer =
cody_chat(&final_query, &chat_completions_url, &headers, &model).await?;
println!("Answer:\n{}", final_answer);
Ok(())
}
(None, Some(message)) => {
debug!("Only message provided, no JQL. Message: {:?}", message);
let answer = cody_chat(&message, &chat_completions_url, &headers).await?;
let answer = cody_chat(&message, &chat_completions_url, &headers, &model).await?;
pb.finish_and_clear();
println!("Answer:\n{}", answer);
Ok(())
Expand Down Expand Up @@ -320,6 +330,7 @@ async fn process_jira_data(
jira_data: String,
chat_completions_url: &str,
headers: &HeaderMap,
model: &str,
) -> Result<Vec<String>, Box<dyn Error>> {
const BATCH_SIZE: usize = 200_000;

Expand Down Expand Up @@ -364,7 +375,7 @@ async fn process_jira_data(
i + 1,
batch_query.len()
);
let batch_summary = cody_chat(&batch_query, chat_completions_url, headers).await?;
let batch_summary = cody_chat(&batch_query, chat_completions_url, headers, model).await?;
debug!("Processed batch {} answer:\n{}", i + 1, batch_summary);
batch_summaries.push(batch_summary);
}
Expand All @@ -376,6 +387,7 @@ async fn cody_chat(
query: &str,
chat_completions_url: &str,
headers: &HeaderMap,
model: &str,
) -> Result<String, Box<dyn Error>> {
let final_prompt = format!(
r#"
Expand All @@ -390,7 +402,7 @@ async fn cody_chat(
"Sending chat query of length: {} characters",
final_prompt.len()
);
let response = chat_completions(&final_prompt, chat_completions_url, headers).await?;
let response = chat_completions(&final_prompt, chat_completions_url, headers, model).await?;
debug!(
"Received chat response of length: {} characters",
response.len()
Expand All @@ -402,11 +414,12 @@ async fn chat_completions(
query: &str,
chat_completions_url: &str,
headers: &HeaderMap,
model: &str,
) -> Result<String, Box<dyn Error>> {
let data = json!({
"maxTokensToSample": 4000,
"messages": [{"role": "user", "content": query}],
"model": "anthropic::2023-06-01::claude-3.5-sonnet",
"model": model,
"temperature": 0.2,
"topK": -1,
"topP": -1,
Expand Down

0 comments on commit facbc9b

Please sign in to comment.