Skip to content

Commit

Permalink
fix: Send system prompts to OAI backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Sep 17, 2024
1 parent 134f6d1 commit 3758d0d
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ pub struct OpenAITextGenerationResponse {
pub choices: Vec<OpenAITextGenerationChoice>,
}

#[derive(Deserialize, Serialize, Clone)]
pub struct OpenAITextGenerationRequest {
pub model: String,
pub messages: Vec<OpenAITextGenerationMessage>,
pub max_tokens: Option<u64>,
pub stream: bool,
}

impl OpenAITextGenerationBackend {
pub fn new(api_key: String, base_url: String, model_name: String) -> Self {
Self {
Expand All @@ -92,24 +100,33 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
async fn generate(&self, request: Arc<TextGenerationRequest>, sender: Sender<TextGenerationAggregatedResponse>) {
let url = format!("{base_url}/v1/chat/completions", base_url = self.base_url);
let mut aggregated_response = TextGenerationAggregatedResponse::default();
//debug!("Requesting {url} with prompt: {prompt}, max tokens: {max_tokens}", prompt = request.prompt, max_tokens = request.max_tokens);
let messages = match &request.system_prompt {
None => vec![
OpenAITextGenerationMessage {
role: "user".to_string(),
content: request.prompt.clone(),
}
],
Some(system_prompt) => vec![
OpenAITextGenerationMessage {
role: "system".to_string(),
content: system_prompt.clone(),
},
OpenAITextGenerationMessage {
role: "user".to_string(),
content: request.prompt.clone(),
}
]
};
let body = OpenAITextGenerationRequest {
model: self.model_name.clone(),
messages,
max_tokens: request.num_decode_tokens,
stream: true,
};
let req = reqwest::Client::new().post(url)
.header("Authorization", format!("Bearer {token}", token = self.api_key))
.json(&serde_json::json!({
"model": self.model_name,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": request.prompt
}
],
"max_tokens": request.num_decode_tokens,
"stream": true,
}));
.json(&serde_json::json!(body));
// start timer
aggregated_response.start(request.num_prompt_tokens);
let mut es = EventSource::new(req).unwrap();
Expand Down

0 comments on commit 3758d0d

Please sign in to comment.