diff --git a/Cargo.toml b/Cargo.toml index 95ca979..809fa34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,5 @@ futures = "0.3.30" homedir = "0.2.1" repair_json = "0.1.0" oxc_allocator = "0.7.0" +syntect = "5.0" +crossterm = "0.27.0" diff --git a/src/copilot.rs b/src/copilot.rs index 4840c1e..0a9a7f4 100644 --- a/src/copilot.rs +++ b/src/copilot.rs @@ -4,12 +4,17 @@ use crate::{ gh, headers::{CopilotCompletionHeaders, Headers}, utils, + term }; + use futures::StreamExt; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; +// crossterm for writing + + #[derive(Serialize, Deserialize, Debug)] struct ContentFilterResult { filtered: bool, @@ -72,6 +77,7 @@ pub struct CopilotManager<'a, 'alloc> { client: &'a Client, allocator: &'alloc oxc_allocator::Allocator, history: Vec>, + full_message: String, } impl<'a, 'alloc> CopilotManager<'a, 'alloc> { @@ -101,9 +107,11 @@ impl<'a, 'alloc> CopilotManager<'a, 'alloc> { client, allocator, history, + full_message: String::new(), } } + #[allow(unused_assignments)] pub async fn ask(&mut self, prompt: &String, log: bool) -> Completion { let url = "https://api.githubcopilot.com/chat/completions"; let headers = CopilotCompletionHeaders { @@ -113,12 +121,18 @@ impl<'a, 'alloc> CopilotManager<'a, 'alloc> { } .to_headers(); - let history = &mut self.history; + let mut transport_history = Vec::new(); - history.push(Message { - content: self.allocator.alloc_str(prompt), - role: self.allocator.alloc_str("user"), - }); + { + let history = &mut self.history; + + history.push(Message { + content: self.allocator.alloc_str(prompt), + role: self.allocator.alloc_str("user"), + }); + + transport_history = history.clone(); + } // no chat history for this let data = json!({ @@ -128,7 +142,7 @@ impl<'a, 'alloc> CopilotManager<'a, 'alloc> { "stream": true, "temperature": 0.1, "top_p": 1, - "messages": history + "messages": transport_history }); // we need to stream the response @@ -181,8 +195,9 @@ impl<'a, 'alloc> CopilotManager<'a, 'alloc> { // There might be content in the delta, let's handle it let delta = &choice.delta; if let Some(content) = &delta.content { - print!("{}", content); - std::io::stdout().flush().unwrap(); + if log { + self.handle_content(content).await; + }//std::io::stdout().flush().unwrap(); message.push_str(content); } } @@ -203,14 +218,47 @@ impl<'a, 'alloc> CopilotManager<'a, 'alloc> { } // add the response to the history - history.push(Message { - content: self.allocator.alloc_str(&message), - role: self.allocator.alloc_str("system"), - }); + { + let history = &mut self.history; + + history.push(Message { + content: self.allocator.alloc_str(&message), + role: self.allocator.alloc_str("system"), + }); + } + + self.full_message = String::new(); Completion { content: message, finish_reason, } } + + async fn handle_content(&mut self, content: &String) { + // tokio sleep for 10 ms + // tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + self.full_message.push_str(content); + let line_count = self.full_message.split("\n").count(); + + if self.full_message.ends_with("\n") { + let highlighted = term::highlight_line(&self.full_message); + let escaped: Vec = term::to_terminal_escaped(&highlighted) + .split("\n") + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect(); + + let mut escaped_len = escaped.len(); + while line_count > escaped_len { + print!("\n"); + escaped_len += 1; + } + + print!("{}", escaped.last().unwrap()); + std::io::stdout().flush().unwrap(); + // self.full_message = String::new(); + } + } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 2f5faa3..fb3de29 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,12 +4,27 @@ mod headers; mod prompts; mod urls; mod utils; +mod term; + +use crossterm::{ + execute, + terminal::{EnterAlternateScreen, LeaveAlternateScreen}, +}; +use std::io::{stdout, Write}; use oxc_allocator; use rustyline::DefaultEditor; +fn move_up_one_line() { + print!("\x1b[1A"); + std::io::stdout().flush().unwrap(); +} + #[tokio::main] async fn main() { + // enter alternate screen + execute!(stdout(), EnterAlternateScreen).unwrap(); + let auth_manager = gh::AuthenticationManager::new(); let auth = auth_manager.cache_auth().await.unwrap(); @@ -24,13 +39,21 @@ async fn main() { loop { let input = rl.readline("You: ").unwrap(); + move_up_one_line(); + if input == "exit" { break; } - let msg = copilot_m.ask(&input, true).await; + let _msg = copilot_m.ask(&input, true).await; + // reset the forground color + print!("\033[0m"); + // syntax highlighting + // let highlighted = term::highlight_text(&msg.content); + // println!("{}", highlighted); - println!("===COPILOT==="); - println!("{:#?}", msg); } + + // leave alternate screen + execute!(stdout(), LeaveAlternateScreen).unwrap(); } diff --git a/src/term.rs b/src/term.rs new file mode 100644 index 0000000..b0da4e7 --- /dev/null +++ b/src/term.rs @@ -0,0 +1,19 @@ +use syntect::{self, highlighting::Style}; + +pub fn highlight_line(text: &String) -> Vec<(Style, &str)> { + // using syntect, apply markdown syntax highlighting to the text + let syntax_set = syntect::parsing::SyntaxSet::load_defaults_newlines(); + let syntax = syntax_set.find_syntax_by_extension("md").unwrap(); + let h = syntect::highlighting::ThemeSet::load_defaults(); + let mut highlighter = syntect::easy::HighlightLines::new(syntax, &h.themes["base16-mocha.dark"]); + + let highlighted = highlighter.highlight_line(text, &syntax_set).unwrap(); + // let escaped = syntect::util::as_24_bit_terminal_escaped(&highlighted, false); + highlighted +} + +pub fn to_terminal_escaped(highlighted: &Vec<(Style, &str)>) -> String { + // convert the highlighted text to a string with terminal escape sequences + let escaped = syntect::util::as_24_bit_terminal_escaped(highlighted, false); + escaped +}