From 864127ce9d3dc660f48d919c532e4f46fe425947 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Sat, 4 Jan 2025 12:58:57 -0800 Subject: [PATCH] feat: configurable TTL for marimo run sessions (#3344) --- marimo/_cli/cli.py | 14 ++++++++++++++ marimo/_server/api/endpoints/ws.py | 14 ++++++++------ marimo/_server/asgi.py | 1 + marimo/_server/export/__init__.py | 1 + marimo/_server/sessions.py | 15 ++++++++++++--- marimo/_server/start.py | 2 ++ tests/_server/mocks.py | 1 + tests/_server/test_session_manager.py | 1 + tests/_server/test_sessions.py | 3 +++ 9 files changed, 43 insertions(+), 9 deletions(-) diff --git a/marimo/_cli/cli.py b/marimo/_cli/cli.py index 14737de0fe2..c7910b27adb 100644 --- a/marimo/_cli/cli.py +++ b/marimo/_cli/cli.py @@ -375,6 +375,7 @@ def edit( base_url=base_url, allow_origins=allow_origins, redirect_console_to_browser=True, + ttl_seconds=None, ) @@ -469,6 +470,7 @@ def new( auth_token=_resolve_token(token, token_password), base_url=base_url, redirect_console_to_browser=True, + ttl_seconds=None, ) @@ -533,6 +535,15 @@ def new( type=bool, help="Include notebook code in the app.", ) +@click.option( + "--session-ttl", + default=120, + show_default=True, + type=int, + help=( + "Seconds to wait before closing a session on " "websocket disconnect." + ), +) @click.option( "--watch", is_flag=True, @@ -585,6 +596,7 @@ def run( token: bool, token_password: Optional[str], include_code: bool, + session_ttl: int, watch: bool, base_url: str, allow_origins: tuple[str, ...], @@ -633,6 +645,7 @@ def run( headless=headless, mode=SessionMode.RUN, include_code=include_code, + ttl_seconds=session_ttl, watch=watch, base_url=base_url, allow_origins=allow_origins, @@ -742,6 +755,7 @@ def tutorial( cli_args={}, auth_token=_resolve_token(token, token_password), redirect_console_to_browser=False, + ttl_seconds=None, ) diff --git a/marimo/_server/api/endpoints/ws.py b/marimo/_server/api/endpoints/ws.py index d0bb7f5f876..65ba0fb642f 100644 --- a/marimo/_server/api/endpoints/ws.py +++ b/marimo/_server/api/endpoints/ws.py @@ -315,9 +315,9 @@ def _on_disconnect( session.disconnect_consumer(self) if self.manager.mode == SessionMode.RUN: - # When the websocket is closed, we wait TTL_SECONDS before - # closing the session. This is to prevent the session from - # being closed if the during an intermittent network issue. + # When the websocket is closed, we wait session.ttl_seconds before + # closing the session. This is to prevent the session from being + # closed if the during an intermittent network issue. def _close() -> None: if self.status != ConnectionState.OPEN: LOGGER.debug( @@ -330,11 +330,13 @@ def _close() -> None: self.manager.close_session(self.session_id) session = self.manager.get_session(self.session_id) - cancellation_handle = asyncio.get_event_loop().call_later( - Session.TTL_SECONDS, _close - ) if session is not None: + cancellation_handle = asyncio.get_event_loop().call_later( + session.ttl_seconds, _close + ) self.cancel_close_handle = cancellation_handle + else: + _close() else: cleanup_fn() diff --git a/marimo/_server/asgi.py b/marimo/_server/asgi.py index 08822c6a7c9..90f9b1a129a 100644 --- a/marimo/_server/asgi.py +++ b/marimo/_server/asgi.py @@ -462,6 +462,7 @@ def _create_app_for_file(base_url: str, file_path: str) -> ASGIApp: cli_args={}, auth_token=auth_token, redirect_console_to_browser=False, + ttl_seconds=None, ) app = create_starlette_app( base_url="", diff --git a/marimo/_server/export/__init__.py b/marimo/_server/export/__init__.py index b99e6790aec..e7257ddf9a7 100644 --- a/marimo/_server/export/__init__.py +++ b/marimo/_server/export/__init__.py @@ -278,6 +278,7 @@ def connection_state(self) -> ConnectionState: user_config_manager=config_manager, virtual_files_supported=False, redirect_console_to_browser=False, + ttl_seconds=None, ) # Run the notebook to completion once diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 6a79db85e2a..8f604281718 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -77,7 +77,6 @@ from marimo._utils.typed_connection import TypedConnection LOGGER = _loggers.marimo_logger() -SESSION_MANAGER: Optional["SessionManager"] = None class QueueManager: @@ -396,6 +395,9 @@ def close(self) -> None: self.main_consumer = None +_DEFAULT_TTL_SECONDS = 120 + + class Session: """A client session. @@ -403,8 +405,6 @@ class Session: and its own websocket, for sending messages to the client. """ - TTL_SECONDS = 120 - @classmethod def create( cls, @@ -416,6 +416,7 @@ def create( user_config_manager: MarimoConfigReader, virtual_files_supported: bool, redirect_console_to_browser: bool, + ttl_seconds: Optional[int], ) -> Session: """ Create a new session. @@ -438,6 +439,7 @@ def create( queue_manager, kernel_manager, app_file_manager, + ttl_seconds, ) def __init__( @@ -447,6 +449,7 @@ def __init__( queue_manager: QueueManager, kernel_manager: KernelManager, app_file_manager: AppFileManager, + ttl_seconds: Optional[int], ) -> None: """Initialize kernel and client connection to it.""" # This is some unique ID that we can use to identify the session @@ -457,6 +460,9 @@ def __init__( self.room = Room() self._queue_manager = queue_manager self.kernel_manager = kernel_manager + self.ttl_seconds = ( + ttl_seconds if ttl_seconds is not None else _DEFAULT_TTL_SECONDS + ) self.session_view = SessionView() self.kernel_manager.start_kernel() @@ -673,6 +679,7 @@ def __init__( cli_args: SerializedCLIArgs, auth_token: Optional[AuthToken], redirect_console_to_browser: bool, + ttl_seconds: Optional[int], ) -> None: self.file_router = file_router self.mode = mode @@ -680,6 +687,7 @@ def __init__( self.quiet = quiet self.sessions: dict[SessionId, Session] = {} self.include_code = include_code + self.ttl_seconds = ttl_seconds self.lsp_server = lsp_server self.watcher: Optional[FileWatcher] = None self.recents = RecentFilesManager() @@ -753,6 +761,7 @@ def create_session( user_config_manager=self.user_config_manager, virtual_files_supported=True, redirect_console_to_browser=self.redirect_console_to_browser, + ttl_seconds=self.ttl_seconds, ) return self.sessions[session_id] diff --git a/marimo/_server/start.py b/marimo/_server/start.py index d2eee2fa85c..d79ea8fe2f3 100644 --- a/marimo/_server/start.py +++ b/marimo/_server/start.py @@ -71,6 +71,7 @@ def start( development_mode: bool, quiet: bool, include_code: bool, + ttl_seconds: Optional[int], headless: bool, port: Optional[int], host: str, @@ -108,6 +109,7 @@ def start( development_mode=development_mode, quiet=quiet, include_code=include_code, + ttl_seconds=ttl_seconds, lsp_server=LspServer(lsp_port), user_config_manager=config_reader, cli_args=cli_args, diff --git a/tests/_server/mocks.py b/tests/_server/mocks.py index ab825d84e52..49c43cde901 100644 --- a/tests/_server/mocks.py +++ b/tests/_server/mocks.py @@ -57,6 +57,7 @@ def __(): cli_args={}, auth_token=AuthToken("fake-token"), redirect_console_to_browser=False, + ttl_seconds=None, ) sm.skew_protection_token = SkewProtectionToken("skew-id-1") return sm diff --git a/tests/_server/test_session_manager.py b/tests/_server/test_session_manager.py index eae86d91d39..0e817ea98ea 100644 --- a/tests/_server/test_session_manager.py +++ b/tests/_server/test_session_manager.py @@ -45,6 +45,7 @@ def session_manager(): cli_args={}, auth_token=None, redirect_console_to_browser=False, + ttl_seconds=None, ) diff --git a/tests/_server/test_sessions.py b/tests/_server/test_sessions.py index 0e1dac549bb..343e5d7d05e 100644 --- a/tests/_server/test_sessions.py +++ b/tests/_server/test_sessions.py @@ -237,6 +237,7 @@ def test_session() -> None: queue_manager, kernel_manager, AppFileManager.from_app(InternalApp(App())), + ttl_seconds=None, ) # Assert startup @@ -282,6 +283,7 @@ def test_session_disconnect_reconnect() -> None: queue_manager, kernel_manager, AppFileManager.from_app(InternalApp(App())), + ttl_seconds=None, ) # Assert startup @@ -338,6 +340,7 @@ def test_session_with_kiosk_consumers() -> None: queue_manager, kernel_manager, AppFileManager.from_app(InternalApp(App())), + ttl_seconds=None, ) # Assert startup