Skip to content

Commit

Permalink
feat: apply pytest rewrites for better testing (#3294)
Browse files Browse the repository at this point in the history
merry merry 🎄 

---

Slight improvement on pytest errors, by manually running pytest's
assertion rewrite magic. So that takes us from this:


![image](https://github.com/user-attachments/assets/e2ddf88b-a6c1-4f89-a98b-20f70b0e8c6d)

to this


![image](https://github.com/user-attachments/assets/317a9d4b-faa1-4d9b-9f54-fa5e8c747e39)

Also a little fun- this rewrite can be applied in general and shown in
editor.


![image](https://github.com/user-attachments/assets/18934182-732a-4dad-8e91-12b9abd32f7c)

Not sure the overhead this might have- so not sure about setting it
generally, but could set `test_rewrite=True` if the cell name starts
with `test_`. Wanted to run this by you all first


@akshayka OR @mscolnick
  • Loading branch information
dmadisetti authored Dec 26, 2024
1 parent a985144 commit 4e4d7f0
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 8 deletions.
14 changes: 8 additions & 6 deletions marimo/_ast/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,20 +460,22 @@ def cell_decorator(
)

def _register(func: Callable[..., Any]) -> Cell:
# Use PYTEST_VERSION here, opposed to PYTEST_CURRENT_TEST, in
# order to allow execution during test collection.
is_top_level_pytest = (
"PYTEST_VERSION" in os.environ
and "PYTEST_CURRENT_TEST" not in os.environ
)
cell = cell_factory(
func,
cell_id=self.create_cell_id(),
anonymous_file=app._app._anonymous_file if app else False,
test_rewrite=is_top_level_pytest,
)
cell._cell.configure(cell_config)
self._register_cell(cell, app=app)
# Manually set the signature for pytest.
# Use PYTEST_VERSION here, opposed to PYTEST_CURRENT_TEST, in
# order to allow execution during test collection.
if (
"PYTEST_VERSION" in os.environ
and "PYTEST_CURRENT_TEST" not in os.environ
):
if is_top_level_pytest:
func = wrap_fn_for_pytest(func, cell)
# NB. in place metadata update.
functools.wraps(func)(cell)
Expand Down
25 changes: 24 additions & 1 deletion marimo/_ast/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tokenize import TokenInfo, tokenize
from typing import TYPE_CHECKING, Any, Callable, Optional

from marimo import _loggers
from marimo._ast.cell import (
Cell,
CellId_t,
Expand All @@ -23,6 +24,8 @@
from marimo._utils.tmpdir import get_tmpdir
from marimo._utils.variables import is_local

LOGGER = _loggers.marimo_logger()

if TYPE_CHECKING:
from collections.abc import Iterator

Expand Down Expand Up @@ -96,6 +99,7 @@ def compile_cell(
cell_id: CellId_t,
source_position: Optional[SourcePosition] = None,
carried_imports: list[ImportData] | None = None,
test_rewrite: bool = False,
) -> CellImpl:
# Replace non-breaking spaces with regular spaces -- some frontends
# send nbsp in place of space, which is a syntax error.
Expand Down Expand Up @@ -163,6 +167,21 @@ def compile_cell(
# since there is an actual file to read from.
cache(filename, code)

# pytest assertion rewriting, gives more context for assertion failures.
if test_rewrite:
# pytest is not required, so fail gracefully if needed
try:
from _pytest.assertion.rewrite import ( # type: ignore
rewrite_asserts,
)

rewrite_asserts(module, code.encode("utf-8"), module_path=filename)
# general catch-all, in case of internal pytest API changes
except Exception:
LOGGER.warning(
"pytest is not installed, skipping assertion rewriting"
)

flags = ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
body = compile(module, filename, mode="exec", flags=flags)
last_expr = compile(expr, filename, mode="eval", flags=flags)
Expand Down Expand Up @@ -214,6 +233,7 @@ def cell_factory(
f: Callable[..., Any],
cell_id: CellId_t,
anonymous_file: bool = False,
test_rewrite: bool = False,
) -> Cell:
"""Creates a cell from a function.
Expand Down Expand Up @@ -340,6 +360,9 @@ def cell_factory(
return Cell(
_name=f.__name__,
_cell=compile_cell(
cell_code, cell_id=cell_id, source_position=source_position
cell_code,
cell_id=cell_id,
source_position=source_position,
test_rewrite=test_rewrite,
),
)
2 changes: 1 addition & 1 deletion marimo/_ast/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def wrap_fn_for_pytest(
local = {"stub": cell.__call__, "Any": Any}
eval(compile(fn, inspect.getfile(func), "exec"), local)

# The remaining expected attributes is needed to ensure attribute count
# The remaining expected attributes are needed to ensure attribute count
# matches.
cell._pytest_reserved = reserved

Expand Down
1 change: 1 addition & 0 deletions marimo/_messaging/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright 2024 Marimo. All rights reserved.
import uuid
from contextvars import ContextVar
from dataclasses import dataclass
Expand Down
15 changes: 15 additions & 0 deletions tests/_ast/test_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ def test_cell_extra_refs_fail(mo): # noqa: ARG001
@app.cell
def test_cell_args_resolved_by_name(mo): # noqa: ARG001
assert x # noqa: F821


@app.cell
def test_cell_assert_rewritten():
import pytest

a = 1
b = 2

with pytest.raises(AssertionError) as exc_info:
assert a + b == a * b

# Check expansion works. Without rewrite, this just produces
# "AssertionError", without showing the expanded expression.
assert "assert (1 + 2) == (1 * 2)" in str(exc_info.value)

0 comments on commit 4e4d7f0

Please sign in to comment.