Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy model switch #210

Merged
merged 30 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
28688fc
model loader external interface changes
noxware Aug 5, 2024
9832f1d
simplify model loader unused error
noxware Aug 5, 2024
ed529c2
allow model loader to be used from other threads
noxware Aug 6, 2024
cd8dbc0
allow model loader to work without file details
noxware Aug 6, 2024
f470215
have access to unloded models when sending messages
noxware Aug 6, 2024
0d829c2
load model when sending message
noxware Aug 6, 2024
c3304d7
save used model immediately
noxware Aug 6, 2024
ac8f590
unused declarations
noxware Aug 6, 2024
62321cf
do not load model on chat change
noxware Aug 6, 2024
54724b1
display chat last used model
noxware Aug 8, 2024
2125fa5
grey out unloaded model
noxware Aug 8, 2024
fe9b454
on model change, record used model but don't load it
noxware Aug 8, 2024
4fa7879
blocking loading and nice error reporting
noxware Aug 9, 2024
078672f
details
noxware Aug 9, 2024
cba8bd6
more imperative update_selected_model_info
noxware Aug 12, 2024
7f83e82
remove unused model loader error wrapper
noxware Aug 12, 2024
b7c0a08
rename model loader load funcs
noxware Aug 12, 2024
a67e23b
revert some borrowing changes
noxware Aug 12, 2024
7dc8971
extract `hex_rgb_color` fn to utils
noxware Aug 12, 2024
2d0dd6e
fix - default to loaded model when no model in chat
noxware Aug 12, 2024
c6f59f6
extract fallback logic of the previous fix
noxware Aug 13, 2024
483c348
restore load model on selector change
noxware Aug 13, 2024
f1865ea
fix - refcell borrow crash when referenced file deleted
noxware Aug 13, 2024
7a69f74
remove "resume chat" visuals
noxware Aug 13, 2024
a0d50bd
show always "chat with model" in downloads table
noxware Aug 13, 2024
ab90163
show always "chat with model" in discover
noxware Aug 13, 2024
b5bfedd
set last used model when clicking "chat with model"
noxware Aug 13, 2024
323018f
optimized always set last used model when load is called
noxware Aug 13, 2024
e2a27af
remove leftover reference to play_arrow.svg
noxware Aug 14, 2024
c9947d0
Merge pull request #216 from moxin-org/remove-resume-chat
jmbejar Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions resources/icons/play_arrow.svg

This file was deleted.

