Skip to content

Commit

Permalink
🎨 Use allocator for messages
Browse files Browse the repository at this point in the history
Now we're using `oxc_allocator` for storing messages. This let's us store messages more easily and idiomatically
  • Loading branch information
darkdarcool committed Feb 13, 2024
1 parent 9aeb004 commit a04b1a2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 55 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ rand = "0.8.5"
futures = "0.3.30"
homedir = "0.2.1"
repair_json = "0.1.0"
oxc_allocator = "0.7.0"
97 changes: 51 additions & 46 deletions src/copilot.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::io::Write;

use crate::{gh, headers::{CopilotCompletionHeaders, Headers}, prompts, utils};
use crate::{
gh,
headers::{CopilotCompletionHeaders, Headers},
utils,
};
use futures::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -50,71 +54,70 @@ struct GhCopilotResponse {
}

#[derive(Deserialize, Serialize, Clone)]
struct Message {
content: String,
role: String,
pub struct Message<'alloc> {
content: &'alloc str,
role: &'alloc str,
}

#[derive(Debug)]
pub struct Completion {
pub content: String,
pub finish_reason: String
pub finish_reason: String,
}

pub struct CopilotManager<'a> {
pub struct CopilotManager<'a, 'alloc> {
vscode_sid: String,
device_id: String,
auth: &'a gh::GithubAuth,
client: &'a Client,
history: Vec<Message>,
allocator: &'alloc oxc_allocator::Allocator,
history: Vec<Message<'alloc>>,
}

