diff --git a/marimo/_runtime/dataflow.py b/marimo/_runtime/dataflow.py index 8eb232b89fb..696d81e78f5 100644 --- a/marimo/_runtime/dataflow.py +++ b/marimo/_runtime/dataflow.py @@ -601,7 +601,7 @@ def _get_ancestors( ) -> set[CellId_t]: # Get the transitive closure of parents defining unsubstituted refs graph = self._graph - substitutions = set(kwargs.values()) + substitutions = set(kwargs.keys()) unsubstituted_refs = cell_impl.refs - substitutions parent_ids = set( [ diff --git a/tests/_ast/cell_data/named_cells.py b/tests/_ast/cell_data/named_cells.py index 0fdb11dce0e..10c94a2c9d3 100644 --- a/tests/_ast/cell_data/named_cells.py +++ b/tests/_ast/cell_data/named_cells.py @@ -23,5 +23,26 @@ def h(y): return (z,) +@app.cell +def unhashable_defined(): + unhashable = {0, 1, 2} + unhashable + return (unhashable,) + + +@app.cell +def unhashable_override_required(unhashable): + assert unhashable == {0, 1} + unhashable + return + + +@app.cell +def multiple(): + A = 0 + B = 1 + (A, B) + return (A, B) + if __name__ == "__main__": app.run() diff --git a/tests/_ast/test_cell.py b/tests/_ast/test_cell.py index 80924b15259..62c768a75d6 100644 --- a/tests/_ast/test_cell.py +++ b/tests/_ast/test_cell.py @@ -187,6 +187,35 @@ def test_import() -> None: assert g.run(x=1) == (None, {"y": 2}) assert h.run(y=2) == (3, {"z": 3}) + @staticmethod + def test_unhashable_import() -> None: + from cell_data.named_cells import ( + unhashable_defined, + unhashable_override_required, + ) + + assert unhashable_defined.name == "unhashable_defined" + assert ( + unhashable_override_required.name == "unhashable_override_required" + ) + + assert unhashable_override_required.run(unhashable={0, 1}) == ( + {0, 1}, + {}, + ) + assert unhashable_defined.run() == ( + {0, 1, 2}, + {"unhashable": {0, 1, 2}}, + ) + + @staticmethod + def test_direct_call() -> None: + from cell_data.named_cells import h, multiple, unhashable_defined + + assert h(1) == 2 + assert multiple() == (0, 1) + assert unhashable_defined() == {0, 1, 2} + def help_smoke() -> None: app = App()