Skip to content

Commit

Permalink
improvement: intelligently keep order of cells when running --watch (#…
Browse files Browse the repository at this point in the history
…3451)

Intelligently keep order of cells when running `marimo edit --watch`. We
use a text similarity heuristic to figure out how cell ids moved since
they were last seen.

We also need to keep track of which IDs were used from reload-to-reload
so we don't surface the same ones.

This also has a few bug fixes with session-views.
  • Loading branch information
mscolnick authored Jan 15, 2025
1 parent d578c18 commit 5364d1c
Show file tree
Hide file tree
Showing 16 changed files with 1,356 additions and 247 deletions.
26 changes: 26 additions & 0 deletions frontend/src/core/cells/__tests__/cells.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,32 @@ describe("cell reducer", () => {
});
});

it("can set cell codes with new cell ids, while preserving the old cell data", () => {
actions.setCellCodes({
codes: ["code1", "code2", "code3"],
ids: ["3", "4", "5"] as CellId[],
codeIsStale: false,
});
expect(state.cellData["3" as CellId].code).toBe("code1");
expect(state.cellData["4" as CellId].code).toBe("code2");
expect(state.cellData["5" as CellId].code).toBe("code3");

// Update with some new cell ids and some old cell ids
actions.setCellIds({ cellIds: ["1", "2", "3", "4"] as CellId[] });
actions.setCellCodes({
codes: ["new1", "new2", "code1", "code2"],
ids: ["1", "2", "3", "4"] as CellId[],
codeIsStale: false,
});
expect(state.cellData["1" as CellId].code).toBe("new1");
expect(state.cellData["2" as CellId].code).toBe("new2");
expect(state.cellData["3" as CellId].code).toBe("code1");
expect(state.cellData["4" as CellId].code).toBe("code2");
expect(state.cellIds.inOrderIds).toEqual(["1", "2", "3", "4"]);
// Cell 5 data is preserved (possibly used for tracing), but it's not in the cellIds
expect(state.cellData["5" as CellId]).not.toBeUndefined();
});