impl<'a> CopilotManager<'a> {
pub fn new(auth: &'a gh::GithubAuth, client: &'a Client) -> CopilotManager<'a> {
impl<'a, 'alloc> CopilotManager<'a, 'alloc> {
pub fn new(
auth: &'a gh::GithubAuth,
client: &'a Client,
allocator: &'a oxc_allocator::Allocator,
prompt: &'static str
) -> CopilotManager<'a, 'alloc>
where
'a: 'alloc,
{
let vscode_sid = crate::utils::generate_vscode_session_id();
let device_id = crate::utils::random_hex_string(6);

let mut history = Vec::new();

history.push(Message {
content: allocator.alloc_str(prompt),
role: allocator.alloc_str("system"),
});

CopilotManager {
vscode_sid,
device_id,
auth,
client,
history: Vec::new(),
allocator,
history,
}
}

fn construct_message_history(
&self,
system_prompt: &str,
current_history: &Vec<Message>,
) -> Vec<Message> {
let system_message = Message {
content: system_prompt.to_string(),
role: "system".to_string(),
};

// return system message and the current history
vec![system_message]
.into_iter()
.chain(current_history.iter().cloned())
.collect()
}

pub async fn ask(&mut self, prompt: &String, log: bool) -> Completion {
let url = "https://api.githubcopilot.com/chat/completions";
let headers = CopilotCompletionHeaders {
token: &self.auth.copilot_auth.token,
vscode_sid: &self.vscode_sid,
device_id: &self.device_id,
}.to_headers();
}
.to_headers();

let mut history =
self.construct_message_history(prompts::COPILOT_INSTRUCTIONS, &self.history);
let history = &mut self.history;

// add current user prompt to history
history.push(Message {
content: prompt.to_string(),
role: "user".to_string(),
content: self.allocator.alloc_str(prompt),
role: self.allocator.alloc_str("user"),
});

// no chat history for this
Expand Down Expand Up @@ -148,10 +151,13 @@ impl<'a> CopilotManager<'a> {
let body_str = String::from_utf8_lossy(&body);

buffer.push_str(&body_str);
// the data may be split into multiple chunks, BUT it's always dilimited by \n\ndata:

let lines = buffer.split("\n\ndata: ").map(|s| s.to_string()).map(|s| s.replacen("data:", "", 1)).collect::<Vec<String>>();

// the data may be split into multiple chunks, BUT it's always dilimited by \n\ndata:
let lines = buffer
.split("\n\ndata: ")
.map(|s| s.to_string())
.map(|s| s.replacen("data:", "", 1))
.collect::<Vec<String>>();

let mut processed_buffer = String::new();
for line in lines {
Expand All @@ -162,15 +168,17 @@ impl<'a> CopilotManager<'a> {

let parsed = serde_json::from_str::<GhCopilotResponse>(&line);


match parsed {
Ok(parsed) => {
// If the choice actually exists
if parsed.choices.len() > 0 {
let choice = &parsed.choices[0];
// If there is a finish reason in the choice, we break the loop
if let Some(freason) = &choice.finish_reason {
finish_reason = freason.clone().to_string();
break 'outerloop;
}
// There might be content in the delta, let's handle it
let delta = &choice.delta;
if let Some(content) = &delta.content {
print!("{}", content);
Expand All @@ -184,10 +192,9 @@ impl<'a> CopilotManager<'a> {
processed_buffer.push_str(&line);
}
}

// Add the incomplete line to the buffer to be processed in the next iteration
buffer = processed_buffer.clone();
}

}

if log {
Expand All @@ -197,15 +204,13 @@ impl<'a> CopilotManager<'a> {

// add the response to the history
history.push(Message {
content: message.clone(),
role: "system".to_string(),
content: self.allocator.alloc_str(&message),
role: self.allocator.alloc_str("system"),
});

self.history = history;

Completion {
content: message.clone(),
finish_reason
content: message,
finish_reason,
}
}
}
}
9 changes: 5 additions & 4 deletions src/gh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use serde_json;

use crate::{
headers::{self, Headers},
utils,
urls
urls, utils,
};

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -228,7 +227,8 @@ impl AuthenticationManager {
let headers = headers::GithubUserHeaders {
token: &auth.access_token,
token_type: &auth.token_type,
}.to_headers();
}
.to_headers();

let req = reqwest::Client::new()
.get(urls::GH_AUTH_TOKEN_URL)
Expand All @@ -251,7 +251,8 @@ impl AuthenticationManager {
) -> Result<GithubCopilotAuth, String> {
let headers = headers::GithubInternalHeaders {
token: &auth.access_token,
}.to_headers();
}
.to_headers();

let req = reqwest::Client::new()
.get(urls::GH_COPILOT_INTERNAL_AUTH_URL)
Expand Down
2 changes: 1 addition & 1 deletion src/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl<'a> Headers for GithubInternalHeaders<'a> {
}
}

pub (crate) struct CopilotCompletionHeaders<'a> {
pub(crate) struct CopilotCompletionHeaders<'a> {
pub token: &'a String,
pub vscode_sid: &'a String,
pub device_id: &'a String,
Expand Down
8 changes: 5 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ mod copilot;
mod gh;
mod headers;
mod prompts;
mod utils;
mod urls;
mod utils;

use oxc_allocator;
use rustyline::DefaultEditor;

#[tokio::main]
Expand All @@ -14,7 +15,9 @@ async fn main() {

let client = reqwest::Client::new();

let mut copilot_m = copilot::CopilotManager::new(&auth, &client);
let allocator = oxc_allocator::Allocator::default();

let mut copilot_m = copilot::CopilotManager::new(&auth, &client, &allocator, prompts::COPILOT_INSTRUCTIONS);

let mut rl = DefaultEditor::new().unwrap();

Expand All @@ -29,6 +32,5 @@ async fn main() {

println!("===COPILOT===");
println!("{:#?}", msg);

}
}
2 changes: 1 addition & 1 deletion src/urls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub const DEVICE_CODE_LOGIN_URL: &str = "https://github.com/login/device/code";
pub const DEVICE_CODE_TOKEN_CHECK_URL: &str = "https://github.com/login/oauth/access_token";
pub const GH_AUTH_TOKEN_URL: &str = "https://api.github.com/user";
pub const GH_COPILOT_INTERNAL_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
pub const GH_COPILOT_INTERNAL_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";

0 comments on commit a04b1a2

Please sign in to comment.