From 13a7711405611edad596cfb7f75f35255d97bf67 Mon Sep 17 00:00:00 2001 From: Ihor Kalnytskyi Date: Mon, 27 Nov 2023 01:51:12 +0200 Subject: [PATCH] Keep an async function marker when @pass_ is used Unfortunately, when the `@picobox.pass_()` decorator is used, a wrapped coroutine function (i.e. an async function) loses its a coroutine function marker, i.e. `inspect.iscoroutinefunction()` returns `False` for such function. Turns out that there are a lot of software out there that support both sync and async interfaces, and may choose one based on the type of a passed function. For instance, Starlette, a web-framework, checks a provided route function for being a coroutine function before choosing how to execute in (i.e. in an event loop or in a thread pool). This patch fixes `@picobox.pass_()` to return a coroutine function when a wrapped function is also a coroutine function. Fixes #78 --- docs/index.rst | 9 +++++++++ pyproject.toml | 2 +- src/picobox/_box.py | 10 +++++++++- tests/test_box.py | 43 ++++++++++++++++++++++++++++++++++++++++++ tests/test_stack.py | 46 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 108 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index eff5e2b..12c0e27 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -358,6 +358,15 @@ Release Notes backward incompatible changes will be released along with bumping major version component. +4.1.0 +````` + +(unreleased) + +* Fix a bug when a coroutine function wrapped with ``@picobox.pass_()`` + lost its coroutine function marker, i.e. ``inspect.iscoroutinefunction()`` + returned ``False``. + 4.0.0 ````` diff --git a/pyproject.toml b/pyproject.toml index d3c2d2d..1fdcd19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ Bugs = "https://github.com/ikalnytskyi/picobox/issues" source = "vcs" [tool.hatch.envs.test] -dependencies = ["pytest", "flask"] +dependencies = ["pytest", "pytest-asyncio", "flask"] scripts.run = "python -m pytest --strict-markers {args:-vv}" [tool.hatch.envs.lint] diff --git a/src/picobox/_box.py b/src/picobox/_box.py index 3c5aaca..77bed5a 100644 --- a/src/picobox/_box.py +++ b/src/picobox/_box.py @@ -187,7 +187,7 @@ def decorator(fn): return fn @functools.wraps(fn) - def wrapper(*args, **kwargs): + def fn_with_dependencies(*args, **kwargs): signature = inspect.signature(fn) arguments = signature.bind_partial(*args, **kwargs) @@ -203,6 +203,14 @@ def wrapper(*args, **kwargs): kwargs[as_] = self.get(key) return fn(*args, **kwargs) + if inspect.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + return await fn_with_dependencies(*args, **kwargs) + else: + wrapper = fn_with_dependencies + wrapper.__dependencies__ = [(key, as_)] return wrapper diff --git a/tests/test_box.py b/tests/test_box.py index 9b7cfb4..f37cb4b 100644 --- a/tests/test_box.py +++ b/tests/test_box.py @@ -386,6 +386,27 @@ def __init__(self, x): assert Foo(*args, **kwargs).x == rv +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("args", "kwargs", "rv"), + [ + ((1,), {}, 1), + ((), {"x": 1}, 1), + ((), {}, 42), + ], +) +async def test_box_pass_coroutine(args, kwargs, rv, boxclass): + testbox = boxclass() + testbox.put("x", 42) + + @testbox.pass_("x") + async def co(x): + return x + + assert inspect.iscoroutinefunction(co) + assert await co(*args, **kwargs) == rv + + @pytest.mark.parametrize( ("args", "kwargs", "rv"), [ @@ -490,6 +511,28 @@ def fn(a, b, c, d): assert len(fn()) == 3 +@pytest.mark.asyncio() +async def test_box_pass_optimization_async(boxclass, request): + testbox = boxclass() + testbox.put("a", 1) + testbox.put("b", 1) + testbox.put("d", 1) + + @testbox.pass_("a") + @testbox.pass_("b") + @testbox.pass_("d", as_="c") + async def fn(a, b, c): + backtrace = list( + itertools.dropwhile( + lambda frame: frame[2] != request.function.__name__, + traceback.extract_stack(), + ) + ) + return backtrace[1:-1] + + assert len(await fn()) == 1 + + def test_chainbox_put_changes_box(): testbox = picobox.Box() testchainbox = picobox.ChainBox(testbox) diff --git a/tests/test_stack.py b/tests/test_stack.py index a6be870..3570268 100644 --- a/tests/test_stack.py +++ b/tests/test_stack.py @@ -1,5 +1,6 @@ """Test picobox's stack interface.""" +import inspect import itertools import sys import traceback @@ -449,6 +450,28 @@ def __init__(self, x): assert Foo(*args, **kwargs).x == rv +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("args", "kwargs", "rv"), + [ + ((1,), {}, 1), + ((), {"x": 1}, 1), + ((), {}, 42), + ], +) +async def test_box_pass_coroutine(boxclass, teststack, args, kwargs, rv): + testbox = boxclass() + testbox.put("x", 42) + + @teststack.pass_("x") + async def co(x): + return x + + with teststack.push(testbox): + assert inspect.iscoroutinefunction(co) + assert await co(*args, **kwargs) == rv + + @pytest.mark.parametrize( ("args", "kwargs", "rv"), [ @@ -567,6 +590,29 @@ def fn(a, b, c, d): assert len(fn()) == 3 +@pytest.mark.asyncio() +async def test_box_pass_optimization_async(boxclass, teststack, request): + testbox = boxclass() + testbox.put("a", 1) + testbox.put("b", 1) + testbox.put("d", 1) + + @teststack.pass_("a") + @teststack.pass_("b") + @teststack.pass_("d", as_="c") + async def fn(a, b, c): + backtrace = list( + itertools.dropwhile( + lambda frame: frame[2] != request.function.__name__, + traceback.extract_stack(), + ) + ) + return backtrace[1:-1] + + with teststack.push(testbox): + assert len(await fn()) == 1 + + def test_chainbox_put_changes_box(teststack): testbox = picobox.Box() testchainbox = picobox.ChainBox(testbox)