From 47e892ef1855c879fd2f841bea58a3a4a6a67eff Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Thu, 18 Jan 2024 17:42:29 +0100 Subject: [PATCH] Allow cancel on `PyFutureAwaitable` --- src/callbacks.rs | 53 ++++++++++++++++++++++++++++++++++++++++-------- src/runtime.rs | 22 ++++++++++++++------ 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/callbacks.rs b/src/callbacks.rs index 1298dc9a..88cb663d 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -1,4 +1,6 @@ use pyo3::{prelude::*, pyclass::IterNextOutput, sync::GILOnceCell}; +use std::sync::Arc; +use tokio::sync::Notify; static CONTEXTVARS: GILOnceCell = GILOnceCell::new(); static CONTEXT: GILOnceCell = GILOnceCell::new(); @@ -74,27 +76,34 @@ impl PyIterAwaitable { #[pyclass] pub(crate) struct PyFutureAwaitable { - fut_spawner: Option) + Send>>, + fut_spawner: Option, Arc, Py) + Send>>, result: Option>, event_loop: PyObject, + callback: Option, + cancel_tx: Arc, py_block: bool, + py_cancelled: bool, } impl PyFutureAwaitable { pub(crate) fn new( - fut_spawner: Box) + Send>, + fut_spawner: Box, Arc, Py) + Send>, event_loop: PyObject, ) -> Self { Self { fut_spawner: Some(fut_spawner), result: None, event_loop, + callback: None, + cancel_tx: Arc::new(Notify::new()), py_block: true, + py_cancelled: false, } } - pub(crate) fn set_result(mut pyself: PyRefMut<'_, Self>, result: PyResult) { + pub(crate) fn set_result(mut pyself: PyRefMut<'_, Self>, result: PyResult) -> Option { pyself.result = Some(result); + pyself.callback.take() } } @@ -118,18 +127,46 @@ impl PyFutureAwaitable { self.event_loop.clone_ref(py) } - fn add_done_callback(mut pyself: PyRefMut<'_, Self>, cb: PyObject, context: PyObject) -> PyResult<()> { + #[pyo3(signature = (cb, context=None))] + fn add_done_callback(mut pyself: PyRefMut<'_, Self>, cb: PyObject, context: Option) -> PyResult<()> { + pyself.callback = Some(cb); if let Some(spawner) = pyself.fut_spawner.take() { - (spawner)(cb, context, pyself.into()); + (spawner)(context, pyself.cancel_tx.clone(), pyself.into()); } Ok(()) } - fn cancel(&self) -> bool { - false + #[allow(unused)] + fn remove_done_callback(&mut self, cb: PyObject) -> i32 { + self.callback = None; + 1 + } + + #[allow(unused)] + #[pyo3(signature = (msg=None))] + fn cancel(&mut self, msg: Option) -> bool { + if self.done() { + return false; + } + self.py_cancelled = true; + self.cancel_tx.notify_one(); + true + } + + fn done(&self) -> bool { + self.result.is_some() || self.py_cancelled + } + + fn result(pyself: PyRef<'_, Self>) -> PyResult { + match &pyself.result { + Some(res) => { + let py = pyself.py(); + res.as_ref().map(|v| v.clone_ref(py)).map_err(|err| err.clone_ref(py)) + } + _ => Ok(pyself.py().None()), + } } - fn result(&self) {} fn exception(&self) {} fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> { diff --git a/src/runtime.rs b/src/runtime.rs index 423b4b01..8494ae1b 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -9,6 +9,7 @@ use std::{ }; use tokio::{ runtime::Builder, + sync::Notify, task::{JoinHandle, LocalSet}, }; @@ -223,15 +224,24 @@ where let task_locals = get_current_locals::(py)?; let event_loop = task_locals.event_loop(py).to_object(py); let event_loop_aw = event_loop.clone(); - let fut_spawner = move |cb: PyObject, context: PyObject, aw: Py| { + let fut_spawner = move |context: Option, cancel_tx: Arc, aw: Py| { rt.spawn(async move { - let result = fut.await; + let result = tokio::select! { + result = fut => { + result + }, + () = cancel_tx.notified() => { + Err(pyo3::exceptions::asyncio::CancelledError::new_err("Task cancelled")) + } + }; Python::with_gil(|py| { - PyFutureAwaitable::set_result(aw.borrow_mut(py), result.map(|v| v.into_py(py))); - let kwctx = pyo3::types::PyDict::new(py); - kwctx.set_item(pyo3::intern!(py, "context"), context).unwrap(); - let _ = event_loop.call_method(py, pyo3::intern!(py, "call_soon_threadsafe"), (cb, aw), Some(kwctx)); + if let Some(cb) = PyFutureAwaitable::set_result(aw.borrow_mut(py), result.map(|v| v.into_py(py))) { + let kwctx = pyo3::types::PyDict::new(py); + kwctx.set_item(pyo3::intern!(py, "context"), context).unwrap(); + let _ = + event_loop.call_method(py, pyo3::intern!(py, "call_soon_threadsafe"), (cb, aw), Some(kwctx)); + } }); }); };