diff --git a/zt_backend/router.py b/zt_backend/router.py index 7223ba00..24abbd93 100644 --- a/zt_backend/router.py +++ b/zt_backend/router.py @@ -13,6 +13,7 @@ import threading import traceback import sys +import asyncio import trace class ConnectionManager: @@ -88,6 +89,7 @@ def kill(self): user_states={} user_timers={} user_threads={} +user_message_tasks={} notebook_state=UserState('') run_mode = settings.run_mode @@ -105,6 +107,7 @@ def ws_url(): async def run_code(websocket: WebSocket): global current_thread if(run_mode=='dev'): + message_send = asyncio.create_task(websocket_message_sender(notebook_state)) await manager.connect(websocket) try: while True: @@ -115,6 +118,8 @@ async def run_code(websocket: WebSocket): current_thread.start() except WebSocketDisconnect: manager.disconnect(websocket) + finally: + message_send.cancel() @router.websocket("/ws/component_run") async def component_run(websocket: WebSocket): @@ -255,6 +260,7 @@ async def load_notebook(websocket: WebSocket): userId = str(uuid.uuid4()) notebook_start.userId = userId user_states[userId]=UserState(userId) + user_message_tasks[userId]=asyncio.create_task(websocket_message_sender(user_states[userId])) timer_set(userId, 1800) cells = [] components={} @@ -318,14 +324,32 @@ async def stop_execution(websocket: WebSocket): except WebSocketDisconnect: manager.disconnect(websocket) +@router.on_event('shutdown') +def shutdown(): + if current_thread: + current_thread.kill() + for user_id in user_threads: + if user_threads[user_id]: + user_threads[user_id].kill() + for user_id in user_timers: + if user_timers[user_id]: + user_timers[user_id].cancel() + for user_id in user_message_tasks: + if user_message_tasks[user_id]: + user_message_tasks[user_id].cancel() + def remove_user_state(user_id): try: if user_id in user_timers: # Cancel and remove the associated timer timer = user_timers[user_id] + message_sender = user_message_tasks[user_id] if timer: timer.cancel() del user_timers[user_id] + if message_sender: + message_sender.cancel() + del user_message_tasks[user_id] if user_id in user_states: del user_states[user_id] logger.debug("User state removed for user %s", user_id) except Exception as e: diff --git a/zt_backend/runner/execute_code.py b/zt_backend/runner/execute_code.py index 159a3dfb..1fe9b51b 100644 --- a/zt_backend/runner/execute_code.py +++ b/zt_backend/runner/execute_code.py @@ -86,7 +86,7 @@ def execute_request(request: request.Request, state: UserState): for code_cell_id in downstream_cells: code_cell = dependency_graph.cells[code_cell_id] - asyncio.run(execution_state.websocket.send_json({"cell_id": code_cell_id, "clear_output": True})) + execution_state.message_queue.put_nowait({"cell_id": code_cell_id, "clear_output": True}) execution_state.io_output = StringIO() execute_cell(code_cell_id, code_cell, component_globals, dependency_graph, execution_state) try: @@ -101,7 +101,7 @@ def execute_request(request: request.Request, state: UserState): cell_response = response.CellResponse(id=code_cell_id, layout=layout, components=execution_state.current_cell_components, output=execution_state.io_output.getvalue()) cell_outputs.append(cell_response) - asyncio.run(execution_state.websocket.send_json(cell_response.model_dump_json())) + execution_state.message_queue.put_nowait(cell_response.model_dump_json()) execution_state.current_cell_components.clear() execution_state.current_cell_layout.clear() execution_state.cell_outputs_dict['previous_dependecy_graph'] = dependency_graph @@ -111,8 +111,7 @@ def execute_request(request: request.Request, state: UserState): execution_response = response.Response(cells=cell_outputs) if settings.run_mode=='dev': globalStateUpdate(run_response=execution_response) - asyncio.run(execution_state.websocket.send_json({"complete": True})) - return execution_response + execution_state.message_queue.put_nowait({"complete": True}) def execute_cell(code_cell_id, code_cell, component_globals, dependency_graph, execution_state: UserContext): class WebSocketStream: @@ -120,7 +119,7 @@ def write(self, message): user_state = UserContext.get_state() if user_state: user_state.io_output.write(message) - asyncio.run(user_state.websocket.send_json({"cell_id": code_cell_id, "output": message})) + user_state.message_queue.put_nowait({"cell_id": code_cell_id, "output": message}) def flush(self): pass diff --git a/zt_backend/runner/user_state.py b/zt_backend/runner/user_state.py index f2d711b3..adeef8f2 100644 --- a/zt_backend/runner/user_state.py +++ b/zt_backend/runner/user_state.py @@ -1,4 +1,5 @@ import threading +import asyncio class UserState: def __init__(self, user_id): @@ -11,6 +12,7 @@ def __init__(self, user_id): self.cell_outputs_dict = {} self.websocket = None self.io_output = None + self.message_queue = asyncio.Queue() class UserContext: _state = threading.local() diff --git a/zt_backend/utils.py b/zt_backend/utils.py index 4f192c9e..a5f90d70 100644 --- a/zt_backend/utils.py +++ b/zt_backend/utils.py @@ -1,4 +1,5 @@ from typing import OrderedDict +from zt_backend.runner.user_state import UserState from zt_backend.models import request, notebook, response from dictdiffer import diff import logging @@ -156,4 +157,9 @@ def save_toml(): def get_code_completions(cell_id:str, code: str, line: int, column: int) -> list: script = jedi.Script(code) completions = script.complete(line, column) - return {"cell_id": cell_id, "completions": [{"label": completion.name, "type": completion.type} for completion in completions]} \ No newline at end of file + return {"cell_id": cell_id, "completions": [{"label": completion.name, "type": completion.type} for completion in completions]} + +async def websocket_message_sender(execution_state: UserState): + while True: + message = await execution_state.message_queue.get() + await execution_state.websocket.send_json(message) \ No newline at end of file