it("can fold and unfold all cells", () => {
actions.foldAll();
expect(foldAllBulk).toHaveBeenCalled();
Expand Down
68 changes: 44 additions & 24 deletions frontend/src/core/cells/cells.ts
Original file line number Diff line number Diff line change
Expand Up @@ -789,36 +789,56 @@ const {
"Expected codes and ids to have the same length",
);

for (let i = 0; i < action.codes.length; i++) {
const cellId = action.ids[i];
const code = action.codes[i];
let nextState = { ...state };

const cellReducer = (
cell: CellData | undefined,
code: string,
cellId: CellId,
) => {
if (!cell) {
return createCell({ id: cellId, code });
}

state = updateCellData(state, cellId, (cell) => {
// No change
if (cell.code.trim() === code.trim()) {
return cell;
}
// No change
if (cell.code.trim() === code.trim()) {
return cell;
}

// Update codemirror if mounted
const cellHandle = state.cellHandles[cellId].current;
if (cellHandle?.editorView) {
updateEditorCodeFromPython(cellHandle.editorView, code);
}
// Update codemirror if mounted
const cellHandle = nextState.cellHandles[cellId].current;
if (cellHandle?.editorView) {
updateEditorCodeFromPython(cellHandle.editorView, code);
}

// If code is stale, we don't promote it to lastCodeRun
const lastCodeRun = action.codeIsStale ? cell.lastCodeRun : code;
// If code is stale, we don't promote it to lastCodeRun
const lastCodeRun = action.codeIsStale ? cell.lastCodeRun : code;

return {
...cell,
code: code,
// Mark as edited if the code has changed
edited: lastCodeRun ? lastCodeRun.trim() !== code.trim() : false,
lastCodeRun,
};
});
return {
...cell,
code: code,
// Mark as edited if the code has changed
edited: lastCodeRun
? lastCodeRun.trim() !== code.trim()
: Boolean(code),
lastCodeRun,
};
};

for (let i = 0; i < action.codes.length; i++) {
const cellId = action.ids[i];
const code = action.codes[i];

nextState = {
...nextState,
cellData: {
...nextState.cellData,
[cellId]: cellReducer(nextState.cellData[cellId], code, cellId),
},
};
}

return state;
return nextState;
},
setStdinResponse: (
state,
Expand Down
214 changes: 1 addition & 213 deletions marimo/_ast/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

import functools
import inspect
import os
import random
import string
from dataclasses import asdict, dataclass, field
from typing import (
TYPE_CHECKING,
Expand All @@ -22,15 +18,13 @@

from marimo import _loggers
from marimo._ast.cell import Cell, CellConfig, CellId_t
from marimo._ast.compiler import cell_factory
from marimo._ast.cell_manager import CellManager
from marimo._ast.errors import (
CycleError,
DeleteNonlocalError,
MultipleDefinitionError,
UnparsableError,
)
from marimo._ast.names import DEFAULT_CELL_NAME
from marimo._ast.pytest import wrap_fn_for_pytest
from marimo._config.config import WidthType
from marimo._messaging.mimetypes import KnownMimeType
from marimo._output.hypertext import Html
Expand Down Expand Up @@ -106,22 +100,6 @@ def update(self, updates: dict[str, Any]) -> _AppConfig:
return self


@dataclass
class CellData:
"""A cell together with some metadata"""

cell_id: CellId_t
# User code comprising the cell
code: str
# User-provided name for cell (or default)
name: str
# Cell config
config: CellConfig

# The original cell, or None if cell was not parsable
cell: Optional[Cell]


class _Namespace(Mapping[str, object]):
def __init__(
self, dictionary: dict[str, object], owner: Cell | App
Expand Down Expand Up @@ -427,196 +405,6 @@ async def embed(self) -> AppEmbedResult:
)


class CellManager:
"""
A manager for cells.
This holds the cells that have been registered with the app, and
provides methods to access them.
"""

def __init__(self, prefix: str = "") -> None:
self._cell_data: dict[CellId_t, CellData] = {}
self.prefix = prefix
self.unparsable = False
self.random_seed = random.Random(42)

def create_cell_id(self) -> CellId_t:
# 4 random letters
return self.prefix + "".join(
self.random_seed.choices(string.ascii_letters, k=4)
)

def cell_decorator(
self,
func: Callable[..., Any] | None,
column: Optional[int],
disabled: bool,
hide_code: bool,
app: InternalApp | None = None,
) -> Cell | Callable[..., Cell]:
cell_config = CellConfig(
column=column, disabled=disabled, hide_code=hide_code
)

def _register(func: Callable[..., Any]) -> Cell:
# Use PYTEST_VERSION here, opposed to PYTEST_CURRENT_TEST, in
# order to allow execution during test collection.
is_top_level_pytest = (
"PYTEST_VERSION" in os.environ
and "PYTEST_CURRENT_TEST" not in os.environ
)
cell = cell_factory(
func,
cell_id=self.create_cell_id(),
anonymous_file=app._app._anonymous_file if app else False,
test_rewrite=is_top_level_pytest,
)
cell._cell.configure(cell_config)
self._register_cell(cell, app=app)
# Manually set the signature for pytest.
if is_top_level_pytest:
func = wrap_fn_for_pytest(func, cell)
# NB. in place metadata update.
functools.wraps(func)(cell)
return cell

if func is None:
# If the decorator was used with parentheses, func will be None,
# and we return a decorator that takes the decorated function as an
# argument
def decorator(func: Callable[..., Any]) -> Cell:
return _register(func)

return decorator
else:
return _register(func)

def _register_cell(
self, cell: Cell, app: InternalApp | None = None
) -> None:
if app is not None:
cell._register_app(app)
cell_impl = cell._cell
self.register_cell(
cell_id=cell_impl.cell_id,
code=cell_impl.code,
name=cell.name,
config=cell_impl.config,
cell=cell,
)

def register_cell(
self,
cell_id: Optional[CellId_t],
code: str,
config: Optional[CellConfig],
name: str = DEFAULT_CELL_NAME,
cell: Optional[Cell] = None,
) -> None:
if cell_id is None:
cell_id = self.create_cell_id()

self._cell_data[cell_id] = CellData(
cell_id=cell_id,
code=code,
name=name,
config=config or CellConfig(),
cell=cell,
)

def register_unparsable_cell(
self,
code: str,
name: Optional[str],
cell_config: CellConfig,
) -> None:
# - code.split("\n")[1:-1] disregards first and last lines, which are
# empty
# - line[4:] removes leading indent in multiline string
# - replace(...) unescapes double quotes
# - rstrip() removes an extra newline
code = "\n".join(
[line[4:].replace('\\"', '"') for line in code.split("\n")[1:-1]]
)

self.register_cell(
cell_id=self.create_cell_id(),
code=code,
config=cell_config,
name=name or DEFAULT_CELL_NAME,
cell=None,
)

def ensure_one_cell(self) -> None:
if not self._cell_data:
cell_id = self.create_cell_id()
self.register_cell(
cell_id=cell_id,
code="",
config=CellConfig(),
)

def cell_name(self, cell_id: CellId_t) -> str:
return self._cell_data[cell_id].name

def names(self) -> Iterable[str]:
for cell_data in self._cell_data.values():
yield cell_data.name

def codes(self) -> Iterable[str]:
for cell_data in self._cell_data.values():
yield cell_data.code

def configs(self) -> Iterable[CellConfig]:
for cell_data in self._cell_data.values():
yield cell_data.config

def valid_cells(
self,
) -> Iterable[tuple[CellId_t, Cell]]:
"""Return cells and functions for each valid cell."""
for cell_data in self._cell_data.values():
if cell_data.cell is not None:
yield (cell_data.cell_id, cell_data.cell)

def valid_cell_ids(self) -> Iterable[CellId_t]:
for cell_data in self._cell_data.values():
if cell_data.cell is not None:
yield cell_data.cell_id

def cell_ids(self) -> Iterable[CellId_t]:
"""Cell IDs in the order they were registered."""
return self._cell_data.keys()

def has_cell(self, cell_id: CellId_t) -> bool:
return cell_id in self._cell_data

def cells(
self,
) -> Iterable[Optional[Cell]]:
for cell_data in self._cell_data.values():
yield cell_data.cell

def config_map(self) -> dict[CellId_t, CellConfig]:
return {cid: cd.config for cid, cd in self._cell_data.items()}

def cell_data(self) -> Iterable[CellData]:
return self._cell_data.values()

def cell_data_at(self, cell_id: CellId_t) -> CellData:
return self._cell_data[cell_id]

def get_cell_id_by_code(self, code: str) -> Optional[CellId_t]:
"""
Finds the first cell with the given code and returns its cell ID.
"""
for cell_id, cell_data in self._cell_data.items():
if cell_data.code == code:
return cell_id
return None


class InternalApp:
"""
Internal representation of an app.
Expand Down
4 changes: 0 additions & 4 deletions marimo/_ast/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,3 @@ class SourcePosition:
filename: str
lineno: int
col_offset: int


def is_ws(char: str) -> bool:
return char == " " or char == "\n" or char == "\t"
Loading

0 comments on commit 5364d1c

Please sign in to comment.