Skip to content

Commit

Permalink
tidy: remove shadow work around for ast rewrite (rebase ec657dd)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmadisetti committed Jan 17, 2025
1 parent 9bd5ed9 commit 83cee0a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 46 deletions.
23 changes: 23 additions & 0 deletions marimo/_save/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ def generic_visit(self, node: ast.AST) -> tuple[ast.Module, ast.Module]: # typ
)


class MangleArguments(ast.NodeTransformer):
"""Mangles arguments names to prevent shadowing issues in analysis."""

def __init__(
self, prefix: str, args: set[str], *arg: Any, **kwargs: Any
) -> None:
super().__init__(*arg, **kwargs)
self.prefix = prefix
self.args = args

def visit_Name(self, node: ast.Name) -> ast.Name:
if node.id in self.args:
node.id = f"{self.prefix}{node.id}"
return node

def generic_visit(self, node: ast.AST) -> ast.AST:
if hasattr(node, "name") and node.name in self.args:
node.name = f"{self.prefix}{node.name}"
return super().generic_visit(node)


class DeprivateVisitor(ast.NodeTransformer):
"""Removes the mangling of private variables from a module."""

Expand All @@ -180,12 +201,14 @@ def visit_Return(self, node: ast.Return) -> ast.Expr:

def strip_function(fn: Callable[..., Any]) -> ast.Module:
code, _ = inspect.getsourcelines(fn)
args = list(fn.__code__.co_varnames)
function_ast = ast.parse(textwrap.dedent("".join(code)))
body = function_ast.body.pop()
assert isinstance(body, (ast.FunctionDef, ast.AsyncFunctionDef)), (
"Expected a function definition"
)
extracted = ast.Module(body.body, type_ignores=[])
module = RemoveReturns().visit(extracted)
module = MangleArguments("*", args).visit(module)
assert isinstance(module, ast.Module), "Expected a module"
return module
20 changes: 0 additions & 20 deletions marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ class SerialRefs(NamedTuple):
stateful_refs: set[Name]


class ShadowedRef:
"""Stub for scoped variables that may shadow global references"""


def hash_module(
code: Optional[CodeType], hash_type: str = DEFAULT_HASH
) -> bytes:
Expand Down Expand Up @@ -753,22 +749,6 @@ def serialize_and_dequeue_stateful_content_refs(
refs, inclusive=False
)

for ref in transitive_state_refs:
if ref in scope and isinstance(scope[ref], ShadowedRef):
# TODO(akshayka, dmadisetti): Lift this restriction once
# function args are rewritten.
#
# This makes more sense as a NameError, but the marimo's
# explainer text for NameError's doesn't make sense in this
# context. ("Definition expected in ...")
raise RuntimeError(
f"The cached function declares an argument '{ref}'"
"but a captured function or class uses the "
f"global variable '{ref}'. Please rename "
"the argument, or restructure the use "
f"of the global variable."
)

