Skip to content

Commit f791360

Browse files
committed
Fix streaming issues and add citations support in chat messages
1 parent 5dceffd commit f791360

File tree

9 files changed

+887
-86
lines changed

9 files changed

+887
-86
lines changed

Cargo.lock

+525-42
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

moly-kit/Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ edition = "2021"
55

66
[dependencies]
77
futures = "0.3.31"
8+
url = "2.4.10"
9+
link-preview = { version = "0.1.1", features = ["fetch"] }
10+
robius-open = "0.1.1"
811

912
makepad-widgets = { git = "https://github.com/makepad/makepad", branch = "rik" }
1013
makepad-code-editor = { git = "https://github.com/makepad/makepad", branch = "rik" }

moly-kit/src/clients/moly.rs

+17-7
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,18 @@ enum Role {
7777
Assistant,
7878
}
7979

80+
/// The Choice object as part of a streaming response.
8081
#[derive(Clone, Debug, Deserialize)]
8182
struct Choice {
8283
pub delta: IncomingMessage,
8384
}
8485

85-
/// Response from the completions endpoint.
86+
/// Response from the completions endpoint
8687
#[derive(Clone, Debug, Deserialize)]
87-
struct Completation {
88+
struct Completion {
8889
pub choices: Vec<Choice>,
90+
#[serde(default)]
91+
pub citations: Option<Vec<String>>,
8992
}
9093

9194
#[derive(Clone, Debug, Default)]
@@ -186,11 +189,12 @@ impl BotClient for MolyClient {
186189
Box::new(self.clone())
187190
}
188191

192+
/// Stream pieces of content back as a ChatDelta instead of just a String.
189193
fn send_stream(
190194
&mut self,
191195
bot: &BotId,
192196
messages: &[Message],
193-
) -> MolyStream<'static, Result<String, ()>> {
197+
) -> MolyStream<'static, Result<ChatDelta, ()>> {
194198
let moly_messages: Vec<OutcomingMessage> = messages
195199
.iter()
196200
.filter_map(|m| m.clone().try_into().ok())
@@ -255,22 +259,28 @@ impl BotClient for MolyClient {
255259
.filter(|m| m.trim() != "[DONE]");
256260

257261
for m in messages {
258-
let completition: Completation = match serde_json::from_str(m) {
259-
Ok(completition) => completition,
262+
let completion: Completion = match serde_json::from_str(m) {
263+
Ok(c) => c,
260264
Err(error) => {
261265
log!("Error: {:?}", error);
262266
yield Err(());
263267
return;
264268
}
265269
};
266270

267-
let text = completition
271+
// Combine all partial choices content
272+
let content_delta = completion
268273
.choices
269274
.iter()
270275
.map(|c| c.delta.content.as_str())
271276
.collect::<String>();
272277

273-
yield Ok(text);
278+
let citations = completion.citations.clone();
279+
280+
yield Ok(ChatDelta {
281+
content_delta,
282+
citations,
283+
});
274284
}
275285

276286
buffer = incomplete_message.to_vec();

moly-kit/src/clients/multi.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl BotClient for MultiClient {
2828
&mut self,
2929
bot: &BotId,
3030
messages: &[Message],
31-
) -> MolyStream<'static, Result<String, ()>> {
31+
) -> MolyStream<'static, Result<ChatDelta, ()>> {
3232
let mut client = self
3333
.clients_with_bots
3434
.lock()
@@ -43,7 +43,7 @@ impl BotClient for MultiClient {
4343
})
4444
.expect("no client for bot");
4545

46-
client.send_stream(&bot, messages)
46+
client.send_stream(bot, messages)
4747
}
4848

4949
fn clone_box(&self) -> Box<dyn BotClient> {

moly-kit/src/protocol.rs

+32-10
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ pub struct Message {
7676
///
7777
/// If `false`, it means the message will not change anymore.
7878
pub is_writing: bool,
79+
/// Citations for the message.
80+
pub citations: Vec<String>,
81+
}
82+
83+
/// A new structure to hold both text delta and optional metadata like citations.
84+
#[derive(Clone, Debug)]
85+
pub struct ChatDelta {
86+
pub content_delta: String,
87+
pub citations: Option<Vec<String>>,
7988
}
8089

8190
/// A standard interface to fetch bots information and send messages to them.
@@ -92,7 +101,7 @@ pub trait BotClient: Send {
92101
&mut self,
93102
bot: &BotId,
94103
messages: &[Message],
95-
) -> MolyStream<'static, Result<String, ()>>;
104+
) -> MolyStream<'static, Result<ChatDelta, ()>>;
96105

