Skip to content

Commit

Permalink
Fixed blocking message send with asyncio (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
Carson-Shaar authored Dec 11, 2023
1 parent f4fb18c commit ea9a9b1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 6 deletions.
24 changes: 24 additions & 0 deletions zt_backend/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
import traceback
import sys
import asyncio
import trace

class ConnectionManager:
Expand Down Expand Up @@ -88,6 +89,7 @@ def kill(self):
user_states={}
user_timers={}
user_threads={}
user_message_tasks={}
notebook_state=UserState('')
run_mode = settings.run_mode

Expand All @@ -105,6 +107,7 @@ def ws_url():
async def run_code(websocket: WebSocket):
global current_thread
if(run_mode=='dev'):
message_send = asyncio.create_task(websocket_message_sender(notebook_state))
await manager.connect(websocket)
try:
while True:
Expand All @@ -115,6 +118,8 @@ async def run_code(websocket: WebSocket):
current_thread.start()
except WebSocketDisconnect:
manager.disconnect(websocket)
finally:
message_send.cancel()

@router.websocket("/ws/component_run")
async def component_run(websocket: WebSocket):
Expand Down Expand Up @@ -255,6 +260,7 @@ async def load_notebook(websocket: WebSocket):
userId = str(uuid.uuid4())
notebook_start.userId = userId
user_states[userId]=UserState(userId)
user_message_tasks[userId]=asyncio.create_task(websocket_message_sender(user_states[userId]))
timer_set(userId, 1800)
cells = []
components={}
Expand Down Expand Up @@ -318,14 +324,32 @@ async def stop_execution(websocket: WebSocket):
except WebSocketDisconnect:
manager.disconnect(websocket)

@router.on_event('shutdown')
def shutdown():
if current_thread:
current_thread.kill()
for user_id in user_threads:
if user_threads[user_id]:
user_threads[user_id].kill()
for user_id in user_timers:
if user_timers[user_id]:
user_timers[user_id].cancel()
for user_id in user_message_tasks:
if user_message_tasks[user_id]:
user_message_tasks[user_id].cancel()

def remove_user_state(user_id):
try:
if user_id in user_timers:
# Cancel and remove the associated timer
timer = user_timers[user_id]
message_sender = user_message_tasks[user_id]
if timer:
timer.cancel()
del user_timers[user_id]
if message_sender:
message_sender.cancel()
del user_message_tasks[user_id]
if user_id in user_states: del user_states[user_id]
logger.debug("User state removed for user %s", user_id)
except Exception as e:
Expand Down
9 changes: 4 additions & 5 deletions zt_backend/runner/execute_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def execute_request(request: request.Request, state: UserState):

for code_cell_id in downstream_cells:
code_cell = dependency_graph.cells[code_cell_id]
asyncio.run(execution_state.websocket.send_json({"cell_id": code_cell_id, "clear_output": True}))
execution_state.message_queue.put_nowait({"cell_id": code_cell_id, "clear_output": True})
execution_state.io_output = StringIO()
execute_cell(code_cell_id, code_cell, component_globals, dependency_graph, execution_state)
try:
Expand All @@ -101,7 +101,7 @@ def execute_request(request: request.Request, state: UserState):

cell_response = response.CellResponse(id=code_cell_id, layout=layout, components=execution_state.current_cell_components, output=execution_state.io_output.getvalue())
cell_outputs.append(cell_response)
asyncio.run(execution_state.websocket.send_json(cell_response.model_dump_json()))
execution_state.message_queue.put_nowait(cell_response.model_dump_json())
execution_state.current_cell_components.clear()
execution_state.current_cell_layout.clear()
execution_state.cell_outputs_dict['previous_dependecy_graph'] = dependency_graph
Expand All @@ -111,16 +111,15 @@ def execute_request(request: request.Request, state: UserState):
execution_response = response.Response(cells=cell_outputs)
if settings.run_mode=='dev':
globalStateUpdate(run_response=execution_response)
asyncio.run(execution_state.websocket.send_json({"complete": True}))
return execution_response
execution_state.message_queue.put_nowait({"complete": True})

def execute_cell(code_cell_id, code_cell, component_globals, dependency_graph, execution_state: UserContext):
class WebSocketStream:
def write(self, message):
user_state = UserContext.get_state()
if user_state:
user_state.io_output.write(message)
asyncio.run(user_state.websocket.send_json({"cell_id": code_cell_id, "output": message}))
user_state.message_queue.put_nowait({"cell_id": code_cell_id, "output": message})

def flush(self):
pass
Expand Down
2 changes: 2 additions & 0 deletions zt_backend/runner/user_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
import asyncio

class UserState:
def __init__(self, user_id):
Expand All @@ -11,6 +12,7 @@ def __init__(self, user_id):
self.cell_outputs_dict = {}
self.websocket = None
self.io_output = None
self.message_queue = asyncio.Queue()

class UserContext:
_state = threading.local()
Expand Down
8 changes: 7 additions & 1 deletion zt_backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import OrderedDict
from zt_backend.runner.user_state import UserState
from zt_backend.models import request, notebook, response
from dictdiffer import diff
import logging
Expand Down Expand Up @@ -156,4 +157,9 @@ def save_toml():
def get_code_completions(cell_id:str, code: str, line: int, column: int) -> list:
script = jedi.Script(code)
completions = script.complete(line, column)
return {"cell_id": cell_id, "completions": [{"label": completion.name, "type": completion.type} for completion in completions]}
return {"cell_id": cell_id, "completions": [{"label": completion.name, "type": completion.type} for completion in completions]}

async def websocket_message_sender(execution_state: UserState):
while True:
message = await execution_state.message_queue.get()
await execution_state.websocket.send_json(message)

0 comments on commit ea9a9b1

Please sign in to comment.