diff --git a/docs/guides/editor_features/watching.md b/docs/guides/editor_features/watching.md new file mode 100644 index 00000000000..b7b11c19a93 --- /dev/null +++ b/docs/guides/editor_features/watching.md @@ -0,0 +1,29 @@ +# Watching notebooks + +marimo's `--watch` flag enables a file watcher that automatically sync your +notebook file with the marimo editor or running application. + +This allows you to edit your notebook file in the editor of your choice, and +have the changes automatically reflected in the running editor or application. + +!!! tip "Install watchdog for better file watching" + For better file watching performance, install watchdog with `pip install watchdog`. Without watchdog, marimo will poll for file changes which is less efficient. + +## `marimo run --watch` + +When you run a notebook with the `--watch` flag, whenever the file watcher +detects a change to the notebook file, the application will be refreshed. +The browser will trigger a page refresh to ensure your notebook starts from a fresh state. + +## `marimo watch --watch` + +When you edit a notebook file with the `--watch` flag, whenever the file watcher +detects a change to the notebook file, the new cells and code changes will be streamed to +the browser editor. + +This code will not be executed until you run the cell, and instead marked as stale. + +## Watching for data changes + +!!! note + Support for watching data files and automatically refreshing cells that depend on them is coming soon. Follow along at diff --git a/frontend/src/core/cells/__tests__/cells.test.ts b/frontend/src/core/cells/__tests__/cells.test.ts index b89f3423eaa..d9c501e0977 100644 --- a/frontend/src/core/cells/__tests__/cells.test.ts +++ b/frontend/src/core/cells/__tests__/cells.test.ts @@ -1305,9 +1305,29 @@ describe("cell reducer", () => { actions.setCellIds({ cellIds: newIds }); expect(state.cellIds.atOrThrow(FIRST_COLUMN).topLevelIds).toEqual(newIds); - actions.setCellCodes({ codes: newCodes, ids: newIds }); + // When codeIsStale is false, lastCodeRun should match code + actions.setCellCodes({ + codes: newCodes, + ids: newIds, + codeIsStale: false, + }); newIds.forEach((id, index) => { expect(state.cellData[id].code).toBe(newCodes[index]); + expect(state.cellData[id].lastCodeRun).toBe(newCodes[index]); + expect(state.cellData[id].edited).toBe(false); + }); + + // When codeIsStale is true, lastCodeRun should not change + const staleCodes = ["stale1", "stale2", "stale3"]; + actions.setCellCodes({ + codes: staleCodes, + ids: newIds, + codeIsStale: true, + }); + newIds.forEach((id, index) => { + expect(state.cellData[id].code).toBe(staleCodes[index]); + expect(state.cellData[id].lastCodeRun).toBe(newCodes[index]); + expect(state.cellData[id].edited).toBe(true); }); }); diff --git a/frontend/src/core/cells/cells.ts b/frontend/src/core/cells/cells.ts index 8da2aad3a3e..6e6ba2a94df 100644 --- a/frontend/src/core/cells/cells.ts +++ b/frontend/src/core/cells/cells.ts @@ -780,7 +780,10 @@ const { cellHandles: nextCellHandles, }; }, - setCellCodes: (state, action: { codes: string[]; ids: CellId[] }) => { + setCellCodes: ( + state, + action: { codes: string[]; ids: CellId[]; codeIsStale: boolean }, + ) => { invariant( action.codes.length === action.ids.length, "Expected codes and ids to have the same length", @@ -791,11 +794,26 @@ const { const code = action.codes[i]; state = updateCellData(state, cellId, (cell) => { + // No change + if (cell.code.trim() === code.trim()) { + return cell; + } + + // Update codemirror if mounted + const cellHandle = state.cellHandles[cellId].current; + if (cellHandle?.editorView) { + updateEditorCodeFromPython(cellHandle.editorView, code); + } + + // If code is stale, we don't promote it to lastCodeRun + const lastCodeRun = action.codeIsStale ? cell.lastCodeRun : code; + return { ...cell, - code, - edited: false, - lastCodeRun: code, + code: code, + // Mark as edited if the code has changed + edited: lastCodeRun ? lastCodeRun.trim() !== code.trim() : false, + lastCodeRun, }; }); } diff --git a/frontend/src/core/websocket/useMarimoWebSocket.tsx b/frontend/src/core/websocket/useMarimoWebSocket.tsx index d6809dd2063..f2906082f2e 100644 --- a/frontend/src/core/websocket/useMarimoWebSocket.tsx +++ b/frontend/src/core/websocket/useMarimoWebSocket.tsx @@ -203,6 +203,7 @@ export function useMarimoWebSocket(opts: { setCellCodes({ codes: msg.data.codes, ids: msg.data.cell_ids as CellId[], + codeIsStale: msg.data.code_is_stale, }); return; case "update-cell-ids": diff --git a/marimo/_cli/cli.py b/marimo/_cli/cli.py index 51db340551a..187b6358948 100644 --- a/marimo/_cli/cli.py +++ b/marimo/_cli/cli.py @@ -286,6 +286,14 @@ def main( help=sandbox_message, ) @click.option("--profile-dir", default=None, type=str, hidden=True) +@click.option( + "--watch", + is_flag=True, + default=False, + show_default=True, + type=bool, + help="Watch the file for changes and reload the code when saved in another editor.", +) @click.argument("name", required=False, type=click.Path()) @click.argument("args", nargs=-1, type=click.UNPROCESSED) def edit( @@ -300,6 +308,7 @@ def edit( skip_update_check: bool, sandbox: bool, profile_dir: Optional[str], + watch: bool, name: Optional[str], args: tuple[str, ...], ) -> None: @@ -369,7 +378,7 @@ def edit( headless=headless, mode=SessionMode.EDIT, include_code=True, - watch=False, + watch=watch, cli_args=parse_args(args), auth_token=_resolve_token(token, token_password), base_url=base_url, diff --git a/marimo/_messaging/ops.py b/marimo/_messaging/ops.py index 2086b236ce5..681e0776043 100644 --- a/marimo/_messaging/ops.py +++ b/marimo/_messaging/ops.py @@ -639,6 +639,7 @@ class UpdateCellCodes(Op): name: ClassVar[str] = "update-cell-codes" cell_ids: List[CellId_t] codes: List[str] + code_is_stale: bool @dataclass diff --git a/marimo/_server/api/endpoints/files.py b/marimo/_server/api/endpoints/files.py index 887078ac944..0322a3cdde7 100644 --- a/marimo/_server/api/endpoints/files.py +++ b/marimo/_server/api/endpoints/files.py @@ -104,6 +104,12 @@ async def rename_file( from_consumer_id=ConsumerId(app_state.require_current_session_id()), ) + if new_path: + # Handle rename for watch + app_state.session_manager.handle_file_rename_for_watch( + app_state.require_current_session_id(), new_path + ) + return SuccessResponse() diff --git a/marimo/_server/api/lifespans.py b/marimo/_server/api/lifespans.py index b781efd0022..6bef4ffb715 100644 --- a/marimo/_server/api/lifespans.py +++ b/marimo/_server/api/lifespans.py @@ -75,15 +75,6 @@ async def lsp(app: Starlette) -> AsyncIterator[None]: yield -@contextlib.asynccontextmanager -async def watcher(app: Starlette) -> AsyncIterator[None]: - state = AppState.from_app(app) - if state.watch: - session_mgr = state.session_manager - session_mgr.start_file_watcher() - yield - - @contextlib.asynccontextmanager async def open_browser(app: Starlette) -> AsyncIterator[None]: state = AppState.from_app(app) diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 05f0cd0a2a5..4380a022577 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -71,7 +71,7 @@ ConnectionDistributor, QueueDistributor, ) -from marimo._utils.file_watcher import FileWatcher +from marimo._utils.file_watcher import FileWatcherManager from marimo._utils.paths import import_files from marimo._utils.repr import format_repr from marimo._utils.typed_connection import TypedConnection @@ -165,7 +165,7 @@ def __init__( virtual_files_supported: bool, redirect_console_to_browser: bool, ) -> None: - self.kernel_task: Optional[threading.Thread] | Optional[mp.Process] + self.kernel_task: Optional[threading.Thread | mp.Process] = None self.queue_manager = queue_manager self.mode = mode self.configs = configs @@ -554,6 +554,8 @@ def put_control_request( UpdateCellCodes( cell_ids=request.cell_ids, codes=request.codes, + # Not stale because we just ran the code + code_is_stale=False, ), except_consumer=from_consumer_id, ) @@ -704,6 +706,7 @@ def __init__( auth_token: Optional[AuthToken], redirect_console_to_browser: bool, ttl_seconds: Optional[int], + watch: bool = False, ) -> None: self.file_router = file_router self.mode = mode @@ -713,7 +716,8 @@ def __init__( self.include_code = include_code self.ttl_seconds = ttl_seconds self.lsp_server = lsp_server - self.watcher: Optional[FileWatcher] = None + self.watcher_manager = FileWatcherManager() + self.watch = watch self.recents = RecentFilesManager() self.user_config_manager = user_config_manager self.cli_args = cli_args @@ -772,7 +776,7 @@ def create_session( if app_file_manager.path: self.recents.touch(app_file_manager.path) - self.sessions[session_id] = Session.create( + session = Session.create( initialization_id=file_key, session_consumer=session_consumer, mode=self.mode, @@ -787,8 +791,103 @@ def create_session( redirect_console_to_browser=self.redirect_console_to_browser, ttl_seconds=self.ttl_seconds, ) + self.sessions[session_id] = session + + # Start file watcher if enabled + if self.watch and app_file_manager.path: + self._start_file_watcher_for_session(session) + return self.sessions[session_id] + def _start_file_watcher_for_session(self, session: Session) -> None: + """Start a file watcher for a session.""" + if not session.app_file_manager.path: + return + + async def on_file_changed(path: Path) -> None: + LOGGER.debug(f"{path} was modified") + # Skip if the session does not relate to the file + if session.app_file_manager.path != os.path.abspath(path): + return + + # Reload the file manager to get the latest code + try: + session.app_file_manager.reload() + except Exception as e: + # If there are syntax errors, we just skip + # and don't send the changes + LOGGER.error(f"Error loading file: {e}") + return + # In run, we just call Reload() + if self.mode == SessionMode.RUN: + session.write_operation(Reload(), from_consumer_id=None) + return + + # Get the latest codes + codes = list(session.app_file_manager.app.cell_manager.codes()) + cell_ids = list( + session.app_file_manager.app.cell_manager.cell_ids() + ) + # Send the updated codes to the frontend + session.write_operation( + UpdateCellCodes( + cell_ids=cell_ids, + codes=codes, + # The code is considered stale + # because it has not been run yet. + # In the future, we may add auto-run here. + code_is_stale=True, + ), + from_consumer_id=None, + ) + + session._unsubscribe_file_watcher_ = on_file_changed # type: ignore + + self.watcher_manager.add_callback( + Path(session.app_file_manager.path), on_file_changed + ) + + def handle_file_rename_for_watch( + self, session_id: SessionId, new_path: str + ) -> tuple[bool, Optional[str]]: + """Handle renaming a file for a session. + + Returns: + tuple[bool, Optional[str]]: (success, error_message) + """ + session = self.get_session(session_id) + if not session: + return False, "Session not found" + + if not os.path.exists(new_path): + return False, f"File {new_path} does not exist" + + if not session.app_file_manager.path: + return False, "Session has no associated file" + + old_path = session.app_file_manager.path + + try: + # Remove the old file watcher if it exists + if self.watch: + self.watcher_manager.remove_callback( + Path(old_path), + session._unsubscribe_file_watcher_, # type: ignore + ) + + # Add a watcher for the new path if needed + if self.watch: + self._start_file_watcher_for_session(session) + + return True, None + + except Exception as e: + LOGGER.error(f"Error handling file rename: {e}") + + if self.watch: + self._start_file_watcher_for_session(session) + return False, str(e) + def get_session(self, session_id: SessionId) -> Optional[Session]: session = self.sessions.get(session_id) if session: @@ -905,13 +1004,22 @@ async def start_lsp_server(self) -> None: return def close_session(self, session_id: SessionId) -> bool: + """Close a session and remove its file watcher if it has one.""" LOGGER.debug("Closing session %s", session_id) session = self.get_session(session_id) - if session is not None: - session.close() - del self.sessions[session_id] - return True - return False + if session is None: + return False + + # Remove the file watcher callback for this session + if session.app_file_manager.path and self.watch: + self.watcher_manager.remove_callback( + Path(session.app_file_manager.path), + session._unsubscribe_file_watcher_, # type: ignore + ) + + session.close() + del self.sessions[session_id] + return True def close_all_sessions(self) -> None: LOGGER.debug("Closing all sessions (sessions: %s)", self.sessions) @@ -921,43 +1029,16 @@ def close_all_sessions(self) -> None: self.sessions = {} def shutdown(self) -> None: + """Shutdown the session manager and stop all file watchers.""" LOGGER.debug("Shutting down") self.close_all_sessions() self.lsp_server.stop() - if self.watcher: - self.watcher.stop() + self.watcher_manager.stop_all() def should_send_code_to_frontend(self) -> bool: """Returns True if the server can send messages to the frontend.""" return self.mode == SessionMode.EDIT or self.include_code - def start_file_watcher(self) -> Disposable: - """Starts the file watcher if it is not already started""" - if self.mode == SessionMode.EDIT: - # We don't support file watching in edit mode yet - # as there are some edge cases that would need to be handled. - # - what to do if the file is deleted, or is renamed - # - do we re-run the app or just show the changed code - # - we don't properly handle saving from the frontend - LOGGER.warning("Cannot start file watcher in edit mode") - return Disposable.empty() - file = self.file_router.maybe_get_single_file() - if not file: - return Disposable.empty() - - file_path = file.path - - async def on_file_changed(path: Path) -> None: - LOGGER.debug(f"{path} was modified") - for _, session in self.sessions.items(): - session.app_file_manager.reload() - session.write_operation(Reload(), from_consumer_id=None) - - LOGGER.debug("Starting file watcher for %s", file_path) - self.watcher = FileWatcher.create(Path(file_path), on_file_changed) - self.watcher.start() - return Disposable(self.watcher.stop) - def get_active_connection_count(self) -> int: return len( [ diff --git a/marimo/_server/start.py b/marimo/_server/start.py index d79ea8fe2f3..4c1003dde1a 100644 --- a/marimo/_server/start.py +++ b/marimo/_server/start.py @@ -21,6 +21,7 @@ initialize_fd_limit, ) from marimo._server.uvicorn_utils import initialize_signals +from marimo._tracer import LOGGER from marimo._utils.paths import import_files DEFAULT_PORT = 2718 @@ -103,6 +104,20 @@ def start( config_reader = get_default_config_manager(current_path=start_path) + # If watch is true, disable auto-save and format-on-save, + # watch is enabled when they are editing in another editor + if watch: + config_reader = config_reader.with_overrides( + { + "save": { + "autosave": "off", + "format_on_save": False, + "autosave_delay": 1000, + } + } + ) + LOGGER.info("Watch mode enabled, auto-save is disabled") + session_manager = SessionManager( file_router=file_router, mode=mode, @@ -115,6 +130,7 @@ def start( cli_args=cli_args, auth_token=auth_token, redirect_console_to_browser=redirect_console_to_browser, + watch=watch, ) log_level = "info" if development_mode else "error" @@ -126,7 +142,6 @@ def start( lifespan=lifespans.Lifespans( [ lifespans.lsp, - lifespans.watcher, lifespans.etc, lifespans.signal_handler, lifespans.logging, diff --git a/marimo/_utils/file_watcher.py b/marimo/_utils/file_watcher.py index c7e4acc98d0..303138206a1 100644 --- a/marimo/_utils/file_watcher.py +++ b/marimo/_utils/file_watcher.py @@ -4,8 +4,9 @@ import asyncio import os from abc import ABC, abstractmethod +from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Coroutine, Optional +from typing import Any, Callable, Coroutine, Dict, Optional, Set from marimo._ast.app import LOGGER from marimo._dependencies.dependencies import DependencyManager @@ -121,3 +122,53 @@ def stop(self) -> None: self.observer.join() return WatchdogFileWatcher(path, callback, loop) + + +class FileWatcherManager: + """Manages multiple file watchers, sharing watchers for the same file.""" + + def __init__(self) -> None: + # Map of file paths to their watchers + self._watchers: Dict[str, FileWatcher] = {} + # Map of file paths to their callbacks + self._callbacks: Dict[str, Set[Callback]] = defaultdict(set) + + def add_callback(self, path: Path, callback: Callback) -> None: + """Add a callback for a file path. Creates watcher if needed.""" + path_str = str(path) + self._callbacks[path_str].add(callback) + + if path_str not in self._watchers: + + async def shared_callback(changed_path: Path) -> None: + callbacks = self._callbacks.get(str(changed_path), set()) + for cb in callbacks: + await cb(changed_path) + + watcher = FileWatcher.create(path, shared_callback) + watcher.start() + self._watchers[path_str] = watcher + LOGGER.debug(f"Created new watcher for {path_str}") + + def remove_callback(self, path: Path, callback: Callback) -> None: + """Remove a callback for a file path. Removes watcher if no more callbacks.""" + path_str = str(path) + if path_str not in self._callbacks: + return + + self._callbacks[path_str].discard(callback) + + if not self._callbacks[path_str]: + # No more callbacks, clean up + del self._callbacks[path_str] + if path_str in self._watchers: + self._watchers[path_str].stop() + del self._watchers[path_str] + LOGGER.debug(f"Removed watcher for {path_str}") + + def stop_all(self) -> None: + """Stop all file watchers.""" + for watcher in self._watchers.values(): + watcher.stop() + self._watchers.clear() + self._callbacks.clear() diff --git a/mkdocs.yml b/mkdocs.yml index 5a4ed8cf90f..52e3c13a192 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -93,6 +93,7 @@ nav: - AI completion: guides/editor_features/ai_completion.md - Package management: guides/editor_features/package_management.md - Module autoreloading: guides/editor_features/module_autoreloading.md + - Watching notebooks: guides/editor_features/watching.md - Hotkeys: guides/editor_features/hotkeys.md - Run notebooks as apps: guides/apps.md - Run notebooks as scripts: guides/scripts.md diff --git a/openapi/api.yaml b/openapi/api.yaml index 5f8eb53db5c..baa18ab495b 100644 --- a/openapi/api.yaml +++ b/openapi/api.yaml @@ -1837,6 +1837,8 @@ components: items: type: string type: array + code_is_stale: + type: boolean codes: items: type: string @@ -1848,6 +1850,7 @@ components: required: - cell_ids - codes + - code_is_stale - name type: object UpdateCellIdsRequest: @@ -1956,7 +1959,7 @@ components: type: object info: title: marimo API - version: 0.10.2 + version: 0.10.13 openapi: 3.1.0 paths: /@file/{filename_and_length}: diff --git a/openapi/src/api.ts b/openapi/src/api.ts index 9e181116af6..8f7a0444c7d 100644 --- a/openapi/src/api.ts +++ b/openapi/src/api.ts @@ -2846,6 +2846,7 @@ export interface components { }; UpdateCellCodes: { cell_ids: string[]; + code_is_stale: boolean; codes: string[]; /** @enum {string} */ name: "update-cell-codes"; diff --git a/tests/_server/api/endpoints/test_kiosk.py b/tests/_server/api/endpoints/test_kiosk.py index 6a896a46d7d..eeb82f4f3dc 100644 --- a/tests/_server/api/endpoints/test_kiosk.py +++ b/tests/_server/api/endpoints/test_kiosk.py @@ -139,6 +139,7 @@ async def test_connect_kiosk_with_session(client: TestClient) -> None: "data": { "cell_ids": ["cell-3"], "codes": ["print('Hello, cell-3')"], + "code_is_stale": False, }, } # And a focused cell diff --git a/tests/_server/api/endpoints/test_ws.py b/tests/_server/api/endpoints/test_ws.py index 44968f3ebfb..6f658e7379e 100644 --- a/tests/_server/api/endpoints/test_ws.py +++ b/tests/_server/api/endpoints/test_ws.py @@ -208,22 +208,23 @@ def test_fails_on_multiple_connections_with_same_file( async def test_file_watcher_calls_reload(client: TestClient) -> None: session_manager: SessionManager = get_session_manager(client) + session_manager.watch = True with client.websocket_connect("/ws?session_id=123") as websocket: data = websocket.receive_json() assert_kernel_ready_response(data) session_manager.mode = SessionMode.RUN - unsubscribe = session_manager.start_file_watcher() filename = session_manager.file_router.get_unique_file_key() assert filename with open(filename, "a") as f: # noqa: ASYNC101 ASYNC230 f.write("\n# test") f.close() - assert session_manager.watcher - await session_manager.watcher.callback(Path(filename)) - unsubscribe() + assert session_manager.watcher_manager._watchers + watcher = list(session_manager.watcher_manager._watchers.values())[0] + await watcher.callback(Path(filename)) data = websocket.receive_json() assert data == {"op": "reload", "data": {}} session_manager.mode = SessionMode.EDIT + session_manager.watch = False client.post("/api/kernel/shutdown", headers=HEADERS) diff --git a/tests/_server/test_session_manager.py b/tests/_server/test_session_manager.py index 0e817ea98ea..3caa2ebb04b 100644 --- a/tests/_server/test_session_manager.py +++ b/tests/_server/test_session_manager.py @@ -151,8 +151,9 @@ def test_close_session( session_manager: SessionManager, mock_session: Session ) -> None: session_id = "test_session_id" + mock_session.app_file_manager = AppFileManager(filename=None) session_manager.sessions[session_id] = mock_session - session_manager.close_session(session_id) + assert session_manager.close_session(session_id) assert session_id not in session_manager.sessions mock_session.close.assert_called_once() diff --git a/tests/_server/test_sessions.py b/tests/_server/test_sessions.py index 343e5d7d05e..df57a7c56cd 100644 --- a/tests/_server/test_sessions.py +++ b/tests/_server/test_sessions.py @@ -1,6 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations +import asyncio import functools import inspect import os @@ -8,11 +9,14 @@ import sys import time from multiprocessing.queues import Queue as MPQueue -from typing import Any +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any, Callable, TypeVar from unittest.mock import MagicMock from marimo._ast.app import App, InternalApp from marimo._config.manager import get_default_config_manager +from marimo._messaging.ops import UpdateCellCodes from marimo._runtime.requests import ( AppMetadata, CreationRequest, @@ -20,12 +24,20 @@ SetUIElementValueRequest, ) from marimo._server.file_manager import AppFileManager +from marimo._server.file_router import AppFileRouter from marimo._server.model import ConnectionState, SessionMode -from marimo._server.sessions import KernelManager, QueueManager, Session +from marimo._server.sessions import ( + KernelManager, + QueueManager, + Session, + SessionManager, +) from marimo._server.utils import initialize_asyncio +from marimo._utils.marimo_path import MarimoPath initialize_asyncio() +F = TypeVar("F", bound=Callable[..., Any]) app_metadata = AppMetadata( query_params={"some_param": "some_value"}, filename="test.py", cli_args={} @@ -33,7 +45,7 @@ # TODO(akshayka): automatically do this for every test in our test suite -def save_and_restore_main(f): +def save_and_restore_main(f: F) -> F: """Kernels swap out the main module; restore it after running tests""" @functools.wraps(f) @@ -44,7 +56,7 @@ def wrapper(*args: Any, **kwargs: Any) -> None: finally: sys.modules["__main__"] = main - return wrapper + return wrapper # type: ignore @save_and_restore_main @@ -380,3 +392,349 @@ def test_session_with_kiosk_consumers() -> None: assert session.connection_state() == ConnectionState.CLOSED assert not session.room.consumers assert session.room.main_consumer is None + + +@save_and_restore_main +async def test_session_manager_file_watching() -> None: + # Create a temporary file + with NamedTemporaryFile(delete=False, suffix=".py") as tmp_file: + tmp_path = Path(tmp_file.name) + # Write initial notebook content + tmp_file.write( + b"""import marimo as mo + +@mo.cell +def __(): + return 1 +""" + ) + + try: + # Create a session manager with file watching enabled + file_router = AppFileRouter.from_filename(MarimoPath(str(tmp_path))) + session_manager = SessionManager( + file_router=file_router, + mode=SessionMode.EDIT, + development_mode=False, + quiet=True, + include_code=True, + lsp_server=MagicMock(), + user_config_manager=get_default_config_manager(current_path=None), + cli_args={}, + auth_token=None, + redirect_console_to_browser=False, + ttl_seconds=None, + watch=True, + ) + + # Create a mock session consumer + session_consumer = MagicMock() + session_consumer.connection_state.return_value = ConnectionState.OPEN + operations: list[Any] = [] + session_consumer.write_operation = ( + lambda op, *_args: operations.append(op) + ) + + # Create a session + session_manager.create_session( + session_id="test", + session_consumer=session_consumer, + query_params={}, + file_key=str(tmp_path), + ) + + # Wait a bit and then modify the file + await asyncio.sleep(0.2) + with open(tmp_path, "w") as f: # noqa: ASYNC230 + f.write( + """import marimo as mo + +@mo.cell +def __(): + return 2 +""" + ) + + # Wait for the watcher to detect the change + await asyncio.sleep(0.2) + + # Check that UpdateCellCodes was sent with the new code + update_ops = [ + op for op in operations if isinstance(op, UpdateCellCodes) + ] + assert len(update_ops) == 1 + assert "return 2" in update_ops[0].codes[0] + assert update_ops[0].code_is_stale is True + + # Create another session for the same file + session_consumer2 = MagicMock() + session_consumer2.connection_state.return_value = ConnectionState.OPEN + operations2: list[Any] = [] + session_consumer2.write_operation = ( + lambda op, *_args: operations2.append(op) + ) + + session_manager.create_session( + session_id="test2", + session_consumer=session_consumer2, + query_params={}, + file_key=str(tmp_path), + ) + + # Modify the file again + operations.clear() + operations2.clear() + with open(tmp_path, "w") as f: # noqa: ASYNC230 + f.write( + """import marimo as mo + +@mo.cell +def __(): + return 3 +""" + ) + + # Wait for the watcher to detect the change + await asyncio.sleep(0.2) + + # Both sessions should receive the update + update_ops = [ + op for op in operations if isinstance(op, UpdateCellCodes) + ] + update_ops2 = [ + op for op in operations2 if isinstance(op, UpdateCellCodes) + ] + assert len(update_ops) == 1 + assert len(update_ops2) == 1 + assert "return 3" in update_ops[0].codes[0] + assert "return 3" in update_ops2[0].codes[0] + + # Close one session and verify the other still receives updates + assert session_manager.close_session("test") + operations.clear() + operations2.clear() + + with open(tmp_path, "w") as f: # noqa: ASYNC230 + f.write( + """import marimo as mo + +@mo.cell +def __(): + return 4 +""" + ) + + # Wait for the watcher to detect the change + await asyncio.sleep(0.2) + + # Only session2 should receive the update + update_ops = [ + op for op in operations if isinstance(op, UpdateCellCodes) + ] + update_ops2 = [ + op for op in operations2 if isinstance(op, UpdateCellCodes) + ] + assert len(update_ops) == 0 + assert len(update_ops2) == 1 + assert "return 4" in update_ops2[0].codes[0] + + finally: + # Cleanup + session_manager.shutdown() + os.remove(tmp_path) + + +@save_and_restore_main +def test_watch_mode_config_override() -> None: + """Test that watch mode properly overrides config settings.""" + # Create a temporary file + with NamedTemporaryFile(delete=False, suffix=".py") as tmp_file: + tmp_path = Path(tmp_file.name) + tmp_file.write(b"import marimo as mo") + + # Create a config with autosave enabled + config_reader = get_default_config_manager(current_path=None) + config_reader_watch = config_reader.with_overrides( + { + "save": { + "autosave": "off", + "format_on_save": False, + "autosave_delay": 2000, + } + } + ) + + # Create a session manager with watch mode enabled + file_router = AppFileRouter.from_filename(MarimoPath(str(tmp_path))) + session_manager = SessionManager( + file_router=file_router, + mode=SessionMode.EDIT, + development_mode=False, + quiet=True, + include_code=True, + lsp_server=MagicMock(), + user_config_manager=config_reader_watch, + cli_args={}, + auth_token=None, + redirect_console_to_browser=False, + ttl_seconds=None, + watch=True, + ) + + session_manager_no_watch = SessionManager( + file_router=file_router, + mode=SessionMode.EDIT, + development_mode=False, + quiet=True, + include_code=True, + lsp_server=MagicMock(), + user_config_manager=config_reader, + cli_args={}, + auth_token=None, + redirect_console_to_browser=False, + ttl_seconds=None, + watch=False, + ) + + try: + # Verify that the config was overridden + config = session_manager.user_config_manager.get_config() + assert config["save"]["autosave"] == "off" + assert config["save"]["format_on_save"] is False + + # Verify that the config was not overridden + config = session_manager_no_watch.user_config_manager.get_config() + assert config["save"]["autosave"] == "after_delay" + assert config["save"]["format_on_save"] is True + + finally: + # Cleanup + session_manager.shutdown() + session_manager_no_watch.shutdown() + os.remove(tmp_path) + + +@save_and_restore_main +async def test_session_manager_file_rename() -> None: + """Test that file renaming works correctly with file watching.""" + # Create two temporary files + with ( + NamedTemporaryFile(delete=False, suffix=".py") as tmp_file1, + NamedTemporaryFile(delete=False, suffix=".py") as tmp_file2, + ): + tmp_path1 = Path(tmp_file1.name) + tmp_path2 = Path(tmp_file2.name) + # Write initial notebook content + tmp_file1.write( + b"""import marimo as mo + +@mo.cell +def __(): + return 1 +""" + ) + tmp_file2.write(b"import marimo as mo") + + try: + # Create a session manager with file watching enabled + file_router = AppFileRouter.from_filename(MarimoPath(str(tmp_path1))) + session_manager = SessionManager( + file_router=file_router, + mode=SessionMode.EDIT, + development_mode=False, + quiet=True, + include_code=True, + lsp_server=MagicMock(), + user_config_manager=get_default_config_manager(current_path=None), + cli_args={}, + auth_token=None, + redirect_console_to_browser=False, + ttl_seconds=None, + watch=True, + ) + + # Create a mock session consumer + session_consumer = MagicMock() + session_consumer.connection_state.return_value = ConnectionState.OPEN + operations: list[Any] = [] + session_consumer.write_operation = ( + lambda op, *_args: operations.append(op) + ) + + # Create a session + session_manager.create_session( + session_id="test", + session_consumer=session_consumer, + query_params={}, + file_key=str(tmp_path1), + ) + + # Try to rename to a non-existent file + success, error = session_manager.handle_file_rename_for_watch( + "test", "/nonexistent/file.py" + ) + assert not success + assert error is not None + assert "does not exist" in error + + # Try to rename with an invalid session + success, error = session_manager.handle_file_rename_for_watch( + "nonexistent", str(tmp_path2) + ) + assert not success + assert error is not None + assert "Session not found" in error + + # Rename to the second file + success, error = session_manager.handle_file_rename_for_watch( + "test", str(tmp_path2) + ) + assert success + assert error is None + + # Modify the new file + operations.clear() + with open(tmp_path2, "w") as f: # noqa: ASYNC230 + f.write( + """import marimo as mo + +@mo.cell +def __(): + return 2 +""" + ) + + # Wait for the watcher to detect the change + await asyncio.sleep(0.2) + + # Check that UpdateCellCodes was sent with the new code + update_ops = [ + op for op in operations if isinstance(op, UpdateCellCodes) + ] + assert len(update_ops) == 1 + assert "return 2" in update_ops[0].codes[0] + + # Modify the old file - should not trigger any updates + operations.clear() + with open(tmp_path1, "w") as f: # noqa: ASYNC230 + f.write( + """import marimo as mo + +@mo.cell +def __(): + return 3 +""" + ) + + # Wait to verify no updates are triggered + await asyncio.sleep(0.2) + update_ops = [ + op for op in operations if isinstance(op, UpdateCellCodes) + ] + assert len(update_ops) == 0 + + finally: + # Cleanup + session_manager.shutdown() + os.remove(tmp_path1) + os.remove(tmp_path2) diff --git a/tests/_utils/test_file_watcher.py b/tests/_utils/test_file_watcher.py index 84c0dd343d1..6522caf4aa3 100644 --- a/tests/_utils/test_file_watcher.py +++ b/tests/_utils/test_file_watcher.py @@ -7,7 +7,7 @@ from tempfile import NamedTemporaryFile from typing import List -from marimo._utils.file_watcher import PollingFileWatcher +from marimo._utils.file_watcher import FileWatcherManager, PollingFileWatcher async def test_polling_file_watcher() -> None: @@ -41,3 +41,107 @@ async def test_callback(path: Path): # Assert that the callback was called assert len(callback_calls) == 1 assert callback_calls[0] == tmp_path + + +async def test_file_watcher_manager() -> None: + # Create two temporary files + with ( + NamedTemporaryFile(delete=False) as tmp_file1, + NamedTemporaryFile(delete=False) as tmp_file2, + ): + tmp_path1 = Path(tmp_file1.name) + tmp_path2 = Path(tmp_file2.name) + + # Create manager and add callbacks + manager = FileWatcherManager() + + try: + # Track callback calls + callback1_calls: List[Path] = [] + callback2_calls: List[Path] = [] + callback3_calls: List[Path] = [] + + async def callback1(path: Path) -> None: + callback1_calls.append(path) + + async def callback2(path: Path) -> None: + callback2_calls.append(path) + + async def callback3(path: Path) -> None: + callback3_calls.append(path) + + # Speed up polling for tests + PollingFileWatcher.POLL_SECONDS = 0.1 + + # Add two callbacks for file1 + manager.add_callback(tmp_path1, callback1) + manager.add_callback(tmp_path1, callback2) + # Add one callback for file2 + manager.add_callback(tmp_path2, callback3) + + # Modify file1 + await asyncio.sleep(0.2) + with open(tmp_path1, "w") as f: # noqa: ASYNC101 ASYNC230 + f.write("modification1") + + # Wait for callbacks + await asyncio.sleep(0.2) + + # Both callbacks should be called for file1 + assert len(callback1_calls) == 1 + assert len(callback2_calls) == 1 + assert len(callback3_calls) == 0 + assert callback1_calls[0] == tmp_path1 + assert callback2_calls[0] == tmp_path1 + + # Remove one callback from file1 + manager.remove_callback(tmp_path1, callback1) + + # Modify file1 again + with open(tmp_path1, "w") as f: # noqa: ASYNC101 ASYNC230 + f.write("modification2") + + # Wait for callbacks + await asyncio.sleep(0.2) + + # Only callback2 should be called again + assert len(callback1_calls) == 1 # unchanged + assert len(callback2_calls) == 2 + assert len(callback3_calls) == 0 + + # Modify file2 + with open(tmp_path2, "w") as f: # noqa: ASYNC101 ASYNC230 + f.write("modification3") + + # Wait for callbacks + await asyncio.sleep(0.2) + + # callback3 should be called for file2 + assert len(callback1_calls) == 1 + assert len(callback2_calls) == 2 + assert len(callback3_calls) == 1 + assert callback3_calls[0] == tmp_path2 + + # Remove all callbacks + manager.remove_callback(tmp_path1, callback2) + manager.remove_callback(tmp_path2, callback3) + + # Modify both files + with open(tmp_path1, "w") as f: # noqa: ASYNC101 ASYNC230 + f.write("modification4") + with open(tmp_path2, "w") as f: # noqa: ASYNC101 ASYNC230 + f.write("modification4") + + # Wait for potential callbacks + await asyncio.sleep(0.2) + + # No new calls should happen + assert len(callback1_calls) == 1 + assert len(callback2_calls) == 2 + assert len(callback3_calls) == 1 + + finally: + # Cleanup + manager.stop_all() + os.remove(tmp_path1) + os.remove(tmp_path2)