2 changes: 1 addition & 1 deletion src/chat/chat_history_card.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ impl WidgetMatchEvent for ChatHistoryCard {
ChatHistoryCardAction::ChatSelected(self.chat_id),
);
let store = scope.data.get_mut::<Store>().unwrap();
store.select_chat(self.chat_id);
store.chats.set_current_chat(self.chat_id);
self.redraw(cx);
}
}
Expand Down
29 changes: 15 additions & 14 deletions src/chat/chat_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,22 +512,19 @@ impl WidgetMatchEvent for ChatPanel {

if let ModelSelectorAction::Selected(downloaded_file) = action.cast() {
store.load_model(&downloaded_file.file);

if let Some(chat) = store.chats.get_current_chat() {
chat.borrow_mut().last_used_file_id = Some(downloaded_file.file.id.clone());
chat.borrow().save();
}
self.redraw(cx)
}

match action.cast() {
ChatAction::Start(file_id) => {
let downloaded_file = store
.downloads
.downloaded_files
.iter()
.find(|file| file.file.id == file_id)
.expect("Attempted to start chat with a no longer existing file")
.clone();

store
.chats
.create_empty_chat_and_load_file(&downloaded_file.file);
if let Some(file) = store.downloads.get_file(&file_id) {
store.chats.create_empty_chat_and_load_file(file);
}
}
_ => {}
}
Expand All @@ -538,7 +535,11 @@ impl WidgetMatchEvent for ChatPanel {
self.redraw(cx);
}
ChatLineAction::Edit(id, updated, regenerate) => {
store.chats.edit_chat_message(id, updated, regenerate);
if regenerate {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag was completely changing what the function does, so I spitted the function (motivated by this).

store.edit_chat_message_regenerating(id, updated)
} else {
store.edit_chat_message(id, updated);
}
self.redraw(cx);
}
_ => {}
Expand Down Expand Up @@ -602,7 +603,7 @@ impl ChatPanel {
} else if store.chats.loaded_model.is_none() {
State::NoModelSelected
} else {
let is_loading = store.chats.get_currently_loading_model().is_some();
let is_loading = store.chats.model_loader.is_loading();

store.chats.get_current_chat().map_or(
State::ModelSelectedWithEmptyChat { is_loading },
Expand Down Expand Up @@ -762,7 +763,7 @@ impl ChatPanel {
} | State::ModelSelectedWithEmptyChat { is_loading: false }
) {
let store = scope.data.get_mut::<Store>().unwrap();
store.chats.send_chat_message(prompt.clone());
store.send_chat_message(prompt.clone());

let prompt_input = self.text_input(id!(main_prompt_input.prompt));
prompt_input.set_text_and_redraw(cx, "");
Expand Down
103 changes: 56 additions & 47 deletions src/chat/model_selector.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
data::store::Store,
shared::{actions::ChatAction, utils::format_model_size},
shared::{
actions::ChatAction,
utils::{format_model_size, hex_rgb_color},
},
};
use makepad_widgets::*;

Expand Down Expand Up @@ -48,7 +51,7 @@ live_design! {

cursor: Hand,

content = <View> {
content = <View> {
width: Fill,
height: Fit,
flow: Overlay,
Expand Down Expand Up @@ -239,7 +242,8 @@ impl Widget for ModelSelector {
.apply_over(cx, live! {height: (height)});

let rotate_angle = self.rotate_animation_progress * std::f64::consts::PI;
self.view(id!(icon_drop.icon)).apply_over(cx, live! {draw_bg: {rotation: (rotate_angle)}});
self.view(id!(icon_drop.icon))
.apply_over(cx, live! {draw_bg: {rotation: (rotate_angle)}});

self.redraw(cx);
}
Expand Down Expand Up @@ -349,13 +353,14 @@ impl ModelSelector {
fn hide_options(&mut self, cx: &mut Cx) {
self.open = false;
self.view(id!(options)).apply_over(cx, live! { height: 0 });
self.view(id!(icon_drop.icon)).apply_over(cx, live! {draw_bg: {rotation: (0.0)}});
self.view(id!(icon_drop.icon))
.apply_over(cx, live! {draw_bg: {rotation: (0.0)}});
self.animator_cut(cx, id!(open.hide));
self.redraw(cx);
}

fn update_loading_model_state(&mut self, cx: &mut Cx, store: &Store) {
if store.chats.get_currently_loading_model().is_some() {
if store.chats.model_loader.is_loading() {
self.model_selector_loading(id!(loading))
.show_and_animate(cx);
} else {
Expand All @@ -364,54 +369,59 @@ impl ModelSelector {
}

fn update_selected_model_info(&mut self, cx: &mut Cx, store: &Store) {
self.view(id!(choose)).apply_over(
cx,
live! {
visible: false
},
);

if let Some(file) = &store.chats.get_currently_loading_model() {
// When a model is being loaded, show the "loading state"
let caption = format!("Loading {}", file.name);
self.view(id!(selected)).apply_over(
cx,
live! {
visible: true
label = { text: (caption) }
architecture_tag = { visible: false }
params_size_tag = { visible: false }
file_size_tag = { visible: false }
},
);
} else {
let Some(downloaded_file) = store.get_loaded_downloaded_file() else {
error!("Error displaying current loaded model");
return;
self.view(id!(choose)).set_visible(false);

let is_loading = store.chats.model_loader.is_loading();
let loaded_file = store.chats.loaded_model.as_ref();

let file = store
.chats
.get_current_chat()
.and_then(|c| c.borrow().last_used_file_id.clone())
.and_then(|file_id| store.downloads.get_file(&file_id).cloned())
.or_else(|| loaded_file.cloned());

if let Some(file) = file {
let selected_view = self.view(id!(selected));
selected_view.set_visible(true);

let text_color = if Some(&file.id) == loaded_file.map(|f| &f.id) {
hex_rgb_color(0x000000)
} else {
hex_rgb_color(0x667085)
};

// When a model is loaded, show the model info
let filename = downloaded_file.file.name;

let architecture = downloaded_file.model.architecture;
let architecture_visible = !architecture.trim().is_empty();

let param_size = downloaded_file.model.size;
let param_size_visible = !param_size.trim().is_empty();
let caption = if is_loading {
format!("Loading {}", file.name.trim())
} else {
file.name.trim().to_string()
};

let size = format_model_size(&downloaded_file.file.size).unwrap_or("".to_string());
let size_visible = !size.trim().is_empty();
let file_size = format_model_size(file.size.trim()).unwrap_or("".into());
let is_file_size_visible = !file_size.is_empty() && !is_loading;

self.view(id!(selected)).apply_over(
selected_view.apply_over(
cx,
live! {
visible: true
label = { text: (filename) }
architecture_tag = { visible: (architecture_visible), caption = { text: (architecture) }}
params_size_tag = { visible: (param_size_visible), caption = { text: (param_size) }}
file_size_tag = { visible: (size_visible), caption = { text: (size) }}
label = { text: (caption), draw_text: { color: (text_color) }}
file_size_tag = { visible: (is_file_size_visible), caption = { text: (file_size), draw_text: { color: (text_color) }}}
},
);

if let Some(model) = store.downloads.get_model_by_file_id(&file.id) {
let architecture = model.architecture.trim();
let params_size = model.size.trim();
let is_architecture_visible = !architecture.is_empty() && !is_loading;
let is_params_size_visible = !params_size.is_empty() && !is_loading;

selected_view.apply_over(
cx,
live! {
architecture_tag = { visible: (is_architecture_visible), caption = { text: (architecture), draw_text: { color: (text_color) }}}
params_size_tag = { visible: (is_params_size_visible), caption = { text: (params_size), draw_text: { color: (text_color) }}}
},
);
}
}

self.redraw(cx);
Expand Down Expand Up @@ -449,6 +459,5 @@ fn options_to_display(store: &Store) -> bool {
}

fn no_active_model(store: &Store) -> bool {
store.get_loaded_downloaded_file().is_none()
&& store.chats.get_currently_loading_model().is_none()
store.get_loaded_downloaded_file().is_none() && store.get_loading_file().is_none()
}
93 changes: 55 additions & 38 deletions src/data/chats/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize};

use crate::data::filesystem::{read_from_file, write_to_file};

use super::model_loader::ModelLoader;

pub type ChatID = u128;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -203,7 +205,13 @@ impl Chat {
}
}
}
pub fn send_message_to_model(&mut self, prompt: String, loaded_file: &File, backend: &Backend) {
pub fn send_message_to_model(
&mut self,
prompt: String,
wanted_file: &File,
mut model_loader: ModelLoader,
backend: &Backend,
) {
let (tx, rx) = channel();
let mut messages: Vec<_> = self
.messages
Expand All @@ -228,7 +236,7 @@ impl Chat {
content: system_prompt.clone(),
role: Role::System,
name: None,
}
},
);
} else {
messages.insert(
Expand All @@ -237,15 +245,15 @@ impl Chat {
content: "You are a helpful, respectful, and honest assistant.".to_string(),
role: Role::System,
name: None,
}
},
);
}

let ip = &self.inferences_params;
let cmd = Command::Chat(
ChatRequestData {
messages,
model: loaded_file.name.clone(),
model: wanted_file.name.clone(),
frequency_penalty: Some(ip.frequency_penalty),
logprobs: None,
top_logprobs: None,
Expand Down Expand Up @@ -277,51 +285,60 @@ impl Chat {
content: prompt.clone(),
});

self.last_used_file_id = Some(loaded_file.id.clone());

self.messages.push(ChatMessage {
id: next_id + 1,
role: Role::Assistant,
username: Some(loaded_file.name.clone()),
username: Some(wanted_file.name.clone()),
content: "".to_string(),
});

let store_chat_tx = self.messages_update_sender.clone();
backend.command_sender.send(cmd).unwrap();
self.is_streaming = true;
thread::spawn(move || loop {
if let Ok(response) = rx.recv() {
match response {
Ok(ChatResponse::ChatResponseChunk(data)) => {
let mut is_done = false;

let _ = store_chat_tx.send(ChatTokenArrivalAction::AppendDelta(
data.choices[0].delta.content.clone(),
));

if let Some(_reason) = &data.choices[0].finish_reason {
is_done = true;
let _ = store_chat_tx.send(ChatTokenArrivalAction::StreamingDone);
}

SignalToUI::set_ui_signal();
if is_done {
let store_chat_tx = self.messages_update_sender.clone();
let wanted_file = wanted_file.clone();
let command_sender = backend.command_sender.clone();
thread::spawn(move || {
if let Err(err) = model_loader.load(wanted_file.id, command_sender.clone()) {
eprintln!("Error loading model: {}", err);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future idea: Eventually, all eprintln we have around in the codebase could trigger a Cx::post_action(ErrorNotification(...)) with the latest updates in makepad, to get feedback in the UI.

return;
}

command_sender.send(cmd).unwrap();

loop {
if let Ok(response) = rx.recv() {
match response {
Ok(ChatResponse::ChatResponseChunk(data)) => {
let mut is_done = false;

let _ = store_chat_tx.send(ChatTokenArrivalAction::AppendDelta(
data.choices[0].delta.content.clone(),
));

if let Some(_reason) = &data.choices[0].finish_reason {
is_done = true;
let _ = store_chat_tx.send(ChatTokenArrivalAction::StreamingDone);
}

SignalToUI::set_ui_signal();
if is_done {
break;
}
}
Ok(ChatResponse::ChatFinalResponseData(data)) => {
let _ = store_chat_tx.send(ChatTokenArrivalAction::AppendDelta(
data.choices[0].message.content.clone(),
));
let _ = store_chat_tx.send(ChatTokenArrivalAction::StreamingDone);
SignalToUI::set_ui_signal();
break;
}
Err(err) => eprintln!("Error receiving response chunk: {:?}", err),
}
Ok(ChatResponse::ChatFinalResponseData(data)) => {
let _ = store_chat_tx.send(ChatTokenArrivalAction::AppendDelta(
data.choices[0].message.content.clone(),
));
let _ = store_chat_tx.send(ChatTokenArrivalAction::StreamingDone);
SignalToUI::set_ui_signal();
break;
}
Err(err) => eprintln!("Error receiving response chunk: {:?}", err),
}
} else {
break;
};
} else {
break;
};
}
});

self.update_title_based_on_first_message();
Expand Down
Loading
Loading