Skip to content

Commit

Permalink
tidy: Tweaks to enable pytest on marimo nbs (#3238)
Browse files Browse the repository at this point in the history
I was really interested in having tests in notebooks (re discussion
yesterday), so I played around with it a tiny bit this afternoon. As is,
pytest does not work out of the box since:
 1. the wrapped functions become Cell instances
2. coroutine detection through the `_is_coroutine` spec should be as an
attribute

This change enables both of these minor adjustments. Now, pytest will
pick up
marimo `test_*` functions, but they will fail since directly invoking a
cell is
currently not allowed. I put in a little unit test to fight against a
possible
regression until such time when a top level fn spec is enabled (see
#2293).

@akshayka OR @mscolnick

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dmadisetti and pre-commit-ci[bot] authored Dec 19, 2024
1 parent 24cd269 commit 50e6285
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 17 deletions.
3 changes: 3 additions & 0 deletions marimo/_ast/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

import functools
import inspect
import random
import string
Expand Down Expand Up @@ -464,6 +465,8 @@ def _register(func: Callable[..., Any]) -> Cell:
)
cell._cell.configure(cell_config)
self._register_cell(cell, app=app)
# NB. in place metadata update.
functools.wraps(func)(cell)
return cell

if func is None:
Expand Down
11 changes: 6 additions & 5 deletions marimo/_ast/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def defs(self) -> set[str]:
"""The definitions made by this cell"""
return self._cell.defs

@property
def _is_coroutine(self) -> bool:
"""Whether this cell is a coroutine function.
Expand All @@ -379,15 +380,15 @@ def _help(self) -> Html:
from marimo._output.formatting import as_html
from marimo._output.md import md

signature_prefix = "Async " if self._is_coroutine() else ""
signature_prefix = "Async " if self._is_coroutine else ""
execute_str_refs = (
f"output, defs = await {self.name}.run(**refs)"
if self._is_coroutine()
if self._is_coroutine
else f"output, defs = {self.name}.run(**refs)"
)
execute_str_no_refs = (
f"output, defs = await {self.name}.run()"
if self._is_coroutine()
if self._is_coroutine
else f"output, defs = {self.name}.run()"
)

Expand Down Expand Up @@ -524,15 +525,15 @@ def add(mo, x, y):
from the cell's defined names to their values.
"""
assert self._app is not None
if self._is_coroutine():
if self._is_coroutine:
return self._app.run_cell_async(cell=self, kwargs=refs)
else:
return self._app.run_cell_sync(cell=self, kwargs=refs)

def __call__(self, *args: Any, **kwargs: Any) -> None:
del args
del kwargs
if self._is_coroutine():
if self._is_coroutine:
call_str = f"`outputs, defs = await {self.name}.run()`"
else:
call_str = f"`outputs, defs = {self.name}.run()`"
Expand Down
2 changes: 1 addition & 1 deletion marimo/_runtime/app/script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(self) -> RunOutput:
"please raise an issue."
)

if cell._is_coroutine():
if cell._is_coroutine:
is_async = True
break

Expand Down
6 changes: 3 additions & 3 deletions marimo/_runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def sanitize_inputs(
except TypeError as e:
raise CloneError(
f"Could not clone reference `{ref}` of type "
f"{getattr(glbls[ref], '__module__', '<module>')}."
f"{glbls[ref].__class__.__name__}"
" try wrapping the object in a `zero_copy`"
f"{getattr(glbls[ref], '__module__', '<module>')}. "
f"{glbls[ref].__class__.__name__} "
"try wrapping the object in a `zero_copy` "
"call. If this is a common object type, consider "
"making an issue on the marimo GitHub "
"repository to never deepcopy."
Expand Down
4 changes: 4 additions & 0 deletions marimo/_runtime/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import weakref
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from marimo._ast.cell import Cell
from marimo._ast.visitor import Name, VariableData

if TYPE_CHECKING:
Expand Down Expand Up @@ -130,6 +131,9 @@ def is_instance_by_name(obj: object, name: str) -> bool:


def is_unclonable_type(obj: object) -> bool:
# Cell objects in particular are hidden by functools.wraps.
if isinstance(obj, Cell):
return True
return any([is_instance_by_name(obj, name) for name in UNCLONABLE_TYPES])


Expand Down
2 changes: 1 addition & 1 deletion marimo/_server/export/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def export_as_script(
for cell in file_manager.app.cell_manager.cells():
if not cell:
continue
if cell._is_coroutine():
if cell._is_coroutine:
from click import UsageError

raise UsageError(
Expand Down
14 changes: 7 additions & 7 deletions tests/_ast/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def f() -> tuple[int]:
assert cell.name == "f"
assert not cell.refs
assert cell.defs == set(["x"])
assert not cell._is_coroutine()
assert not cell._is_coroutine
assert cell.run() == ("output", {"x": 4})

@staticmethod
Expand All @@ -39,7 +39,7 @@ async def f(asyncio) -> tuple[int]:
assert cell.name == "f"
assert cell.refs == {"asyncio"}
assert cell.defs == {"x"}
assert cell._is_coroutine()
assert cell._is_coroutine

import asyncio

Expand Down Expand Up @@ -118,10 +118,10 @@ def h(x):
y = x
return (y,)

assert g._is_coroutine()
assert g._is_coroutine
# h is a coroutine because it depends on the execution of an async
# function
assert h._is_coroutine()
assert h._is_coroutine

@staticmethod
def test_async_chain() -> None:
Expand All @@ -143,9 +143,9 @@ def h(y):
z = y
return (z,)

assert f._is_coroutine()
assert g._is_coroutine()
assert h._is_coroutine()
assert f._is_coroutine
assert g._is_coroutine
assert h._is_coroutine

@staticmethod
def test_empty_cell() -> None:
Expand Down
21 changes: 21 additions & 0 deletions tests/_runtime/test_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

import pytest

import marimo

app = marimo.App()


@pytest.mark.xfail(
reason=(
"Invoking a cell is not directly supported, and as such should fail "
"until #2293. However, the decorated function _should_ be picked up "
"by pytest. The hook in conftest.py ensures this."
),
raises=RuntimeError,
strict=True,
)
@app.cell
def test_cell_is_invoked():
assert True
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Generator

import pytest
from _pytest import runner

from marimo._ast.app import CellManager
from marimo._ast.cell import CellId_t
Expand Down Expand Up @@ -491,3 +492,26 @@ def get_with_id(self, cell_id: CellId_t, code: str) -> ExecutionRequest:
@pytest.fixture
def exec_req() -> ExecReqProvider:
return ExecReqProvider()


# # A pytest hook to fail when raw marimo cells are not collected.
@pytest.hookimpl
def pytest_make_collect_report(collector):
report = runner.pytest_make_collect_report(collector)
# Defined within the file does not seem to hook correctly, as such filter
# for the test_pytest specific file here.
if "test_pytest" in str(collector.path):
collected = {fn.name: fn for fn in collector.collect()}
from tests._runtime.test_pytest import app

invalid = []
for name in app._cell_manager.names():
if name.startswith("test_") and name not in collected:
invalid.append(f"'{name}'")
if invalid:
tests = ", ".join([f"'{test}'" for test in collected.keys()])
report.outcome = "failed"
report.longrepr = (
f"Cannot collect test(s) {', '.join(invalid)} from {tests}"
)
return report

0 comments on commit 50e6285

Please sign in to comment.