Skip to content

Commit

Permalink
Keep task reference in asyncio task-impl (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Dec 24, 2024
1 parent af1d130 commit 6674a71
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
13 changes: 12 additions & 1 deletion granian/_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class _CBSchedulerAIO(_BaseCBScheduler):

def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
super().__init__()
self._schedule_fn = _cbsched_schedule(loop, ctx, loop.create_task, cb)
self._schedule_fn = _cbsched_aioschedule(loop, ctx, cb)


def _new_cbscheduler(loop, cb, impl_asyncio=False):
Expand All @@ -43,3 +43,14 @@ def _schedule(watcher):
loop.call_soon_threadsafe(run, cb(watcher), context=ctx)

return _schedule


def _cbsched_aioschedule(loop, ctx, cb):
def _run(coro, watcher):
task = loop.create_task(coro)
watcher.taskref(task)

def _schedule(watcher):
loop.call_soon_threadsafe(_run, cb(watcher), watcher, context=ctx)

return _schedule
23 changes: 22 additions & 1 deletion src/asgi/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use pyo3::prelude::*;
use pyo3::types::PyDict;
use std::{net::SocketAddr, sync::Arc};
use std::{
net::SocketAddr,
sync::{Arc, OnceLock},
};
use tokio::sync::oneshot;

use super::{
Expand Down Expand Up @@ -38,19 +41,27 @@ macro_rules! callback_impl_done_err {
};
}

macro_rules! callback_impl_taskref {
($self:expr, $py:expr, $task:expr) => {
let _ = $self.aio_taskref.set($task.clone_ref($py));
};
}

#[pyclass(frozen)]
pub(crate) struct CallbackWatcherHTTP {
#[pyo3(get)]
proto: Py<HTTPProtocol>,
#[pyo3(get)]
scope: Py<PyDict>,
aio_taskref: OnceLock<PyObject>,
}

impl CallbackWatcherHTTP {
pub fn new(py: Python, proto: HTTPProtocol, scope: Bound<PyDict>) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
scope: scope.unbind(),
aio_taskref: OnceLock::new(),
}
}
}
Expand All @@ -64,6 +75,10 @@ impl CallbackWatcherHTTP {
fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value(err));
}

fn taskref(&self, py: Python, task: PyObject) {
callback_impl_taskref!(self, py, task);
}
}

#[pyclass(frozen)]
Expand All @@ -72,13 +87,15 @@ pub(crate) struct CallbackWatcherWebsocket {
proto: Py<WebsocketProtocol>,
#[pyo3(get)]
scope: Py<PyDict>,
aio_taskref: OnceLock<PyObject>,
}

impl CallbackWatcherWebsocket {
pub fn new(py: Python, proto: WebsocketProtocol, scope: Bound<PyDict>) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
scope: scope.unbind(),
aio_taskref: OnceLock::new(),
}
}
}
Expand All @@ -92,6 +109,10 @@ impl CallbackWatcherWebsocket {
fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value(err));
}

fn taskref(&self, py: Python, task: PyObject) {
callback_impl_taskref!(self, py, task);
}
}

// NOTE: we cannot use single `impl` function as structs with pyclass won't handle
Expand Down
19 changes: 19 additions & 0 deletions src/rsgi/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use pyo3::prelude::*;
use std::sync::OnceLock;
use tokio::sync::oneshot;

use super::{
Expand Down Expand Up @@ -33,19 +34,27 @@ macro_rules! callback_impl_done_err {
};
}

macro_rules! callback_impl_taskref {
($self:expr, $py:expr, $task:expr) => {
let _ = $self.aio_taskref.set($task.clone_ref($py));
};
}

#[pyclass(frozen)]
pub(crate) struct CallbackWatcherHTTP {
#[pyo3(get)]
proto: Py<HTTPProtocol>,
#[pyo3(get)]
scope: Py<HTTPScope>,
aio_taskref: OnceLock<PyObject>,
}

impl CallbackWatcherHTTP {
pub fn new(py: Python, proto: HTTPProtocol, scope: HTTPScope) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
scope: Py::new(py, scope).unwrap(),
aio_taskref: OnceLock::new(),
}
}
}
Expand All @@ -59,6 +68,10 @@ impl CallbackWatcherHTTP {
fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value(err));
}

fn taskref(&self, py: Python, task: PyObject) {
callback_impl_taskref!(self, py, task);
}
}

#[pyclass(frozen)]
Expand All @@ -67,13 +80,15 @@ pub(crate) struct CallbackWatcherWebsocket {
proto: Py<WebsocketProtocol>,
#[pyo3(get)]
scope: Py<WebsocketScope>,
aio_taskref: OnceLock<PyObject>,
}

impl CallbackWatcherWebsocket {
pub fn new(py: Python, proto: WebsocketProtocol, scope: WebsocketScope) -> Self {
Self {
proto: Py::new(py, proto).unwrap(),
scope: Py::new(py, scope).unwrap(),
aio_taskref: OnceLock::new(),
}
}
}
Expand All @@ -87,6 +102,10 @@ impl CallbackWatcherWebsocket {
fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value(err));
}

fn taskref(&self, py: Python, task: PyObject) {
callback_impl_taskref!(self, py, task);
}
}

#[inline]
Expand Down

0 comments on commit 6674a71

Please sign in to comment.