Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep an async function marker when @pass_ is used #79

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ Release Notes
backward incompatible changes will be released along with bumping major
version component.

4.1.0
`````

(unreleased)

* Fix a bug when a coroutine function wrapped with ``@picobox.pass_()``
lost its coroutine function marker, i.e. ``inspect.iscoroutinefunction()``
returned ``False``.

4.0.0
`````

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Bugs = "https://github.com/ikalnytskyi/picobox/issues"
source = "vcs"

[tool.hatch.envs.test]
dependencies = ["pytest", "flask"]
dependencies = ["pytest", "pytest-asyncio", "flask"]
scripts.run = "python -m pytest --strict-markers {args:-vv}"

[tool.hatch.envs.lint]
Expand Down
10 changes: 9 additions & 1 deletion src/picobox/_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def decorator(fn):
return fn

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def fn_with_dependencies(*args, **kwargs):
signature = inspect.signature(fn)
arguments = signature.bind_partial(*args, **kwargs)

Expand All @@ -203,6 +203,14 @@ def wrapper(*args, **kwargs):
kwargs[as_] = self.get(key)
return fn(*args, **kwargs)

if inspect.iscoroutinefunction(fn):

@functools.wraps(fn)
async def wrapper(*args, **kwargs):
return await fn_with_dependencies(*args, **kwargs)
else:
wrapper = fn_with_dependencies

wrapper.__dependencies__ = [(key, as_)]
return wrapper

Expand Down
43 changes: 43 additions & 0 deletions tests/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,27 @@ def __init__(self, x):
assert Foo(*args, **kwargs).x == rv


@pytest.mark.asyncio()
@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
((1,), {}, 1),
((), {"x": 1}, 1),
((), {}, 42),
],
)
async def test_box_pass_coroutine(args, kwargs, rv, boxclass):
testbox = boxclass()
testbox.put("x", 42)

@testbox.pass_("x")
async def co(x):
return x

assert inspect.iscoroutinefunction(co)
assert await co(*args, **kwargs) == rv


@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
Expand Down Expand Up @@ -490,6 +511,28 @@ def fn(a, b, c, d):
assert len(fn()) == 3


@pytest.mark.asyncio()
async def test_box_pass_optimization_async(boxclass, request):
testbox = boxclass()
testbox.put("a", 1)
testbox.put("b", 1)
testbox.put("d", 1)

@testbox.pass_("a")
@testbox.pass_("b")
@testbox.pass_("d", as_="c")
async def fn(a, b, c):
backtrace = list(
itertools.dropwhile(
lambda frame: frame[2] != request.function.__name__,
traceback.extract_stack(),
)
)
return backtrace[1:-1]

assert len(await fn()) == 1


def test_chainbox_put_changes_box():
testbox = picobox.Box()
testchainbox = picobox.ChainBox(testbox)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test picobox's stack interface."""

import inspect
import itertools
import sys
import traceback
Expand Down Expand Up @@ -449,6 +450,28 @@ def __init__(self, x):
assert Foo(*args, **kwargs).x == rv


@pytest.mark.asyncio()
@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
((1,), {}, 1),
((), {"x": 1}, 1),
((), {}, 42),
],
)
async def test_box_pass_coroutine(boxclass, teststack, args, kwargs, rv):
testbox = boxclass()
testbox.put("x", 42)

@teststack.pass_("x")
async def co(x):
return x

with teststack.push(testbox):
assert inspect.iscoroutinefunction(co)
assert await co(*args, **kwargs) == rv


@pytest.mark.parametrize(
("args", "kwargs", "rv"),
[
Expand Down Expand Up @@ -567,6 +590,29 @@ def fn(a, b, c, d):
assert len(fn()) == 3


@pytest.mark.asyncio()
async def test_box_pass_optimization_async(boxclass, teststack, request):
testbox = boxclass()
testbox.put("a", 1)
testbox.put("b", 1)
testbox.put("d", 1)

@teststack.pass_("a")
@teststack.pass_("b")
@teststack.pass_("d", as_="c")
async def fn(a, b, c):
backtrace = list(
itertools.dropwhile(
lambda frame: frame[2] != request.function.__name__,
traceback.extract_stack(),
)
)
return backtrace[1:-1]

with teststack.push(testbox):
assert len(await fn()) == 1


def test_chainbox_put_changes_box(teststack):
testbox = picobox.Box()
testchainbox = picobox.ChainBox(testbox)
Expand Down