# Filter for relevant stateful cases.
refs |= set(
filter(
Expand Down
17 changes: 5 additions & 12 deletions marimo/_save/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from marimo._save.hash import (
DEFAULT_HASH,
BlockHasher,
ShadowedRef,
cache_attempt_from_hash,
content_cache_attempt_from_base,
)
Expand Down Expand Up @@ -115,21 +114,13 @@ def _set_context(self, fn: Callable[..., Any]) -> None:
# checking a single frame- should be good enough.
f_locals = inspect.stack()[2 + self._frame_offset][0].f_locals
self.scope = {**ctx.globals, **f_locals}
# In case scope shadows variables
#
# TODO(akshayka, dmadisetti): rewrite function args with an AST pass
# to make them unique, deterministically based on function body; this
# will allow for lifting the error when a ShadowedRef is also used
# as a regular ref.
for arg in self._args:
self.scope[arg] = ShadowedRef()

# Scoped refs are references particular to this block, that may not be
# defined out of the context of the block, or the cell.
# For instance, the args of the invoked function are restricted to the
# block.
cell_id = ctx.cell_id or ctx.execution_context.cell_id or ""
self.scoped_refs = set(self._args)
self.scoped_refs = set([f"*{k}" for k in self._args])
# # As are the "locals" not in globals
self.scoped_refs |= set(f_locals.keys()) - set(ctx.globals.keys())
# Defined in the cell, and currently available in scope
Expand Down Expand Up @@ -184,16 +175,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
self._set_context(args[0])
return self

# Rewrite scoped args to prevent shadowed variables
arg_dict = {f"*{k}": v for (k, v) in zip(self._args, args)}
kwargs = {"*{k}": v for (k, v) in kwargs.items()}
# Capture the call case
arg_dict = {k: v for (k, v) in zip(self._args, args)}
scope = {**self.scope, **get_context().globals, **arg_dict, **kwargs}
assert self._loader is not None, UNEXPECTED_FAILURE_BOILERPLATE
attempt = content_cache_attempt_from_base(
self.base_block,
scope,
self._loader(),
scoped_refs=self.scoped_refs,
required_refs=set(self._args),
required_refs=set([f"*{k}" for k in self._args]),
as_fn=True,
)

Expand Down
68 changes: 54 additions & 14 deletions tests/_save/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,9 @@ def call(v):
from tests._save.mocks import MockLoader

_loader = MockLoader()
with persistent_cache("else", _loader=_loader) as cache:
call(False)
# fmt: off
with persistent_cache("else", _loader=_loader) as cache: call(False)
# fmt: on

with pytest.raises(BlockException):
app.run()
Expand Down Expand Up @@ -1194,11 +1195,45 @@ def __(v):
app.run()

@staticmethod
def test_transitive_shadowed_state_fails() -> None:
def test_internal_shadowed() -> None:
app = App()
app._anonymous_file = True

# Add a unit test to denote a known failure case
@app.cell
def __():
import marimo as mo

return (mo,)

@app.cell
def __(mo):
state0, set_state0 = mo.state(1)
state1, set_state1 = mo.state(1)
state2, set_state2 = mo.state(10)

state, set_state = mo.state(100)

@mo.cache
def h(state):
x = state()
def g():
global state
def f(state):
return x + state()
return state() + g(state2)
return g()

assert g(state0) == 111
assert g.hits == 0
assert g(state1) == 111
assert g.hits == 1

app.run()

@staticmethod
def test_transitive_shadowed_state_passes() -> None:
app = App()
app._anonymous_file = True

@app.cell
def __():
Expand All @@ -1208,29 +1243,33 @@ def __():

@app.cell
def __(mo):
state0, set_state0 = mo.state(1)
state1, set_state1 = mo.state(1)
state2, set_state2 = mo.state(2)
state2, set_state2 = mo.state(10)

state, set_state = mo.state(3)
state, set_state = mo.state(100)

# Example of a case where things start to get very tricky. There
# comes a point where you might also have to capture frame levels
# as well if you mix scope.
#
# This is solved by throwing an exception when state
# shadowing occurs.
def f():
# This is solved by rewriting potential name collisions
def h(state):
return state()

def f():
return state() + h(state2)

@mo.cache
def g(state):
return state() + f()

# Cannot resolved shadowed ref.
with pytest.raises(RuntimeError) as e:
app.run()
assert g(state0) == 1111
assert g.hits == 0
assert g(state1) == 1111
assert g.hits == 1

assert "rename the argument" in str(e)
app.run()

@staticmethod
def test_shadowed_state_mismatch() -> None:
Expand All @@ -1248,6 +1287,7 @@ def __(mo):
state1, set_state1 = mo.state(1)
state2, set_state2 = mo.state(2)

# Here as a var for shadowing
state, set_state = mo.state(3)

@mo.cache
Expand All @@ -1256,7 +1296,7 @@ def g(state):

a = g(state1)
b = g(state2)

assert g.hits == 0
A = g(state1)
B = g(state2)
assert g.hits == 2
Expand Down

0 comments on commit 83cee0a

Please sign in to comment.