97106
/// Interrupt the bot's current operation.
98107
// TODO: There may be many chats with the same bot/model/agent so maybe this
@@ -109,24 +118,37 @@ pub trait BotClient: Send {
109118
fn clone_box(&self) -> Box<dyn BotClient>;
110119

111120
/// Send a message to a bot expecting a full response at once.
112-
// TODO: messages may end up being a little bit more complex, using string while thinking.
113-
// TODO: Should support a way of passing, unknown, backend-specific, inference parameters.
114121
fn send(
115122
&mut self,
116123
bot: &BotId,
117124
messages: &[Message],
118-
) -> MolyFuture<'static, Result<String, ()>> {
125+
) -> MolyFuture<'static, Result<Message, ()>> {
119126
let stream = self.send_stream(bot, messages);
127+
let bot = bot.clone();
120128

121129
let future = async move {
122-
let parts = stream.collect::<Vec<_>>().await;
123-
124-
if parts.contains(&Err(())) {
125-
return Err(());
130+
let mut content = String::new();
131+
let mut citations = Vec::new();
132+
133+
let mut stream = stream;
134+
while let Some(delta) = stream.next().await {
135+
match delta {
136+
Ok(chat_delta) => {
137+
content.push_str(&chat_delta.content_delta);
138+
if let Some(cits) = chat_delta.citations {
139+
citations = cits;
140+
}
141+
}
142+
Err(()) => return Err(()),
143+
}
126144
}
127145

128-
let message = parts.into_iter().filter_map(Result::ok).collect::<String>();
129-
Ok(message)
146+
Ok(Message {
147+
from: EntityId::Bot(bot),
148+
body: content,
149+
is_writing: false,
150+
citations,
151+
})
130152
};
131153

132154
moly_future(future)

moly-kit/src/widgets/chat.rs

+26-21
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ pub enum ChatTask {
8888
/// When received back, it will re-write the message history with the given messages.
8989
SetMessages(Vec<Message>),
9090

91-
/// When received back, it will insert a message at the given index with the given text and entity.
92-
InsertMessage(usize, EntityId, String),
91+
/// When received back, it will insert a message at the given index.
92+
InsertMessage(usize, Message),
9393

9494
/// When received back, it will delete the message at the given index.
9595
DeleteMessage(usize),
@@ -292,8 +292,12 @@ impl Chat {
292292
if !text.is_empty() {
293293
composition.push(ChatTask::InsertMessage(
294294
next_index,
295-
EntityId::User,
296-
text.clone(),
295+
Message {
296+
from: EntityId::User,
297+
body: text.clone(),
298+
is_writing: false,
299+
citations: vec![],
300+
},
297301
));
298302
}
299303

@@ -322,6 +326,7 @@ impl Chat {
322326
from: EntityId::Bot(bot_id.clone()),
323327
body: String::new(),
324328
is_writing: true,
329+
citations: vec![],
325330
});
326331

327332
self.dispatch(cx, ChatTask::ScrollToBottom.into());
@@ -343,24 +348,28 @@ impl Chat {
343348
let mut message_stream = client.send_stream(&bot_id, &context);
344349

345350
while let Some(delta) = message_stream.next().await {
346-
let delta = delta.unwrap_or_else(|_| "An error occurred".to_string());
351+
let delta = match delta {
352+
Ok(chat_delta) => chat_delta,
353+
Err(_) => ChatDelta {
354+
content_delta: "An error occurred".to_string(),
355+
citations: None,
356+
},
357+
};
347358

348359
ui.defer_with_redraw(move |me, cx, _scope| {
349360
me.messages_ref().write_with(|messages| {
350361
let last_index = messages.messages.len() - 1;
351-
let message = messages
352-
.messages
353-
.last_mut()
354-
.expect("no message where to put delta");
362+
let message = messages.messages.last_mut().expect("no message where to put delta");
355363

356-
message.body.push_str(&delta);
357-
let updated_body = message.body.clone();
364+
// Append new text
365+
message.body.push_str(&delta.content_delta);
358366

359-
me.dispatch(
360-
cx,
361-
ChatTask::UpdateMessage(last_index, updated_body).into(),
362-
);
367+
// If the chunk contains citations, store them
368+
if let Some(cits) = &delta.citations {
369+
message.citations = cits.clone();
370+
}
363371

372+
me.dispatch(cx, ChatTask::UpdateMessage(last_index, message.body.clone()).into());
364373
if messages.is_at_bottom() {
365374
me.dispatch(cx, ChatTask::ScrollToBottom.into());
366375
}
@@ -458,14 +467,10 @@ impl Chat {
458467
self.messages_ref().write().messages.remove(*index);
459468
self.redraw(cx);
460469
}
461-
ChatTask::InsertMessage(index, entity, text) => {
470+
ChatTask::InsertMessage(index, message) => {
462471
self.messages_ref().write().messages.insert(
463472
*index,
464-
Message {
465-
from: entity.clone(),
466-
body: text.clone(),
467-
is_writing: false,
468-
},
473+
message.clone(),
469474
);
470475
self.redraw(cx);
471476
}

0 commit comments

Comments
 (0)