diff --git a/granian/_futures.py b/granian/_futures.py index a819d068..9793af18 100644 --- a/granian/_futures.py +++ b/granian/_futures.py @@ -30,7 +30,7 @@ class _CBSchedulerAIO(_BaseCBScheduler): def __init__(self, loop, ctx, cb, aio_tenter, aio_texit): super().__init__() - self._schedule_fn = _cbsched_schedule(loop, ctx, loop.create_task, cb) + self._schedule_fn = _cbsched_aioschedule(loop, ctx, cb) def _new_cbscheduler(loop, cb, impl_asyncio=False): @@ -43,3 +43,14 @@ def _schedule(watcher): loop.call_soon_threadsafe(run, cb(watcher), context=ctx) return _schedule + + +def _cbsched_aioschedule(loop, ctx, cb): + def _run(coro, watcher): + task = loop.create_task(coro) + watcher.taskref(task) + + def _schedule(watcher): + loop.call_soon_threadsafe(_run, cb(watcher), watcher, context=ctx) + + return _schedule diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index e437017a..053fea64 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -1,6 +1,9 @@ use pyo3::prelude::*; use pyo3::types::PyDict; -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::SocketAddr, + sync::{Arc, OnceLock}, +}; use tokio::sync::oneshot; use super::{ @@ -38,12 +41,19 @@ macro_rules! callback_impl_done_err { }; } +macro_rules! callback_impl_taskref { + ($self:expr, $py:expr, $task:expr) => { + let _ = $self.aio_taskref.set($task.clone_ref($py)); + }; +} + #[pyclass(frozen)] pub(crate) struct CallbackWatcherHTTP { #[pyo3(get)] proto: Py, #[pyo3(get)] scope: Py, + aio_taskref: OnceLock, } impl CallbackWatcherHTTP { @@ -51,6 +61,7 @@ impl CallbackWatcherHTTP { Self { proto: Py::new(py, proto).unwrap(), scope: scope.unbind(), + aio_taskref: OnceLock::new(), } } } @@ -64,6 +75,10 @@ impl CallbackWatcherHTTP { fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value(err)); } + + fn taskref(&self, py: Python, task: PyObject) { + callback_impl_taskref!(self, py, task); + } } #[pyclass(frozen)] @@ -72,6 +87,7 @@ pub(crate) struct CallbackWatcherWebsocket { proto: Py, #[pyo3(get)] scope: Py, + aio_taskref: OnceLock, } impl CallbackWatcherWebsocket { @@ -79,6 +95,7 @@ impl CallbackWatcherWebsocket { Self { proto: Py::new(py, proto).unwrap(), scope: scope.unbind(), + aio_taskref: OnceLock::new(), } } } @@ -92,6 +109,10 @@ impl CallbackWatcherWebsocket { fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value(err)); } + + fn taskref(&self, py: Python, task: PyObject) { + callback_impl_taskref!(self, py, task); + } } // NOTE: we cannot use single `impl` function as structs with pyclass won't handle diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index 94a97c74..d00ef51e 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -1,4 +1,5 @@ use pyo3::prelude::*; +use std::sync::OnceLock; use tokio::sync::oneshot; use super::{ @@ -33,12 +34,19 @@ macro_rules! callback_impl_done_err { }; } +macro_rules! callback_impl_taskref { + ($self:expr, $py:expr, $task:expr) => { + let _ = $self.aio_taskref.set($task.clone_ref($py)); + }; +} + #[pyclass(frozen)] pub(crate) struct CallbackWatcherHTTP { #[pyo3(get)] proto: Py, #[pyo3(get)] scope: Py, + aio_taskref: OnceLock, } impl CallbackWatcherHTTP { @@ -46,6 +54,7 @@ impl CallbackWatcherHTTP { Self { proto: Py::new(py, proto).unwrap(), scope: Py::new(py, scope).unwrap(), + aio_taskref: OnceLock::new(), } } } @@ -59,6 +68,10 @@ impl CallbackWatcherHTTP { fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value(err)); } + + fn taskref(&self, py: Python, task: PyObject) { + callback_impl_taskref!(self, py, task); + } } #[pyclass(frozen)] @@ -67,6 +80,7 @@ pub(crate) struct CallbackWatcherWebsocket { proto: Py, #[pyo3(get)] scope: Py, + aio_taskref: OnceLock, } impl CallbackWatcherWebsocket { @@ -74,6 +88,7 @@ impl CallbackWatcherWebsocket { Self { proto: Py::new(py, proto).unwrap(), scope: Py::new(py, scope).unwrap(), + aio_taskref: OnceLock::new(), } } } @@ -87,6 +102,10 @@ impl CallbackWatcherWebsocket { fn err(&self, err: Bound) { callback_impl_done_err!(self, &PyErr::from_value(err)); } + + fn taskref(&self, py: Python, task: PyObject) { + callback_impl_taskref!(self, py, task); + } } #[inline]