diff --git a/frontend/src/core/cells/__tests__/cells.test.ts b/frontend/src/core/cells/__tests__/cells.test.ts index d9c501e0977..a5be6ca6361 100644 --- a/frontend/src/core/cells/__tests__/cells.test.ts +++ b/frontend/src/core/cells/__tests__/cells.test.ts @@ -1331,6 +1331,32 @@ describe("cell reducer", () => { }); }); + it("can set cell codes with new cell ids, while preserving the old cell data", () => { + actions.setCellCodes({ + codes: ["code1", "code2", "code3"], + ids: ["3", "4", "5"] as CellId[], + codeIsStale: false, + }); + expect(state.cellData["3" as CellId].code).toBe("code1"); + expect(state.cellData["4" as CellId].code).toBe("code2"); + expect(state.cellData["5" as CellId].code).toBe("code3"); + + // Update with some new cell ids and some old cell ids + actions.setCellIds({ cellIds: ["1", "2", "3", "4"] as CellId[] }); + actions.setCellCodes({ + codes: ["new1", "new2", "code1", "code2"], + ids: ["1", "2", "3", "4"] as CellId[], + codeIsStale: false, + }); + expect(state.cellData["1" as CellId].code).toBe("new1"); + expect(state.cellData["2" as CellId].code).toBe("new2"); + expect(state.cellData["3" as CellId].code).toBe("code1"); + expect(state.cellData["4" as CellId].code).toBe("code2"); + expect(state.cellIds.inOrderIds).toEqual(["1", "2", "3", "4"]); + // Cell 5 data is preserved (possibly used for tracing), but it's not in the cellIds + expect(state.cellData["5" as CellId]).not.toBeUndefined(); + }); + it("can fold and unfold all cells", () => { actions.foldAll(); expect(foldAllBulk).toHaveBeenCalled(); diff --git a/frontend/src/core/cells/cells.ts b/frontend/src/core/cells/cells.ts index 6e6ba2a94df..333833827d1 100644 --- a/frontend/src/core/cells/cells.ts +++ b/frontend/src/core/cells/cells.ts @@ -789,36 +789,56 @@ const { "Expected codes and ids to have the same length", ); - for (let i = 0; i < action.codes.length; i++) { - const cellId = action.ids[i]; - const code = action.codes[i]; + let nextState = { ...state }; + + const cellReducer = ( + cell: CellData | undefined, + code: string, + cellId: CellId, + ) => { + if (!cell) { + return createCell({ id: cellId, code }); + } - state = updateCellData(state, cellId, (cell) => { - // No change - if (cell.code.trim() === code.trim()) { - return cell; - } + // No change + if (cell.code.trim() === code.trim()) { + return cell; + } - // Update codemirror if mounted - const cellHandle = state.cellHandles[cellId].current; - if (cellHandle?.editorView) { - updateEditorCodeFromPython(cellHandle.editorView, code); - } + // Update codemirror if mounted + const cellHandle = nextState.cellHandles[cellId].current; + if (cellHandle?.editorView) { + updateEditorCodeFromPython(cellHandle.editorView, code); + } - // If code is stale, we don't promote it to lastCodeRun - const lastCodeRun = action.codeIsStale ? cell.lastCodeRun : code; + // If code is stale, we don't promote it to lastCodeRun + const lastCodeRun = action.codeIsStale ? cell.lastCodeRun : code; - return { - ...cell, - code: code, - // Mark as edited if the code has changed - edited: lastCodeRun ? lastCodeRun.trim() !== code.trim() : false, - lastCodeRun, - }; - }); + return { + ...cell, + code: code, + // Mark as edited if the code has changed + edited: lastCodeRun + ? lastCodeRun.trim() !== code.trim() + : Boolean(code), + lastCodeRun, + }; + }; + + for (let i = 0; i < action.codes.length; i++) { + const cellId = action.ids[i]; + const code = action.codes[i]; + + nextState = { + ...nextState, + cellData: { + ...nextState.cellData, + [cellId]: cellReducer(nextState.cellData[cellId], code, cellId), + }, + }; } - return state; + return nextState; }, setStdinResponse: ( state, diff --git a/marimo/_ast/app.py b/marimo/_ast/app.py index 2eaf84e07a3..dd051b01382 100644 --- a/marimo/_ast/app.py +++ b/marimo/_ast/app.py @@ -1,11 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -import functools import inspect -import os -import random -import string from dataclasses import asdict, dataclass, field from typing import ( TYPE_CHECKING, @@ -22,15 +18,13 @@ from marimo import _loggers from marimo._ast.cell import Cell, CellConfig, CellId_t -from marimo._ast.compiler import cell_factory +from marimo._ast.cell_manager import CellManager from marimo._ast.errors import ( CycleError, DeleteNonlocalError, MultipleDefinitionError, UnparsableError, ) -from marimo._ast.names import DEFAULT_CELL_NAME -from marimo._ast.pytest import wrap_fn_for_pytest from marimo._config.config import WidthType from marimo._messaging.mimetypes import KnownMimeType from marimo._output.hypertext import Html @@ -106,22 +100,6 @@ def update(self, updates: dict[str, Any]) -> _AppConfig: return self -@dataclass -class CellData: - """A cell together with some metadata""" - - cell_id: CellId_t - # User code comprising the cell - code: str - # User-provided name for cell (or default) - name: str - # Cell config - config: CellConfig - - # The original cell, or None if cell was not parsable - cell: Optional[Cell] - - class _Namespace(Mapping[str, object]): def __init__( self, dictionary: dict[str, object], owner: Cell | App @@ -427,196 +405,6 @@ async def embed(self) -> AppEmbedResult: ) -class CellManager: - """ - A manager for cells. - - This holds the cells that have been registered with the app, and - provides methods to access them. - """ - - def __init__(self, prefix: str = "") -> None: - self._cell_data: dict[CellId_t, CellData] = {} - self.prefix = prefix - self.unparsable = False - self.random_seed = random.Random(42) - - def create_cell_id(self) -> CellId_t: - # 4 random letters - return self.prefix + "".join( - self.random_seed.choices(string.ascii_letters, k=4) - ) - - def cell_decorator( - self, - func: Callable[..., Any] | None, - column: Optional[int], - disabled: bool, - hide_code: bool, - app: InternalApp | None = None, - ) -> Cell | Callable[..., Cell]: - cell_config = CellConfig( - column=column, disabled=disabled, hide_code=hide_code - ) - - def _register(func: Callable[..., Any]) -> Cell: - # Use PYTEST_VERSION here, opposed to PYTEST_CURRENT_TEST, in - # order to allow execution during test collection. - is_top_level_pytest = ( - "PYTEST_VERSION" in os.environ - and "PYTEST_CURRENT_TEST" not in os.environ - ) - cell = cell_factory( - func, - cell_id=self.create_cell_id(), - anonymous_file=app._app._anonymous_file if app else False, - test_rewrite=is_top_level_pytest, - ) - cell._cell.configure(cell_config) - self._register_cell(cell, app=app) - # Manually set the signature for pytest. - if is_top_level_pytest: - func = wrap_fn_for_pytest(func, cell) - # NB. in place metadata update. - functools.wraps(func)(cell) - return cell - - if func is None: - # If the decorator was used with parentheses, func will be None, - # and we return a decorator that takes the decorated function as an - # argument - def decorator(func: Callable[..., Any]) -> Cell: - return _register(func) - - return decorator - else: - return _register(func) - - def _register_cell( - self, cell: Cell, app: InternalApp | None = None - ) -> None: - if app is not None: - cell._register_app(app) - cell_impl = cell._cell - self.register_cell( - cell_id=cell_impl.cell_id, - code=cell_impl.code, - name=cell.name, - config=cell_impl.config, - cell=cell, - ) - - def register_cell( - self, - cell_id: Optional[CellId_t], - code: str, - config: Optional[CellConfig], - name: str = DEFAULT_CELL_NAME, - cell: Optional[Cell] = None, - ) -> None: - if cell_id is None: - cell_id = self.create_cell_id() - - self._cell_data[cell_id] = CellData( - cell_id=cell_id, - code=code, - name=name, - config=config or CellConfig(), - cell=cell, - ) - - def register_unparsable_cell( - self, - code: str, - name: Optional[str], - cell_config: CellConfig, - ) -> None: - # - code.split("\n")[1:-1] disregards first and last lines, which are - # empty - # - line[4:] removes leading indent in multiline string - # - replace(...) unescapes double quotes - # - rstrip() removes an extra newline - code = "\n".join( - [line[4:].replace('\\"', '"') for line in code.split("\n")[1:-1]] - ) - - self.register_cell( - cell_id=self.create_cell_id(), - code=code, - config=cell_config, - name=name or DEFAULT_CELL_NAME, - cell=None, - ) - - def ensure_one_cell(self) -> None: - if not self._cell_data: - cell_id = self.create_cell_id() - self.register_cell( - cell_id=cell_id, - code="", - config=CellConfig(), - ) - - def cell_name(self, cell_id: CellId_t) -> str: - return self._cell_data[cell_id].name - - def names(self) -> Iterable[str]: - for cell_data in self._cell_data.values(): - yield cell_data.name - - def codes(self) -> Iterable[str]: - for cell_data in self._cell_data.values(): - yield cell_data.code - - def configs(self) -> Iterable[CellConfig]: - for cell_data in self._cell_data.values(): - yield cell_data.config - - def valid_cells( - self, - ) -> Iterable[tuple[CellId_t, Cell]]: - """Return cells and functions for each valid cell.""" - for cell_data in self._cell_data.values(): - if cell_data.cell is not None: - yield (cell_data.cell_id, cell_data.cell) - - def valid_cell_ids(self) -> Iterable[CellId_t]: - for cell_data in self._cell_data.values(): - if cell_data.cell is not None: - yield cell_data.cell_id - - def cell_ids(self) -> Iterable[CellId_t]: - """Cell IDs in the order they were registered.""" - return self._cell_data.keys() - - def has_cell(self, cell_id: CellId_t) -> bool: - return cell_id in self._cell_data - - def cells( - self, - ) -> Iterable[Optional[Cell]]: - for cell_data in self._cell_data.values(): - yield cell_data.cell - - def config_map(self) -> dict[CellId_t, CellConfig]: - return {cid: cd.config for cid, cd in self._cell_data.items()} - - def cell_data(self) -> Iterable[CellData]: - return self._cell_data.values() - - def cell_data_at(self, cell_id: CellId_t) -> CellData: - return self._cell_data[cell_id] - - def get_cell_id_by_code(self, code: str) -> Optional[CellId_t]: - """ - Finds the first cell with the given code and returns its cell ID. - """ - for cell_id, cell_data in self._cell_data.items(): - if cell_data.code == code: - return cell_id - return None - - class InternalApp: """ Internal representation of an app. diff --git a/marimo/_ast/cell.py b/marimo/_ast/cell.py index 863f0abcb88..2ef2bfa1ee5 100644 --- a/marimo/_ast/cell.py +++ b/marimo/_ast/cell.py @@ -605,7 +605,3 @@ class SourcePosition: filename: str lineno: int col_offset: int - - -def is_ws(char: str) -> bool: - return char == " " or char == "\n" or char == "\t" diff --git a/marimo/_ast/cell_manager.py b/marimo/_ast/cell_manager.py new file mode 100644 index 00000000000..8c71f51bc5a --- /dev/null +++ b/marimo/_ast/cell_manager.py @@ -0,0 +1,494 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +import functools +import os +import random +import string +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Optional, +) + +from marimo._ast.cell import Cell, CellConfig, CellId_t +from marimo._ast.compiler import cell_factory +from marimo._ast.models import CellData +from marimo._ast.names import DEFAULT_CELL_NAME +from marimo._ast.pytest import wrap_fn_for_pytest + +if TYPE_CHECKING: + from marimo._ast.app import InternalApp + + +class CellManager: + """A manager for cells in a marimo notebook. + + The CellManager is responsible for: + 1. Creating and managing unique cell IDs + 2. Registering and storing cell data (code, configuration, etc.) + 3. Providing access to cell information through various queries + 4. Managing both valid (parsable) and unparsable cells + 5. Handling cell decorators for the notebook interface + + Attributes: + prefix (str): A prefix added to all cell IDs managed by this instance + unparsable (bool): Flag indicating if any unparsable cells were encountered + random_seed (random.Random): Seeded random number generator for deterministic cell ID creation + """ + + def __init__(self, prefix: str = "") -> None: + """Initialize a new CellManager. + + Args: + prefix (str, optional): Prefix to add to all cell IDs. Defaults to "". + """ + self._cell_data: dict[CellId_t, CellData] = {} + self.prefix = prefix + self.unparsable = False + self.random_seed = random.Random(42) + self.seen_ids: set[CellId_t] = set() + + def create_cell_id(self) -> CellId_t: + """Create a new unique cell ID. + + Returns: + CellId_t: A new cell ID consisting of the manager's prefix followed by 4 random letters. + """ + # 4 random letters + _id = self.prefix + "".join( + self.random_seed.choices(string.ascii_letters, k=4) + ) + while _id in self.seen_ids: + _id = self.prefix + "".join( + self.random_seed.choices(string.ascii_letters, k=4) + ) + self.seen_ids.add(_id) + return _id + + # TODO: maybe remove this, it is leaky + def cell_decorator( + self, + func: Callable[..., Any] | None, + column: Optional[int], + disabled: bool, + hide_code: bool, + app: InternalApp | None = None, + ) -> Cell | Callable[..., Cell]: + """Create a cell decorator for marimo notebook cells.""" + cell_config = CellConfig( + column=column, disabled=disabled, hide_code=hide_code + ) + + def _register(func: Callable[..., Any]) -> Cell: + # Use PYTEST_VERSION here, opposed to PYTEST_CURRENT_TEST, in + # order to allow execution during test collection. + is_top_level_pytest = ( + "PYTEST_VERSION" in os.environ + and "PYTEST_CURRENT_TEST" not in os.environ + ) + cell = cell_factory( + func, + cell_id=self.create_cell_id(), + anonymous_file=app._app._anonymous_file if app else False, + test_rewrite=is_top_level_pytest, + ) + cell._cell.configure(cell_config) + self._register_cell(cell, app=app) + # Manually set the signature for pytest. + if is_top_level_pytest: + func = wrap_fn_for_pytest(func, cell) + # NB. in place metadata update. + functools.wraps(func)(cell) + return cell + + if func is None: + # If the decorator was used with parentheses, func will be None, + # and we return a decorator that takes the decorated function as an + # argument + def decorator(func: Callable[..., Any]) -> Cell: + return _register(func) + + return decorator + else: + return _register(func) + + def _register_cell( + self, cell: Cell, app: InternalApp | None = None + ) -> None: + if app is not None: + cell._register_app(app) + cell_impl = cell._cell + self.register_cell( + cell_id=cell_impl.cell_id, + code=cell_impl.code, + name=cell.name, + config=cell_impl.config, + cell=cell, + ) + + def register_cell( + self, + cell_id: Optional[CellId_t], + code: str, + config: Optional[CellConfig], + name: str = DEFAULT_CELL_NAME, + cell: Optional[Cell] = None, + ) -> None: + """Register a new cell with the manager. + + Args: + cell_id: Unique identifier for the cell. If None, one will be generated + code: The cell's source code + config: Cell configuration (column, disabled state, etc.) + name: Name of the cell, defaults to DEFAULT_CELL_NAME + cell: Optional Cell object for valid cells + """ + if cell_id is None: + cell_id = self.create_cell_id() + + self._cell_data[cell_id] = CellData( + cell_id=cell_id, + code=code, + name=name, + config=config or CellConfig(), + cell=cell, + ) + + def register_unparsable_cell( + self, + code: str, + name: Optional[str], + cell_config: CellConfig, + ) -> None: + """Register a cell that couldn't be parsed. + + Handles code formatting and registration of cells that couldn't be parsed + into valid Python code. + + Args: + code: The unparsable code string + name: Optional name for the cell + cell_config: Configuration for the cell + """ + # - code.split("\n")[1:-1] disregards first and last lines, which are + # empty + # - line[4:] removes leading indent in multiline string + # - replace(...) unescapes double quotes + # - rstrip() removes an extra newline + code = "\n".join( + [line[4:].replace('\\"', '"') for line in code.split("\n")[1:-1]] + ) + + self.register_cell( + cell_id=self.create_cell_id(), + code=code, + config=cell_config, + name=name or DEFAULT_CELL_NAME, + cell=None, + ) + + def ensure_one_cell(self) -> None: + """Ensure at least one cell exists in the manager. + + If no cells exist, creates an empty cell with default configuration. + """ + if not self._cell_data: + cell_id = self.create_cell_id() + self.register_cell( + cell_id=cell_id, + code="", + config=CellConfig(), + ) + + def cell_name(self, cell_id: CellId_t) -> str: + """Get the name of a cell by its ID. + + Args: + cell_id: The ID of the cell + + Returns: + str: The name of the cell + + Raises: + KeyError: If the cell_id doesn't exist + """ + return self._cell_data[cell_id].name + + def names(self) -> Iterable[str]: + """Get an iterator over all cell names. + + Returns: + Iterable[str]: Iterator yielding each cell's name + """ + for cell_data in self._cell_data.values(): + yield cell_data.name + + def codes(self) -> Iterable[str]: + """Get an iterator over all cell codes. + + Returns: + Iterable[str]: Iterator yielding each cell's source code + """ + for cell_data in self._cell_data.values(): + yield cell_data.code + + def configs(self) -> Iterable[CellConfig]: + """Get an iterator over all cell configurations. + + Returns: + Iterable[CellConfig]: Iterator yielding each cell's configuration + """ + for cell_data in self._cell_data.values(): + yield cell_data.config + + def valid_cells( + self, + ) -> Iterable[tuple[CellId_t, Cell]]: + """Get an iterator over all valid (parsable) cells. + + Returns: + Iterable[tuple[CellId_t, Cell]]: Iterator yielding tuples of (cell_id, cell) + for each valid cell + """ + for cell_data in self._cell_data.values(): + if cell_data.cell is not None: + yield (cell_data.cell_id, cell_data.cell) + + def valid_cell_ids(self) -> Iterable[CellId_t]: + """Get an iterator over IDs of all valid cells. + + Returns: + Iterable[CellId_t]: Iterator yielding cell IDs of valid cells + """ + for cell_data in self._cell_data.values(): + if cell_data.cell is not None: + yield cell_data.cell_id + + def cell_ids(self) -> Iterable[CellId_t]: + """Get an iterator over all cell IDs in registration order. + + Returns: + Iterable[CellId_t]: Iterator yielding all cell IDs + """ + return self._cell_data.keys() + + def has_cell(self, cell_id: CellId_t) -> bool: + """Check if a cell with the given ID exists. + + Args: + cell_id: The ID to check + + Returns: + bool: True if the cell exists, False otherwise + """ + return cell_id in self._cell_data + + def cells( + self, + ) -> Iterable[Optional[Cell]]: + """Get an iterator over all Cell objects. + + Returns: + Iterable[Optional[Cell]]: Iterator yielding Cell objects (or None for invalid cells) + """ + for cell_data in self._cell_data.values(): + yield cell_data.cell + + def config_map(self) -> dict[CellId_t, CellConfig]: + """Get a mapping of cell IDs to their configurations. + + Returns: + dict[CellId_t, CellConfig]: Dictionary mapping cell IDs to their configurations + """ + return {cid: cd.config for cid, cd in self._cell_data.items()} + + def cell_data(self) -> Iterable[CellData]: + """Get an iterator over all cell data. + + Returns: + Iterable[CellData]: Iterator yielding CellData objects for all cells + """ + return self._cell_data.values() + + def cell_data_at(self, cell_id: CellId_t) -> CellData: + """Get the cell data for a specific cell ID. + + Args: + cell_id: The ID of the cell to get data for + + Returns: + CellData: The cell's data + + Raises: + KeyError: If the cell_id doesn't exist + """ + return self._cell_data[cell_id] + + def get_cell_id_by_code(self, code: str) -> Optional[CellId_t]: + """Find a cell ID by its code content. + + Args: + code: The code to search for + + Returns: + Optional[CellId_t]: The ID of the first cell with matching code, + or None if no match is found + """ + for cell_id, cell_data in self._cell_data.items(): + if cell_data.code == code: + return cell_id + return None + + def sort_cell_ids_by_similarity( + self, prev_cell_manager: CellManager + ) -> None: + """Sort cell IDs by similarity to the current cell manager. + + This mutates the current cell manager. + """ + prev_ids = list(prev_cell_manager.cell_ids()) + prev_codes = [data.code for data in prev_cell_manager.cell_data()] + current_ids = list(self._cell_data.keys()) + current_codes = [data.code for data in self.cell_data()] + sorted_ids = _match_cell_ids_by_similarity( + prev_ids, prev_codes, current_ids, current_codes + ) + assert len(sorted_ids) == len(list(self.cell_ids())) + + # Create mapping from new to old ids + id_mapping = dict(zip(sorted_ids, current_ids)) + + # Update the cell data in place + self._cell_data = { + new_id: self._cell_data[old_id] + for new_id, old_id in id_mapping.items() + } + + # Add the new ids to the set, so we don't reuse them in the future + for _id in sorted_ids: + self.seen_ids.add(_id) + + +def _match_cell_ids_by_similarity( + prev_ids: list[CellId_t], + prev_codes: list[str], + next_ids: list[CellId_t], + next_codes: list[str], +) -> list[CellId_t]: + """Match cell IDs based on code similarity. + + Args: + prev_ids: List of previous cell IDs, used as the set of possible IDs + prev_codes: List of previous cell codes + next_ids: List of next cell IDs, used only when more cells than prev_ids + next_codes: List of next cell codes + + Returns: + List of cell IDs matching next_codes, using prev_ids where possible + """ + assert len(prev_codes) == len(prev_ids) + assert len(next_codes) == len(next_ids) + + # Initialize result and tracking sets + result: list[Optional[CellId_t]] = [None] * len(next_codes) + used_positions: set[int] = set() + used_prev_ids: set[CellId_t] = set() + + # Track which next_ids are new (not in prev_ids) + new_next_ids = [p_id for p_id in next_ids if p_id not in prev_ids] + new_id_idx = 0 + + # First pass: exact matches using hash map + next_code_to_idx: dict[str, list[int]] = {} + for idx, code in enumerate(next_codes): + next_code_to_idx.setdefault(code, []).append(idx) + + for prev_idx, prev_code in enumerate(prev_codes): + if prev_ids[prev_idx] in used_prev_ids: + continue + if prev_code in next_code_to_idx: + # Use first available matching position + for next_idx in next_code_to_idx[prev_code]: + if next_idx not in used_positions: + result[next_idx] = prev_ids[prev_idx] + used_positions.add(next_idx) + used_prev_ids.add(prev_ids[prev_idx]) + break + + # If all positions filled, we're done + if len(used_positions) == len(next_codes): + return [_id for _id in result if _id is not None] # type: ignore + + def similarity_score(s1: str, s2: str) -> int: + """Fast similarity score based on common prefix and suffix. + Returns lower score for more similar strings.""" + # Find common prefix length + prefix_len = 0 + for c1, c2 in zip(s1, s2): + if c1 != c2: + break + prefix_len += 1 + + # Find common suffix length if strings differ in middle + if prefix_len < min(len(s1), len(s2)): + s1_rev = s1[::-1] + s2_rev = s2[::-1] + suffix_len = 0 + for c1, c2 in zip(s1_rev, s2_rev): + if c1 != c2: + break + suffix_len += 1 + else: + suffix_len = 0 + + # Return inverse similarity - shorter common affix means higher score + return len(s1) + len(s2) - 2 * (prefix_len + suffix_len) + + # Filter out used positions and ids for similarity matrix + remaining_prev_indices = [ + i for i, pid in enumerate(prev_ids) if pid not in used_prev_ids + ] + remaining_next_indices = [ + i for i in range(len(next_codes)) if i not in used_positions + ] + + # Create similarity matrix only for remaining cells + similarity_matrix: list[list[int]] = [] + for prev_idx in remaining_prev_indices: + row: list[int] = [] + for next_idx in remaining_next_indices: + score = similarity_score( + prev_codes[prev_idx], next_codes[next_idx] + ) + row.append(score) + similarity_matrix.append(row) + + # Second pass: best matches for remaining positions + for matrix_prev_idx, prev_idx in enumerate(remaining_prev_indices): + # Find best match among unused positions + min_score = float("inf") # type: ignore + best_next_matrix_idx = None + for matrix_next_idx, score in enumerate( + similarity_matrix[matrix_prev_idx] + ): + if score < min_score: + min_score = score + best_next_matrix_idx = matrix_next_idx + + if best_next_matrix_idx is not None: + next_idx = remaining_next_indices[best_next_matrix_idx] + result[next_idx] = prev_ids[prev_idx] + used_positions.add(next_idx) + used_prev_ids.add(prev_ids[prev_idx]) + + # Fill remaining positions with new next_ids + for i in range(len(next_codes)): + if result[i] is None: + if new_id_idx < len(new_next_ids): + result[i] = new_next_ids[new_id_idx] + new_id_idx += 1 + + return [_id for _id in result if _id is not None] # type: ignore diff --git a/marimo/_ast/codegen.py b/marimo/_ast/codegen.py index 196bdb823d2..16c31663f89 100644 --- a/marimo/_ast/codegen.py +++ b/marimo/_ast/codegen.py @@ -215,7 +215,21 @@ class MarimoFileError(Exception): def get_app(filename: Optional[str]) -> Optional[App]: - """Load and return app from a marimo-generated module""" + """Load and return app from a marimo-generated module. + + Args: + filename: Path to a marimo notebook file (.py or .md) + + Returns: + The marimo App instance if the file exists and contains valid code, + None if the file is empty or contains only comments. + + Raises: + MarimoFileError: If the file exists but doesn't define a valid marimo app + RuntimeError: If there are issues loading the module + SyntaxError: If the file contains a syntax error + FileNotFoundError: If the file doesn't exist + """ if filename is None: return None @@ -247,7 +261,7 @@ def get_app(filename: Optional[str]) -> Optional[App]: marimo_app = importlib.util.module_from_spec(spec) if spec.loader is None: raise RuntimeError("Failed to load module spec's loader") - spec.loader.exec_module(marimo_app) + spec.loader.exec_module(marimo_app) # This may throw a SyntaxError if not hasattr(marimo_app, "app"): raise MarimoFileError(f"{filename} missing attribute `app`.") if not isinstance(marimo_app.app, App): diff --git a/marimo/_ast/models.py b/marimo/_ast/models.py new file mode 100644 index 00000000000..8fd3f22bbe0 --- /dev/null +++ b/marimo/_ast/models.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from marimo._ast.cell import Cell, CellConfig, CellId_t + + +@dataclass +class CellData: + """A cell together with its metadata. + + This class bundles a cell with its associated metadata like ID, code, name and config. + It represents both valid cells that can be executed and invalid/unparsable cells. + + Attributes: + cell_id: Unique identifier for the cell + code: Raw source code text of the cell + name: User-provided name for the cell, or a default if none provided + config: Configuration options for the cell like column placement, disabled state, etc. + cell: The compiled Cell object if code is valid, None if code couldn't be parsed + """ + + cell_id: CellId_t + code: str + name: str + config: CellConfig + cell: Optional[Cell] diff --git a/marimo/_cli/cli.py b/marimo/_cli/cli.py index 187b6358948..e60eeaa85e1 100644 --- a/marimo/_cli/cli.py +++ b/marimo/_cli/cli.py @@ -52,6 +52,22 @@ def helpful_usage_error(self: Any, file: Any = None) -> None: click.echo(self.ctx.get_help(), file=file, color=color) +def check_app_correctness(filename: str) -> None: + try: + codegen.get_app(filename) + except SyntaxError: + import traceback + + # This prints a more readable error message, without internal details + # e.g. + # Error: File "/my/bad/file.py", line 17 + # x. + # ^ + # SyntaxError: invalid syntax + click.echo(f"Failed to parse notebook: {filename}\n", err=True) + raise click.ClickException(traceback.format_exc(limit=0)) from None + + click.exceptions.UsageError.show = helpful_usage_error # type: ignore @@ -351,7 +367,7 @@ def edit( if os.path.exists(name) and not is_dir: # module correctness check - don't start the server # if we can't import the module - codegen.get_app(name) + check_app_correctness(name) elif not is_dir: # write empty file try: @@ -640,7 +656,7 @@ def run( name, _ = validate_name(name, allow_new_file=False, allow_directory=False) # correctness check - don't start the server if we can't import the module - codegen.get_app(name) + check_app_correctness(name) start( file_router=AppFileRouter.from_filename(MarimoPath(name)), diff --git a/marimo/_messaging/ops.py b/marimo/_messaging/ops.py index 681e0776043..1e015fc83b1 100644 --- a/marimo/_messaging/ops.py +++ b/marimo/_messaging/ops.py @@ -639,6 +639,8 @@ class UpdateCellCodes(Op): name: ClassVar[str] = "update-cell-codes" cell_ids: List[CellId_t] codes: List[str] + # If true, this means the code was not run on the backend when updating + # the cell codes. code_is_stale: bool diff --git a/marimo/_server/file_manager.py b/marimo/_server/file_manager.py index 2dcb713b01d..6e38f14240b 100644 --- a/marimo/_server/file_manager.py +++ b/marimo/_server/file_manager.py @@ -42,7 +42,9 @@ def from_app(app: InternalApp) -> AppFileManager: def reload(self) -> None: """Reload the app from the file.""" + prev_cell_manager = self.app.cell_manager self.app = self._load_app(self.path) + self.app.cell_manager.sort_cell_ids_by_similarity(prev_cell_manager) def _is_same_path(self, filename: str) -> bool: if self.filename is None: diff --git a/marimo/_server/session/session_view.py b/marimo/_server/session/session_view.py index e8a2706cd47..c586854f0d3 100644 --- a/marimo/_server/session/session_view.py +++ b/marimo/_server/session/session_view.py @@ -13,6 +13,7 @@ Datasets, Interrupted, MessageOperation, + UpdateCellCodes, UpdateCellIdsRequest, Variables, VariableValue, @@ -52,6 +53,8 @@ def __init__(self) -> None: self.last_executed_code: dict[CellId_t, str] = {} # Map of cell id to the last cell execution time self.last_execution_time: dict[CellId_t, float] = {} + # Any stale code that was read from a file-watcher + self.stale_code: Optional[UpdateCellCodes] = None # Auto-saving self.has_auto_exported_html = False @@ -172,6 +175,11 @@ def add_operation(self, operation: MessageOperation) -> None: elif isinstance(operation, UpdateCellIdsRequest): self.cell_ids = operation + elif ( + isinstance(operation, UpdateCellCodes) and operation.code_is_stale + ): + self.stale_code = operation + def get_cell_outputs( self, ids: list[CellId_t] ) -> dict[CellId_t, CellOutput]: @@ -226,6 +234,8 @@ def operations(self) -> list[MessageOperation]: if self.datasets.tables: all_ops.append(self.datasets) all_ops.extend(self.cell_operations.values()) + if self.stale_code: + all_ops.append(self.stale_code) return all_ops def mark_auto_export_html(self) -> None: diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 4380a022577..a6542b53600 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -41,6 +41,7 @@ MessageOperation, Reload, UpdateCellCodes, + UpdateCellIdsRequest, ) from marimo._messaging.types import KernelMessage from marimo._output.formatters.formatters import register_formatters @@ -828,7 +829,11 @@ async def on_file_changed(path: Path) -> None: cell_ids = list( session.app_file_manager.app.cell_manager.cell_ids() ) - # Send the updated codes to the frontend + # Send the updated cell ids and codes to the frontend + session.write_operation( + UpdateCellIdsRequest(cell_ids=cell_ids), + from_consumer_id=None, + ) session.write_operation( UpdateCellCodes( cell_ids=cell_ids, diff --git a/tests/_ast/test_cell_manager.py b/tests/_ast/test_cell_manager.py new file mode 100644 index 00000000000..dd2eb401070 --- /dev/null +++ b/tests/_ast/test_cell_manager.py @@ -0,0 +1,402 @@ +from __future__ import annotations + +import pytest + +from marimo._ast.cell import Cell, CellConfig +from marimo._ast.cell_manager import CellManager, _match_cell_ids_by_similarity +from marimo._ast.compiler import compile_cell +from marimo._ast.names import DEFAULT_CELL_NAME + + +@pytest.fixture +def cell_manager(): + return CellManager(prefix="test_") + + +def test_create_cell_id(cell_manager: CellManager) -> None: + # Test deterministic behavior with fixed seed + cell_id1 = cell_manager.create_cell_id() + cell_id2 = cell_manager.create_cell_id() + + assert cell_id1.startswith("test_") + assert len(cell_id1) == 9 # "test_" + 4 random letters + assert cell_id1 != cell_id2 + + +def test_register_cell(cell_manager: CellManager) -> None: + cell_id = "test_cell" + code = "print('hello')" + config = CellConfig() + + cell_manager.register_cell( + cell_id=cell_id, + code=code, + config=config, + name=DEFAULT_CELL_NAME, + ) + + assert cell_manager.has_cell(cell_id) + cell_data = cell_manager.cell_data_at(cell_id) + assert cell_data.code == code + assert cell_data.config == config + assert cell_data.name == DEFAULT_CELL_NAME + + +def test_register_cell_auto_id(cell_manager: CellManager) -> None: + code = "print('hello')" + config = CellConfig() + + cell_manager.register_cell( + cell_id=None, + code=code, + config=config, + ) + + # Should have created one cell with an auto-generated ID + assert len(list(cell_manager.cell_ids())) == 1 + cell_id = next(iter(cell_manager.cell_ids())) + assert cell_id.startswith("test_") + + +def test_ensure_one_cell(cell_manager: CellManager) -> None: + assert len(list(cell_manager.cell_ids())) == 0 + cell_manager.ensure_one_cell() + assert len(list(cell_manager.cell_ids())) == 1 + + # Calling again shouldn't add another cell + cell_manager.ensure_one_cell() + assert len(list(cell_manager.cell_ids())) == 1 + + +def test_cell_queries(cell_manager: CellManager) -> None: + cell_id1 = "test_cell1" + cell_id2 = "test_cell2" + code1 = "print('hello')" + code2 = "print('world')" + config1 = CellConfig(column=1) + config2 = CellConfig(disabled=True) + + cell_manager.register_cell(cell_id1, code1, config1, name="cell1") + cell_manager.register_cell(cell_id2, code2, config2, name="cell2") + + assert list(cell_manager.names()) == ["cell1", "cell2"] + assert list(cell_manager.codes()) == [code1, code2] + assert list(cell_manager.configs()) == [config1, config2] + assert list(cell_manager.cell_ids()) == [cell_id1, cell_id2] + assert cell_manager.config_map() == {cell_id1: config1, cell_id2: config2} + + +def test_get_cell_id_by_code(cell_manager: CellManager) -> None: + code = "print('unique')" + cell_manager.register_cell("test_cell1", code, CellConfig()) + cell_manager.register_cell("test_cell2", "different_code", CellConfig()) + + assert cell_manager.get_cell_id_by_code(code) == "test_cell1" + assert cell_manager.get_cell_id_by_code("nonexistent") is None + + +def test_register_unparsable_cell(cell_manager: CellManager) -> None: + code = """ + def unparsable(): + return "test" + """ + config = CellConfig() + + cell_manager.register_unparsable_cell( + code=code, + name="unparsable", + cell_config=config, + ) + + cell_data = next(iter(cell_manager.cell_data())) + assert cell_data.name == "unparsable" + assert "def unparsable():" in cell_data.code + assert cell_data.cell is None # Unparsable cells have no Cell object + + +def test_valid_cells(cell_manager: CellManager) -> None: + # Register a mix of valid and invalid cells + cell1 = Cell(_name="_", _cell=compile_cell("print('valid')", "test_cell1")) + cell_manager.register_cell( + "test_cell1", "print('valid')", CellConfig(), cell=cell1 + ) + cell_manager.register_cell( + "test_cell2", "print('invalid')", CellConfig(), cell=None + ) + + valid_cells = list(cell_manager.valid_cells()) + assert len(valid_cells) == 1 + assert valid_cells[0][0] == "test_cell1" + assert valid_cells[0][1] == cell1 + + valid_ids = list(cell_manager.valid_cell_ids()) + assert valid_ids == ["test_cell1"] + + +def test_match_cell_ids_by_similarity(): + # Test exact matches + assert _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["abc", "def", "ghi"], + next_ids=["unused_a", "unused_b", "unused_c"], + next_codes=["abc", "def", "ghi"], + ) == ["a", "b", "c"] + + # Test with reordered codes + assert _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["abc", "def", "ghi"], + next_ids=["unused_a", "unused_b", "unused_c"], + next_codes=["def", "ghi", "abc"], + ) == ["b", "c", "a"] + + # Test with similar but not exact matches + assert _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["abc", "def", "ghi"], + next_ids=["unused_a", "unused_b", "unused_c"], + next_codes=["ghij", "abcd", "defg"], + ) == ["c", "a", "b"] + + # Test with fewer next cells + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["abc", "def", "ghi"], + next_ids=["unused_a", "unused_b"], + next_codes=["abc", "ghi"], + ) + assert len(result) == 2 + assert result == ["a", "c"] + + # Test with more next cells + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["abc", "def"], + next_ids=["a", "b", "c"], + next_codes=["def", "ghi", "abc"], + ) + assert len(result) == 3 + assert result == ["b", "c", "a"] + + # Test with completely different codes + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["abc", "def"], + next_ids=["unused_a", "unused_b"], + next_codes=["xyz", "123"], + ) + assert len(result) == 2 + + # Test empty lists + assert _match_cell_ids_by_similarity([], [], [], []) == [] + + # Test with empty strings + assert _match_cell_ids_by_similarity( + prev_ids=["a"], + prev_codes=[""], + next_ids=["unused_a"], + next_codes=[""], + ) == ["a"] + + +def test_match_cell_ids_by_similarity_edge_cases(): + # Test with multiple identical codes in prev + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["same", "same", "diff"], + next_ids=["x", "y"], + next_codes=["same", "diff"], + ) + assert len(result) == 2 + assert result[0] in ["a", "b"] + assert result[1] == "c" + + # Test with multiple identical codes in next + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["code1", "code2"], + next_ids=["x", "y", "z"], + next_codes=["code1", "code1", "code2"], + ) + assert len(result) == 3 + assert result == ["a", "x", "b"] + + # Test with very long common prefixes/suffixes + long_prefix = "x" * 1000 + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=[long_prefix + "1", long_prefix + "2"], + next_ids=["x", "y"], + next_codes=[long_prefix + "2", long_prefix + "1"], + ) + assert result == ["b", "a"] + + # Test with Unicode and special characters + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["🔥", "∑∫", "\n\t\r"], + next_ids=["x", "y", "z"], + next_codes=["∑∫", "🔥", "\n\t\r"], + ) + assert result == ["b", "a", "c"] + + # Test with mixed case sensitivity + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["ABC", "def"], + next_ids=["x", "y"], + next_codes=["abc", "DEF"], + ) + assert len(result) == 2 + + # Test with whitespace variations + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["x y", "a\nb"], + next_ids=["x", "y"], + next_codes=["x y", "a b"], + ) + assert result == ["a", "b"] + + # Test with all codes being substrings of each other + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["x", "xy", "xyz"], + next_ids=["p", "q", "r"], + next_codes=["xyz", "xy", "x"], + ) + assert result == ["c", "b", "a"] + + # Test with maximum length differences + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["x", "x" * 10000], + next_ids=["y", "z"], + next_codes=["x" * 10000, "x"], + ) + assert result == ["b", "a"] + + # Test with empty strings + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["", "x"], + next_ids=["y", "z"], + next_codes=["x", ""], + ) + assert result == ["b", "a"] + + # Test with identical codes + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b", "c"], + prev_codes=["same", "same", "same"], + next_ids=["x", "y", "z"], + next_codes=["same", "same", "same"], + ) + assert len(result) == 3 + assert set(result) == {"a", "b", "c"} + + # Test with completely different codes + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["abc", "def"], + next_ids=["x", "y"], + next_codes=["123", "456"], + ) + assert len(result) == 2 + # This is probably ok since they are completely different. + # Not sure what the best behavior is. + assert result == ["b", "x"] + + # Test with special Python syntax + result = _match_cell_ids_by_similarity( + prev_ids=["a", "b"], + prev_codes=["def foo():", "class Bar:"], + next_ids=["x", "y"], + next_codes=["class Bar:", "def foo():"], + ) + assert result == ["b", "a"] + + +def test_sort_cell_ids_by_similarity_reorder(): + # Test simple reorder + prev_manager = CellManager() + prev_manager.register_cell("a", "code1", CellConfig()) + prev_manager.register_cell("b", "code2", CellConfig()) + + curr_manager = CellManager() + curr_manager.register_cell("x", "code2", CellConfig()) + curr_manager.register_cell("y", "code1", CellConfig()) + + # Save original seen_ids + original_seen_ids = curr_manager.seen_ids.copy() + + curr_manager.sort_cell_ids_by_similarity(prev_manager) + assert list(curr_manager.cell_ids()) == ["b", "a"] + assert curr_manager.cell_data_at("b").code == "code2" + assert curr_manager.cell_data_at("a").code == "code1" + + # Check seen_ids were updated + assert curr_manager.seen_ids == original_seen_ids | {"a", "b"} + + +def test_sort_cell_ids_by_similarity_reorder_same_ids(): + # Test simple reorder + prev_manager = CellManager() + prev_manager.register_cell("a", "code1", CellConfig()) + prev_manager.register_cell("b", "code2", CellConfig()) + + curr_manager = CellManager() + curr_manager.register_cell("a", "code2", CellConfig()) + curr_manager.register_cell("b", "code1", CellConfig()) + + # Save original seen_ids + original_seen_ids = curr_manager.seen_ids.copy() + + curr_manager.sort_cell_ids_by_similarity(prev_manager) + assert list(curr_manager.cell_ids()) == ["b", "a"] + assert curr_manager.cell_data_at("b").code == "code2" + assert curr_manager.cell_data_at("a").code == "code1" + + # Check seen_ids were updated + assert curr_manager.seen_ids == original_seen_ids | {"a", "b"} + + +def test_sort_cell_ids_by_similarity_less_cells(): + # Test less cells than before + prev_manager = CellManager() + prev_manager.register_cell("a", "code1", CellConfig()) + prev_manager.register_cell("b", "code2", CellConfig()) + prev_manager.register_cell("c", "code3", CellConfig()) + + curr_manager = CellManager() + curr_manager.register_cell("x", "code2", CellConfig()) + curr_manager.register_cell("y", "code1", CellConfig()) + + original_seen_ids = curr_manager.seen_ids.copy() + + curr_manager.sort_cell_ids_by_similarity(prev_manager) + assert list(curr_manager.cell_ids()) == ["b", "a"] + assert curr_manager.seen_ids == original_seen_ids | {"a", "b"} + assert curr_manager.cell_data_at("b").code == "code2" + assert curr_manager.cell_data_at("a").code == "code1" + + +def test_sort_cell_ids_by_similarity_more_cells(): + # Test more cells than before + prev_manager = CellManager() + prev_manager.register_cell("a", "code1", CellConfig()) + prev_manager.register_cell("b", "code2", CellConfig()) + + curr_manager = CellManager() + curr_manager.register_cell("x", "code2", CellConfig()) + curr_manager.register_cell("y", "code1", CellConfig()) + curr_manager.register_cell("z", "code3", CellConfig()) + + original_seen_ids = curr_manager.seen_ids.copy() + + curr_manager.sort_cell_ids_by_similarity(prev_manager) + assert list(curr_manager.cell_ids()) == ["b", "a", "x"] + assert curr_manager.seen_ids == original_seen_ids | {"a", "b", "x"} + assert curr_manager.cell_data_at("b").code == "code2" + assert curr_manager.cell_data_at("a").code == "code1" + assert curr_manager.cell_data_at("x").code == "code3" diff --git a/tests/_server/api/endpoints/test_resume_session.py b/tests/_server/api/endpoints/test_resume_session.py index 8da7d0a368b..9fbce076fc7 100644 --- a/tests/_server/api/endpoints/test_resume_session.py +++ b/tests/_server/api/endpoints/test_resume_session.py @@ -76,8 +76,8 @@ def test_refresh_session(client: TestClient) -> None: # Check the session still exists after closing the websocket session = get_session(client, "123") - session_view = session.session_view assert session + session_view = session.session_view # Mimic cell execution time save cell_op = CellOp("Hbol") @@ -284,3 +284,85 @@ def test_restart_session(client: TestClient) -> None: # Shutdown the kernel client.post("/api/kernel/shutdown", headers=HEADERS) + + +def test_resume_session_with_watch(client: TestClient) -> None: + session_manager = get_session_manager(client) + session_manager.watch = True + + with client.websocket_connect("/ws?session_id=123") as websocket: + data = websocket.receive_json() + assert_kernel_ready_response(data, create_response({})) + + # Write to the notebook file to trigger a reload + # we write it as the second to last cell + filename = session_manager.file_router.get_unique_file_key() + assert filename + with open(filename, "r+") as f: + content = f.read() + last_cell_pos = content.rindex("@app.cell") + f.seek(last_cell_pos) + f.write( + "\n@app.cell\ndef _(): x=10; x\n" + content[last_cell_pos:] + ) + f.close() + + data = websocket.receive_json() + assert data == { + "op": "update-cell-ids", + "data": {"cell_ids": ["MJUe", "Hbol"]}, + } + data = websocket.receive_json() + assert data == { + "op": "update-cell-codes", + "data": { + "cell_ids": ["MJUe", "Hbol"], + "code_is_stale": True, + "codes": ["x=10; x", "import marimo as mo"], + }, + } + + # Resume session with new ID (simulates refresh) + with client.websocket_connect("/ws?session_id=456") as websocket: + # First message is the kernel reconnected + data = websocket.receive_json() + assert data == {"op": "reconnected", "data": {}} + + # Check for KernelReady message + data = websocket.receive_json() + assert parse_raw(data["data"], KernelReady) + messages: list[dict[str, Any]] = [] + + # Wait for update-cell-codes message + while True: + data = websocket.receive_json() + messages.append(data) + if data["op"] == "update-cell-codes": + break + + # 3 messages: + # 1. banner + # 2. update-cell-ids + # 3. update-cell-codes + assert len(messages) == 3 + assert messages[0]["op"] == "banner" + assert messages[1] == { + "op": "update-cell-ids", + "data": {"cell_ids": ["MJUe", "Hbol"]}, + } + assert messages[2] == { + "op": "update-cell-codes", + "data": { + "cell_ids": ["MJUe", "Hbol"], + "code_is_stale": True, + "codes": ["x=10; x", "import marimo as mo"], + }, + } + + session = get_session(client, "456") + assert session + session_view = session.session_view + assert session_view.last_executed_code == {} + + session_manager.watch = False + client.post("/api/kernel/shutdown", headers=HEADERS) diff --git a/tests/_server/session/test_session_view.py b/tests/_server/session/test_session_view.py index c7b87bb6e76..0cb990b9c19 100644 --- a/tests/_server/session/test_session_view.py +++ b/tests/_server/session/test_session_view.py @@ -10,6 +10,7 @@ from marimo._messaging.ops import ( CellOp, Datasets, + UpdateCellCodes, UpdateCellIdsRequest, VariableDeclaration, Variables, @@ -652,3 +653,46 @@ def test_mark_auto_export(): ) assert not session_view.has_auto_exported_html assert not session_view.has_auto_exported_md + + +def test_stale_code() -> None: + """Test that stale code is properly tracked and included in operations.""" + session_view = SessionView() + assert session_view.stale_code is None + + # Add stale code operation + stale_code_op = UpdateCellCodes( + cell_ids=["cell1"], + codes=["print('hello')"], + code_is_stale=True, + ) + session_view.add_operation(stale_code_op) + + # Verify stale code is tracked + assert session_view.stale_code == stale_code_op + assert session_view.stale_code in session_view.operations + + # Add non-stale code operation + non_stale_code_op = UpdateCellCodes( + cell_ids=["cell2"], + codes=["print('world')"], + code_is_stale=False, + ) + session_view.add_operation(non_stale_code_op) + + # Verify non-stale code doesn't affect stale_code tracking + assert session_view.stale_code == stale_code_op + assert session_view.stale_code in session_view.operations + + # Update stale code + new_stale_code_op = UpdateCellCodes( + cell_ids=["cell3"], + codes=["print('updated')"], + code_is_stale=True, + ) + session_view.add_operation(new_stale_code_op) + + # Verify stale code is updated + assert session_view.stale_code == new_stale_code_op + assert session_view.stale_code in session_view.operations + assert stale_code_op not in session_view.operations diff --git a/tests/_server/test_file_manager.py b/tests/_server/test_file_manager.py index 2d2e32257fb..685b51f2beb 100644 --- a/tests/_server/test_file_manager.py +++ b/tests/_server/test_file_manager.py @@ -249,3 +249,183 @@ def test_to_code(app_file_manager: AppFileManager) -> None: "", ] ) + + +def test_reload_reorders_cells() -> None: + """Test that reload() reorders cell IDs based on similarity to previous cells.""" + # Create a temporary file with initial content + temp_file = tempfile.NamedTemporaryFile(suffix=".py", delete=False) + initial_content = """ +import marimo +__generated_with = "0.0.1" +app = marimo.App() + +@app.cell +def cell1(): + x = 1 + return x + +@app.cell +def cell2(): + y = 2 + return y + +if __name__ == "__main__": + app.run() +""" + temp_file.write(initial_content.encode()) + temp_file.close() + + # Initialize AppFileManager with the temp file + manager = AppFileManager(filename=temp_file.name) + original_cell_ids = list(manager.app.cell_manager.cell_ids()) + assert original_cell_ids == ["Hbol", "MJUe"] + + # Modify the file content - swap the cells but keep similar content + modified_content = """ +import marimo +__generated_with = "0.0.1" +app = marimo.App() + +@app.cell +def cell2(): + y = 2 + return y + +@app.cell +def cell1(): + x = 1 + return x + +if __name__ == "__main__": + app.run() +""" + with open(temp_file.name, "w") as f: + f.write(modified_content) + + # Reload the file + manager.reload() + + # The cell IDs should be reordered to match the original code + reloaded_cell_ids = list(manager.app.cell_manager.cell_ids()) + assert len(reloaded_cell_ids) == len(original_cell_ids) + assert reloaded_cell_ids == ["MJUe", "Hbol"] + + # Clean up + os.remove(temp_file.name) + + +def test_reload_updates_content() -> None: + """Test that reload() updates the file contents correctly.""" + # Create a temporary file with initial content + temp_file = tempfile.NamedTemporaryFile(suffix=".py", delete=False) + initial_content = """ +import marimo +__generated_with = "0.0.1" +app = marimo.App() + +@app.cell +def cell1(): + x = 1 + return x + +if __name__ == "__main__": + app.run() +""" + temp_file.write(initial_content.encode()) + temp_file.close() + + # Initialize AppFileManager with the temp file + manager = AppFileManager(filename=temp_file.name) + original_code = list(manager.app.cell_manager.codes())[0] + assert "x = 1" in original_code + + # Modify the file content + modified_content = """ +import marimo +__generated_with = "0.0.1" +app = marimo.App() + +@app.cell +def cell1(): + x = 42 # Changed value + return x + +if __name__ == "__main__": + app.run() +""" + with open(temp_file.name, "w") as f: + f.write(modified_content) + + # Reload the file + manager.reload() + + # Check that the code was updated + reloaded_code = list(manager.app.cell_manager.codes())[0] + assert "x = 42" in reloaded_code + assert "x = 1" not in reloaded_code + + # Clean up + os.remove(temp_file.name) + + +def test_reload_updates_new_cell() -> None: + """Test that reload() updates the file contents correctly.""" + + # Create a temp file with initial content + temp_file = tempfile.NamedTemporaryFile(suffix=".py", delete=False) + initial_content = """ +import marimo +app = marimo.App() + +@app.cell +def cell1(): + x = 1 + return x + +if __name__ == "__main__": + app.run() +""" + temp_file.write(initial_content.encode()) + temp_file.close() + + # Initialize AppFileManager with the temp file + manager = AppFileManager(filename=temp_file.name) + assert len(list(manager.app.cell_manager.codes())) == 1 + original_cell_ids = list(manager.app.cell_manager.cell_ids()) + assert original_cell_ids == ["Hbol"] + + # Modify the file content to add a new cell + modified_content = """ +import marimo +app = marimo.App() + +@app.cell +def cell2(): + y = 2 + return y + +@app.cell +def cell1(): + x = 1 + return x + +if __name__ == "__main__": + app.run() +""" + with open(temp_file.name, "w") as f: + f.write(modified_content) + + # Reload the file + manager.reload() + + # Check that the new cell was added + codes = list(manager.app.cell_manager.codes()) + assert len(codes) == 2 + assert "y = 2" in codes[0] + assert "x = 1" in codes[1] + next_cell_ids = list(manager.app.cell_manager.cell_ids()) + assert next_cell_ids == ["MJUe", "Hbol"] + + # Clean up + os.remove(temp_file.name)