Skip to content

Commit

Permalink
feat: configurable TTL for marimo run sessions (#3344)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayka authored Jan 4, 2025
1 parent b6834a9 commit 864127c
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 9 deletions.
14 changes: 14 additions & 0 deletions marimo/_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def edit(
base_url=base_url,
allow_origins=allow_origins,
redirect_console_to_browser=True,
ttl_seconds=None,
)


Expand Down Expand Up @@ -469,6 +470,7 @@ def new(
auth_token=_resolve_token(token, token_password),
base_url=base_url,
redirect_console_to_browser=True,
ttl_seconds=None,
)


Expand Down Expand Up @@ -533,6 +535,15 @@ def new(
type=bool,
help="Include notebook code in the app.",
)
@click.option(
"--session-ttl",
default=120,
show_default=True,
type=int,
help=(
"Seconds to wait before closing a session on " "websocket disconnect."
),
)
@click.option(
"--watch",
is_flag=True,
Expand Down Expand Up @@ -585,6 +596,7 @@ def run(
token: bool,
token_password: Optional[str],
include_code: bool,
session_ttl: int,
watch: bool,
base_url: str,
allow_origins: tuple[str, ...],
Expand Down Expand Up @@ -633,6 +645,7 @@ def run(
headless=headless,
mode=SessionMode.RUN,
include_code=include_code,
ttl_seconds=session_ttl,
watch=watch,
base_url=base_url,
allow_origins=allow_origins,
Expand Down Expand Up @@ -742,6 +755,7 @@ def tutorial(
cli_args={},
auth_token=_resolve_token(token, token_password),
redirect_console_to_browser=False,
ttl_seconds=None,
)


Expand Down
14 changes: 8 additions & 6 deletions marimo/_server/api/endpoints/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ def _on_disconnect(
session.disconnect_consumer(self)

if self.manager.mode == SessionMode.RUN:
# When the websocket is closed, we wait TTL_SECONDS before
# closing the session. This is to prevent the session from
# being closed if the during an intermittent network issue.
# When the websocket is closed, we wait session.ttl_seconds before
# closing the session. This is to prevent the session from being
# closed if the during an intermittent network issue.
def _close() -> None:
if self.status != ConnectionState.OPEN:
LOGGER.debug(
Expand All @@ -330,11 +330,13 @@ def _close() -> None:
self.manager.close_session(self.session_id)

session = self.manager.get_session(self.session_id)
cancellation_handle = asyncio.get_event_loop().call_later(
Session.TTL_SECONDS, _close
)
if session is not None:
cancellation_handle = asyncio.get_event_loop().call_later(
session.ttl_seconds, _close
)
self.cancel_close_handle = cancellation_handle
else:
_close()
else:
cleanup_fn()

Expand Down
1 change: 1 addition & 0 deletions marimo/_server/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def _create_app_for_file(base_url: str, file_path: str) -> ASGIApp:
cli_args={},
auth_token=auth_token,
redirect_console_to_browser=False,
ttl_seconds=None,
)
app = create_starlette_app(
base_url="",
Expand Down
1 change: 1 addition & 0 deletions marimo/_server/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def connection_state(self) -> ConnectionState:
user_config_manager=config_manager,
virtual_files_supported=False,
redirect_console_to_browser=False,
ttl_seconds=None,
)

# Run the notebook to completion once
Expand Down
15 changes: 12 additions & 3 deletions marimo/_server/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
from marimo._utils.typed_connection import TypedConnection

LOGGER = _loggers.marimo_logger()
SESSION_MANAGER: Optional["SessionManager"] = None


class QueueManager:
Expand Down Expand Up @@ -396,15 +395,16 @@ def close(self) -> None:
self.main_consumer = None


_DEFAULT_TTL_SECONDS = 120


class Session:
"""A client session.
Each session has its own Python kernel, for editing and running the app,
and its own websocket, for sending messages to the client.
"""

TTL_SECONDS = 120

@classmethod
def create(
cls,
Expand All @@ -416,6 +416,7 @@ def create(
user_config_manager: MarimoConfigReader,
virtual_files_supported: bool,
redirect_console_to_browser: bool,
ttl_seconds: Optional[int],
) -> Session:
"""
Create a new session.
Expand All @@ -438,6 +439,7 @@ def create(
queue_manager,
kernel_manager,
app_file_manager,
ttl_seconds,
)

def __init__(
Expand All @@ -447,6 +449,7 @@ def __init__(
queue_manager: QueueManager,
kernel_manager: KernelManager,
app_file_manager: AppFileManager,
ttl_seconds: Optional[int],
) -> None:
"""Initialize kernel and client connection to it."""
# This is some unique ID that we can use to identify the session
Expand All @@ -457,6 +460,9 @@ def __init__(
self.room = Room()
self._queue_manager = queue_manager
self.kernel_manager = kernel_manager
self.ttl_seconds = (
ttl_seconds if ttl_seconds is not None else _DEFAULT_TTL_SECONDS
)
self.session_view = SessionView()

self.kernel_manager.start_kernel()
Expand Down Expand Up @@ -673,13 +679,15 @@ def __init__(
cli_args: SerializedCLIArgs,
auth_token: Optional[AuthToken],
redirect_console_to_browser: bool,
ttl_seconds: Optional[int],
) -> None:
self.file_router = file_router
self.mode = mode
self.development_mode = development_mode
self.quiet = quiet
self.sessions: dict[SessionId, Session] = {}
self.include_code = include_code
self.ttl_seconds = ttl_seconds
self.lsp_server = lsp_server
self.watcher: Optional[FileWatcher] = None
self.recents = RecentFilesManager()
Expand Down Expand Up @@ -753,6 +761,7 @@ def create_session(
user_config_manager=self.user_config_manager,
virtual_files_supported=True,
redirect_console_to_browser=self.redirect_console_to_browser,
ttl_seconds=self.ttl_seconds,
)
return self.sessions[session_id]

Expand Down
2 changes: 2 additions & 0 deletions marimo/_server/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def start(
development_mode: bool,
quiet: bool,
include_code: bool,
ttl_seconds: Optional[int],
headless: bool,
port: Optional[int],
host: str,
Expand Down Expand Up @@ -108,6 +109,7 @@ def start(
development_mode=development_mode,
quiet=quiet,
include_code=include_code,
ttl_seconds=ttl_seconds,
lsp_server=LspServer(lsp_port),
user_config_manager=config_reader,
cli_args=cli_args,
Expand Down
1 change: 1 addition & 0 deletions tests/_server/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __():
cli_args={},
auth_token=AuthToken("fake-token"),
redirect_console_to_browser=False,
ttl_seconds=None,
)
sm.skew_protection_token = SkewProtectionToken("skew-id-1")
return sm
Expand Down
1 change: 1 addition & 0 deletions tests/_server/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def session_manager():
cli_args={},
auth_token=None,
redirect_console_to_browser=False,
ttl_seconds=None,
)


Expand Down
3 changes: 3 additions & 0 deletions tests/_server/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def test_session() -> None:
queue_manager,
kernel_manager,
AppFileManager.from_app(InternalApp(App())),
ttl_seconds=None,
)

# Assert startup
Expand Down Expand Up @@ -282,6 +283,7 @@ def test_session_disconnect_reconnect() -> None:
queue_manager,
kernel_manager,
AppFileManager.from_app(InternalApp(App())),
ttl_seconds=None,
)

# Assert startup
Expand Down Expand Up @@ -338,6 +340,7 @@ def test_session_with_kiosk_consumers() -> None:
queue_manager,
kernel_manager,
AppFileManager.from_app(InternalApp(App())),
ttl_seconds=None,
)

# Assert startup
Expand Down

0 comments on commit 864127c

Please sign in to comment.