diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0cc14ec5079..9cf3de3a459 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: - id: typos - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.1 hooks: # Run the linter - id: ruff diff --git a/marimo/_ai/llm.py b/marimo/_ai/llm.py index 6e6e686aab5..f10b947bccf 100644 --- a/marimo/_ai/llm.py +++ b/marimo/_ai/llm.py @@ -128,7 +128,7 @@ def __call__( client: AzureOpenAI | OpenAI = AzureOpenAI( api_key=self._require_api_key, api_version=api_version, - azure_endpoint=f"{cast(str,parsed_url.scheme)}://{cast(str,parsed_url.hostname)}", + azure_endpoint=f"{cast(str, parsed_url.scheme)}://{cast(str, parsed_url.hostname)}", ) else: client = OpenAI( diff --git a/marimo/_cli/cli.py b/marimo/_cli/cli.py index c7910b27adb..51db340551a 100644 --- a/marimo/_cli/cli.py +++ b/marimo/_cli/cli.py @@ -540,9 +540,7 @@ def new( default=120, show_default=True, type=int, - help=( - "Seconds to wait before closing a session on " "websocket disconnect." - ), + help=("Seconds to wait before closing a session on websocket disconnect."), ) @click.option( "--watch", diff --git a/marimo/_plugins/stateless/tree.py b/marimo/_plugins/stateless/tree.py index f89c901f166..276413e17aa 100644 --- a/marimo/_plugins/stateless/tree.py +++ b/marimo/_plugins/stateless/tree.py @@ -20,7 +20,9 @@ def tree( Example: ```python3 - mo.tree(["entry", "another entry", {"key": [0, 1, 2]}], label="A tree.") + mo.tree( + ["entry", "another entry", {"key": [0, 1, 2]}], label="A tree." + ) ``` Args: diff --git a/marimo/_plugins/ui/_impl/altair_chart.py b/marimo/_plugins/ui/_impl/altair_chart.py index 4b459f5318b..d7955519e66 100644 --- a/marimo/_plugins/ui/_impl/altair_chart.py +++ b/marimo/_plugins/ui/_impl/altair_chart.py @@ -296,8 +296,7 @@ def __init__( if not isinstance(chart, (alt.TopLevelMixin)): raise ValueError( - "Invalid type for chart: " - f"{type(chart)}; expected altair.Chart" + f"Invalid type for chart: {type(chart)}; expected altair.Chart" ) # Make full-width if no width is specified diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py index 6400a91c6f8..a1e1e040eff 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py @@ -21,7 +21,7 @@ def python_print_transforms( strs: List[str] = [] for transform in transforms: strs.append( - f"{df_next_name} = {print_transform(df_next_name,all_columns, transform)}" # noqa: E501 + f"{df_next_name} = {print_transform(df_next_name, all_columns, transform)}" # noqa: E501 ) return "\n".join([f"{df_next_name} = {df_name}"] + strs) @@ -128,7 +128,7 @@ def generate_where_clause(df_name: str, where: Condition) -> str: ) if not column_ids: return f"{df_name}.agg({_list_of_strings(aggregations)})" - return f'{df_name}.agg({{{", ".join(f"{_as_literal(column_id)}: {_list_of_strings(aggregations)}" for column_id in column_ids)}}})' # noqa: E501 + return f"{df_name}.agg({{{', '.join(f'{_as_literal(column_id)}: {_list_of_strings(aggregations)}' for column_id in column_ids)}}})" # noqa: E501 elif transform.type == TransformType.GROUP_BY: column_ids, aggregation, drop_na = ( @@ -480,7 +480,7 @@ def _as_literal(value: Any) -> str: def _list_of_strings(value: Union[List[Any], Any]) -> str: if isinstance(value, list): - return f'[{", ".join(_as_literal(v) for v in value)}]' + return f"[{', '.join(_as_literal(v) for v in value)}]" return _as_literal(value) diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/types.py b/marimo/_plugins/ui/_impl/dataframes/transforms/types.py index 1e791975c88..1cf52d3e43c 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/types.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/types.py @@ -75,9 +75,9 @@ def __hash__(self) -> int: def __post_init__(self) -> None: if self.operator == "in": - assert isinstance( - self.value, list - ), "value must be a list for 'in' operator" + assert isinstance(self.value, list), ( + "value must be a list for 'in' operator" + ) @dataclass diff --git a/marimo/_runtime/app_meta.py b/marimo/_runtime/app_meta.py index 2e39a8c20a7..63e94041abc 100644 --- a/marimo/_runtime/app_meta.py +++ b/marimo/_runtime/app_meta.py @@ -62,7 +62,9 @@ def mode(self) -> Optional[RunMode]: ```python # Only show this content when editing the notebook - mo.md("# Developer Notes") if mo.app_meta().mode == "edit" else None + mo.md( + "# Developer Notes" + ) if mo.app_meta().mode == "edit" else None ``` Returns: diff --git a/marimo/_runtime/requests.py b/marimo/_runtime/requests.py index c31d90dd7aa..d2396e3029e 100644 --- a/marimo/_runtime/requests.py +++ b/marimo/_runtime/requests.py @@ -51,9 +51,9 @@ def execution_requests(self) -> List[ExecutionRequest]: ] def __post_init__(self) -> None: - assert len(self.cell_ids) == len( - self.codes - ), "Mismatched cell_ids and codes" + assert len(self.cell_ids) == len(self.codes), ( + "Mismatched cell_ids and codes" + ) @dataclass @@ -74,9 +74,9 @@ class SetUIElementValueRequest: token: str = field(default_factory=lambda: str(uuid4())) def __post_init__(self) -> None: - assert len(self.object_ids) == len( - self.values - ), "Mismatched object_ids and values" + assert len(self.object_ids) == len(self.values), ( + "Mismatched object_ids and values" + ) @staticmethod def from_ids_and_values( diff --git a/marimo/_runtime/runner/hooks_on_finish.py b/marimo/_runtime/runner/hooks_on_finish.py index 9a6b86c0d9c..3c13ef271d9 100644 --- a/marimo/_runtime/runner/hooks_on_finish.py +++ b/marimo/_runtime/runner/hooks_on_finish.py @@ -74,8 +74,7 @@ def _send_cancellation_errors(runner: cell_runner.Runner) -> None: exception_type = type(runner.exceptions[raising_cell]).__name__ data = MarimoExceptionRaisedError( msg=( - "An ancestor raised an exception " - f"({exception_type}): " + f"An ancestor raised an exception ({exception_type}): " ), exception_type=exception_type, raising_cell=raising_cell, diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 385d317eccc..27030a0ff67 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -250,7 +250,9 @@ def app_meta() -> AppMeta: import altair as alt # Enable dark theme for Altair when marimo is in dark mode - alt.themes.enable("dark" if mo.app_meta().theme == "dark" else "default") + alt.themes.enable( + "dark" if mo.app_meta().theme == "dark" else "default" + ) ``` Show content only in edit mode: diff --git a/marimo/_save/ast.py b/marimo/_save/ast.py index b602ab86363..2a7f954836c 100644 --- a/marimo/_save/ast.py +++ b/marimo/_save/ast.py @@ -159,9 +159,9 @@ def strip_function(fn: Callable[..., Any]) -> ast.Module: code, _ = inspect.getsourcelines(fn) function_ast = ast.parse(textwrap.dedent("".join(code))) body = function_ast.body.pop() - assert isinstance( - body, (ast.FunctionDef, ast.AsyncFunctionDef) - ), "Expected a function definition" + assert isinstance(body, (ast.FunctionDef, ast.AsyncFunctionDef)), ( + "Expected a function definition" + ) extracted = ast.Module(body.body, type_ignores=[]) module = RemoveReturns().visit(extracted) assert isinstance(module, ast.Module), "Expected a module" diff --git a/marimo/_save/hash.py b/marimo/_save/hash.py index 5326e451459..f7c0fdbfbb5 100644 --- a/marimo/_save/hash.py +++ b/marimo/_save/hash.py @@ -322,9 +322,9 @@ def __init__( if not scoped_refs: scoped_refs = set() else: - assert ( - not apply_content_hash - ), "scoped_refs should only be used with deferred hashing." + assert not apply_content_hash, ( + "scoped_refs should only be used with deferred hashing." + ) self._hash: Optional[str] = None self.graph = graph diff --git a/marimo/_save/loaders/loader.py b/marimo/_save/loaders/loader.py index 0c3d36447e8..03e92ce1475 100644 --- a/marimo/_save/loaders/loader.py +++ b/marimo/_save/loaders/loader.py @@ -57,9 +57,9 @@ def cache_attempt( loaded = self.load_cache(hashed_context, cache_type) # TODO: Consider more robust verification assert loaded.hash == hashed_context, INCONSISTENT_CACHE_BOILER_PLATE - assert set(defs | stateful_refs) == set( - loaded.defs - ), INCONSISTENT_CACHE_BOILER_PLATE + assert set(defs | stateful_refs) == set(loaded.defs), ( + INCONSISTENT_CACHE_BOILER_PLATE + ) return Cache( loaded.defs, hashed_context, diff --git a/marimo/_save/loaders/memory.py b/marimo/_save/loaders/memory.py index c4b19ee52b5..2e576d1829e 100644 --- a/marimo/_save/loaders/memory.py +++ b/marimo/_save/loaders/memory.py @@ -55,9 +55,9 @@ def cache_hit(self, hashed_context: str, cache_type: CacheType) -> bool: return self._maybe_lock(lambda: key in self._cache) def load_cache(self, hashed_context: str, cache_type: CacheType) -> Cache: - assert self.cache_hit( - hashed_context, cache_type - ), INCONSISTENT_CACHE_BOILER_PLATE + assert self.cache_hit(hashed_context, cache_type), ( + INCONSISTENT_CACHE_BOILER_PLATE + ) self.hits += 1 key = self.build_path(hashed_context, cache_type) if self.is_lru: diff --git a/marimo/_save/loaders/pickle.py b/marimo/_save/loaders/pickle.py index acab1fe1fec..a81bc100071 100644 --- a/marimo/_save/loaders/pickle.py +++ b/marimo/_save/loaders/pickle.py @@ -27,13 +27,13 @@ def cache_hit(self, hashed_context: str, cache_type: CacheType) -> bool: return os.path.exists(path) and os.path.getsize(path) > 0 def load_cache(self, hashed_context: str, cache_type: CacheType) -> Cache: - assert self.cache_hit( - hashed_context, cache_type - ), INCONSISTENT_CACHE_BOILER_PLATE + assert self.cache_hit(hashed_context, cache_type), ( + INCONSISTENT_CACHE_BOILER_PLATE + ) with open(self.build_path(hashed_context, cache_type), "rb") as handle: cache = pickle.load(handle) assert isinstance(cache, Cache), ( - "Excepted cache object, got" f"{type(cache)} ", + f"Excepted cache object, got{type(cache)} ", INCONSISTENT_CACHE_BOILER_PLATE, ) return cache diff --git a/marimo/_save/save.py b/marimo/_save/save.py index 9f20d5a64ea..e9004a05e06 100644 --- a/marimo/_save/save.py +++ b/marimo/_save/save.py @@ -480,7 +480,7 @@ def __exit__( sys.settrace(self._old_trace) # Clear to previous set trace. if not self._entered_trace: raise CacheException( - ("Unexpected block format" f"{UNEXPECTED_FAILURE_BOILERPLATE}") + (f"Unexpected block format{UNEXPECTED_FAILURE_BOILERPLATE}") ) # Backfill the loaded values into global scope. @@ -493,8 +493,7 @@ def __exit__( # NB: exception is a type. if exception: assert not isinstance(instance, SkipWithBlock), ( - "Cache was not correctly set" - f"{UNEXPECTED_FAILURE_BOILERPLATE}" + f"Cache was not correctly set{UNEXPECTED_FAILURE_BOILERPLATE}" ) if isinstance(instance, BaseException): raise instance from CacheException("Failure during save.") diff --git a/marimo/_server/ai/prompts.py b/marimo/_server/ai/prompts.py index c63f617ba2b..bb4c249e4fb 100644 --- a/marimo/_server/ai/prompts.py +++ b/marimo/_server/ai/prompts.py @@ -55,7 +55,7 @@ def get_system_prompt( "Do not output markdown or backticks.", ] + language_rules[language] rules = "\n".join( - f"{i+1}. {rule}" for i, rule in enumerate(all_rules) + f"{i + 1}. {rule}" for i, rule in enumerate(all_rules) ) system_prompt = ( f"You are a helpful assistant that can answer questions about {language}." @@ -109,7 +109,7 @@ def get_chat_system_prompt( continue rules = "\n".join( - f"{i+1}. {rule}" + f"{i + 1}. {rule}" for i, rule in enumerate(language_rules[language]) ) diff --git a/marimo/_server/api/endpoints/ai.py b/marimo/_server/api/endpoints/ai.py index d8f0b22c627..56e2fec24a1 100644 --- a/marimo/_server/api/endpoints/ai.py +++ b/marimo/_server/api/endpoints/ai.py @@ -107,7 +107,7 @@ def get_openai_client(config: MarimoConfig) -> "OpenAI": api_key=key, api_version=api_version, azure_deployment=deployment_model, - azure_endpoint=f"{cast(str,parsed_url.scheme)}://{cast(str,parsed_url.hostname)}", + azure_endpoint=f"{cast(str, parsed_url.scheme)}://{cast(str, parsed_url.hostname)}", ) else: return OpenAI( diff --git a/marimo/_server/api/middleware.py b/marimo/_server/api/middleware.py index 47cc4c2f1d4..d8dcc029c6e 100644 --- a/marimo/_server/api/middleware.py +++ b/marimo/_server/api/middleware.py @@ -375,7 +375,7 @@ async def _proxy_websocket( websocket = WebSocket(scope, receive=receive, send=send) original_params = websocket.query_params if original_params: - ws_url = f"{ws_url}?{'&'.join(f'{k}={v}' for k,v in original_params.items())}" + ws_url = f"{ws_url}?{'&'.join(f'{k}={v}' for k, v in original_params.items())}" await websocket.accept() async with connect(ws_url) as ws_client: diff --git a/marimo/_server/models/models.py b/marimo/_server/models/models.py index 90816d5bc2e..7c9e4dbe81a 100644 --- a/marimo/_server/models/models.py +++ b/marimo/_server/models/models.py @@ -28,9 +28,9 @@ def zip( # Validate same length def __post_init__(self) -> None: - assert len(self.object_ids) == len( - self.values - ), "Mismatched object_ids and values" + assert len(self.object_ids) == len(self.values), ( + "Mismatched object_ids and values" + ) @dataclass @@ -89,9 +89,9 @@ def as_execution_request(self) -> ExecuteMultipleRequest: # Validate same length def __post_init__(self) -> None: - assert len(self.cell_ids) == len( - self.codes - ), "Mismatched cell_ids and codes" + assert len(self.cell_ids) == len(self.codes), ( + "Mismatched cell_ids and codes" + ) @dataclass @@ -121,15 +121,15 @@ class SaveNotebookRequest: # Validate same length def __post_init__(self) -> None: - assert len(self.cell_ids) == len( - self.codes - ), "Mismatched cell_ids and codes" - assert len(self.cell_ids) == len( - self.names - ), "Mismatched cell_ids and names" - assert len(self.cell_ids) == len( - self.configs - ), "Mismatched cell_ids and configs" + assert len(self.cell_ids) == len(self.codes), ( + "Mismatched cell_ids and codes" + ) + assert len(self.cell_ids) == len(self.names), ( + "Mismatched cell_ids and names" + ) + assert len(self.cell_ids) == len(self.configs), ( + "Mismatched cell_ids and configs" + ) @dataclass @@ -147,9 +147,9 @@ def __post_init__(self) -> None: f'File "{self.source}" does not exist.' + "Please save the notebook and try again." ) - assert not os.path.exists( - self.destination - ), f'File "{destination}" already exists in this directory.' + assert not os.path.exists(self.destination), ( + f'File "{destination}" already exists in this directory.' + ) @dataclass diff --git a/marimo/_server/router.py b/marimo/_server/router.py index 779ed6292ef..f269820af59 100644 --- a/marimo/_server/router.py +++ b/marimo/_server/router.py @@ -37,12 +37,12 @@ def __init__(self, prefix: str = "") -> None: def __post_init__(self) -> None: if self.prefix: - assert self.prefix.startswith( - "/" - ), "Path prefix must start with '/'" - assert not self.prefix.endswith( - "/" - ), "Path prefix must not end with '/'" + assert self.prefix.startswith("/"), ( + "Path prefix must start with '/'" + ) + assert not self.prefix.endswith("/"), ( + "Path prefix must not end with '/'" + ) def post( self, path: str diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 11ca538d2b6..05f0cd0a2a5 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -365,9 +365,9 @@ def add_consumer( self.consumers[consumer] = consumer_id self.disposables[consumer] = dispose if main: - assert ( - self.main_consumer is None - ), "Main session consumer already exists" + assert self.main_consumer is None, ( + "Main session consumer already exists" + ) self.main_consumer = consumer def remove_consumer(self, consumer: SessionConsumer) -> None: diff --git a/marimo/_server/templates/templates.py b/marimo/_server/templates/templates.py index 09ab36514fa..4e2c72d2555 100644 --- a/marimo/_server/templates/templates.py +++ b/marimo/_server/templates/templates.py @@ -155,17 +155,23 @@ def static_notebook_template( diff --git a/tests/_cli/test_cli_development.py b/tests/_cli/test_cli_development.py index 72aaf9fc120..e06015c7a67 100644 --- a/tests/_cli/test_cli_development.py +++ b/tests/_cli/test_cli_development.py @@ -33,6 +33,6 @@ def test_openapi_up_to_date() -> None: del generated_content["info"]["version"] cmd = "marimo development openapi > openapi/api.yaml && make fe-codegen" - assert ( - current_content == generated_content - ), f"openapi/api.yaml is not up to date. Run '{cmd}' to update it." + assert current_content == generated_content, ( + f"openapi/api.yaml is not up to date. Run '{cmd}' to update it." + ) diff --git a/tests/_cli/test_cli_export.py b/tests/_cli/test_cli_export.py index 1f479382da8..51f31cea190 100644 --- a/tests/_cli/test_cli_export.py +++ b/tests/_cli/test_cli_export.py @@ -450,9 +450,9 @@ def test_export_script_with_multiple_definitions( ], capture_output=True, ) - assert ( - p.returncode != 0 - ), "Expected non-zero return code due to multiple definitions" + assert p.returncode != 0, ( + "Expected non-zero return code due to multiple definitions" + ) error_message = p.stderr.decode() assert ( "MultipleDefinitionError: This app can't be run because it has multiple definitions of the name x" diff --git a/tests/_convert/test_ipynb.py b/tests/_convert/test_ipynb.py index 8a3a0ffad65..adf39460c0a 100644 --- a/tests/_convert/test_ipynb.py +++ b/tests/_convert/test_ipynb.py @@ -250,21 +250,18 @@ def test_transform_magic_commands_complex(): "%env MY_VAR=value", ] result = transform_magic_commands(sources) - assert ( - result - == [ - '_df = mo.sql("""\nSELECT *\nFROM table\nWHERE condition\n""")', - ( - "# magic command not supported in marimo; please file an issue to add support\n" # noqa: E501 - "# %%time\nfor i in range(1000000):\n" - " pass" - ), - ( - "# '%load_ext autoreload\\n%autoreload 2' command supported automatically in marimo" # noqa: E501 - ), - "import os\nos.environ['MY_VAR'] = 'value'", - ] - ) + assert result == [ + '_df = mo.sql("""\nSELECT *\nFROM table\nWHERE condition\n""")', + ( + "# magic command not supported in marimo; please file an issue to add support\n" # noqa: E501 + "# %%time\nfor i in range(1000000):\n" + " pass" + ), + ( + "# '%load_ext autoreload\\n%autoreload 2' command supported automatically in marimo" # noqa: E501 + ), + "import os\nos.environ['MY_VAR'] = 'value'", + ] @pytest.mark.skipif( diff --git a/tests/_messaging/test_run_id_context.py b/tests/_messaging/test_run_id_context.py index 348727e9aa3..f12b70c9b5b 100644 --- a/tests/_messaging/test_run_id_context.py +++ b/tests/_messaging/test_run_id_context.py @@ -7,9 +7,9 @@ class TestRunIDContext: def test_run_id_context(self): with run_id_context(): run_id = RUN_ID_CTX.get() - assert ( - run_id is not None - ), "within run_id context but unable to obtain run_id" + assert run_id is not None, ( + "within run_id context but unable to obtain run_id" + ) # out of context manager with pytest.raises(LookupError): diff --git a/tests/_plugins/stateless/test_routes.py b/tests/_plugins/stateless/test_routes.py index a2d9d50e89e..9f0a8638384 100644 --- a/tests/_plugins/stateless/test_routes.py +++ b/tests/_plugins/stateless/test_routes.py @@ -75,7 +75,7 @@ async def test_routes_non_lazy(k: Kernel, exec_req: ExecReqProvider) -> None: assert len(context.function_registry.namespaces.values()) == 0 routes = k.globals["routes"] - children = "42" "45" + children = "4245" assert children in routes.text assert ( "data-routes='["#/", "{/*path}"]'" in routes.text diff --git a/tests/_plugins/ui/_impl/tables/test_narwhals.py b/tests/_plugins/ui/_impl/tables/test_narwhals.py index c987ae741db..5cd187c213f 100644 --- a/tests/_plugins/ui/_impl/tables/test_narwhals.py +++ b/tests/_plugins/ui/_impl/tables/test_narwhals.py @@ -839,6 +839,6 @@ def test_get_field_types_with_many_columns_is_performant(df: Any) -> None: # This can be slow if get_field_types is not optimized. # https://github.com/marimo-team/marimo/issues/3107 total_ms = (end_time - start_time) * 1000 - assert ( - total_ms < 500 - ), f"Total time: {total_ms}ms for {df.shape[1]} columns with {type(df)}" + assert total_ms < 500, ( + f"Total time: {total_ms}ms for {df.shape[1]} columns with {type(df)}" + ) diff --git a/tests/_runtime/test_trace.py b/tests/_runtime/test_trace.py index f8e8d11657e..c1625c5868f 100644 --- a/tests/_runtime/test_trace.py +++ b/tests/_runtime/test_trace.py @@ -47,7 +47,7 @@ def test_script_trace_with_output() -> None: result = p.stderr.decode() assert "ZeroDivisionError: division by zero" in result - assert ('script_exception_with_output.py"' ", line 11") in result + assert ('script_exception_with_output.py", line 11') in result assert "y / x" in result @staticmethod diff --git a/tests/_save/test_cache.py b/tests/_save/test_cache.py index fe90abc9c39..297d49adc26 100644 --- a/tests/_save/test_cache.py +++ b/tests/_save/test_cache.py @@ -770,9 +770,9 @@ def __(f, ns): f.base_block.execution_refs, f.base_block.content_refs, ) - assert f.base_block.context_refs == { - "ns" - }, f.base_block.context_refs + assert f.base_block.context_refs == {"ns"}, ( + f.base_block.context_refs + ) return app.run() @@ -851,9 +851,9 @@ def __(f, ns): assert f.hits == 0 assert f() == 1 assert f.hits == 1 - assert ( - f.base_block.execution_refs == set() - ), f.base_block.execution_refs + assert f.base_block.execution_refs == set(), ( + f.base_block.execution_refs + ) assert f.base_block.missing == {"ns"}, f.base_block.missing return diff --git a/tests/_save/test_hash.py b/tests/_save/test_hash.py index 2e35bcd8e53..4ede3813aeb 100644 --- a/tests/_save/test_hash.py +++ b/tests/_save/test_hash.py @@ -410,9 +410,9 @@ def one(MockLoader, persistent_cache, expected_hash, np) -> tuple[int]: one = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) @app.cell @@ -424,9 +424,9 @@ def two(MockLoader, persistent_cache, expected_hash, np) -> tuple[int]: two = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (two,) @app.cell @@ -468,9 +468,9 @@ def one(MockLoader, persistent_cache, expected_hash, np) -> tuple[int]: one = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) @app.cell @@ -482,9 +482,9 @@ def two(MockLoader, persistent_cache, expected_hash, np) -> tuple[int]: two = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (two,) @app.cell @@ -528,9 +528,9 @@ def one( one = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) @app.cell @@ -544,9 +544,9 @@ def two( two = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (two,) @app.cell @@ -591,12 +591,12 @@ def one( _A = torch.sum(_a) one = _A - assert ( - _cache._cache.cache_type == "ContextExecutionPath" - ), _cache._cache.cache_type - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.cache_type == "ContextExecutionPath", ( + _cache._cache.cache_type + ) + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) app.run() @@ -638,9 +638,9 @@ def one( assert strand == sibling assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) app.run() @@ -685,9 +685,9 @@ def one( one = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) @app.cell @@ -702,9 +702,9 @@ def two( two = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (two,) @app.cell @@ -754,9 +754,9 @@ def one( one = _A assert _cache._cache.cache_type == "ContextExecutionPath" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) app.run() @@ -798,9 +798,9 @@ def one(MockLoader, persistent_cache, expected_hash, pl) -> tuple[int]: one = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (one,) @app.cell @@ -813,9 +813,9 @@ def two(MockLoader, persistent_cache, expected_hash, pl) -> tuple[int]: two = _A assert _cache._cache.cache_type == "ContentAddressed" - assert ( - _cache._cache.hash == expected_hash - ), f"expected_hash != {_cache._cache.hash}" + assert _cache._cache.hash == expected_hash, ( + f"expected_hash != {_cache._cache.hash}" + ) return (two,) @app.cell diff --git a/tests/_server/test_utils.py b/tests/_server/test_utils.py index b551c43f9a0..e284b7883f4 100644 --- a/tests/_server/test_utils.py +++ b/tests/_server/test_utils.py @@ -7,9 +7,9 @@ def test_require_header() -> None: # Happy path header = ["Content-Type"] - assert ( - require_header(header) == "Content-Type" - ), "The function should return the single header value" + assert require_header(header) == "Content-Type", ( + "The function should return the single header value" + ) with pytest.raises(ValueError) as e: require_header(None) diff --git a/tests/_smoke_tests/run_all.py b/tests/_smoke_tests/run_all.py index 111fea9f46e..161156a7760 100644 --- a/tests/_smoke_tests/run_all.py +++ b/tests/_smoke_tests/run_all.py @@ -83,25 +83,25 @@ async def _run_test( if failed_reason: # Expecting an error - assert ( - process.returncode != 0 - ), f"{relative_file} Expected error: {failed_reason}" + assert process.returncode != 0, ( + f"{relative_file} Expected error: {failed_reason}" + ) if isinstance(failed_reason, list): - assert any( - reason in stderr for reason in failed_reason - ), f"File: {file}. Expected error one of {failed_reason} in {stderr}" # noqa: E501 + assert any(reason in stderr for reason in failed_reason), ( + f"File: {file}. Expected error one of {failed_reason} in {stderr}" + ) # noqa: E501 else: assert failed_reason in stderr, f"File: {file}" # Allow MarimoStop elif "MarimoStop" in stderr: - assert ( - process.returncode != 0 - ), f"{relative_file} Unexpected error: {stderr}" + assert process.returncode != 0, ( + f"{relative_file} Unexpected error: {stderr}" + ) else: # Expecting no error, allow MarimoStop - assert ( - process.returncode == 0 - ), f"{relative_file} Unexpected error: {stderr}" + assert process.returncode == 0, ( + f"{relative_file} Unexpected error: {stderr}" + ) assert not any( line.startswith("Traceback") for line in stderr.splitlines() )