From 82b3b466863bfa325c39f22a7b27e6f2e19d68df Mon Sep 17 00:00:00 2001 From: Ben Weise Date: Sun, 19 Jan 2025 14:34:40 +1100 Subject: [PATCH] Refactor and add tests --- marimo/_snippets/snippets.py | 13 +++++++------ tests/_snippets/test_snippets.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/marimo/_snippets/snippets.py b/marimo/_snippets/snippets.py index ab7e3a0795b..cd9e835a679 100644 --- a/marimo/_snippets/snippets.py +++ b/marimo/_snippets/snippets.py @@ -33,7 +33,7 @@ class Snippets: async def read_snippets() -> Snippets: snippets: List[Snippet] = [] - for file in snippet_files(): + for file in read_snippet_filenames_from_config(): app = get_app(file) assert app is not None sections: List[SnippetSection] = [] @@ -90,20 +90,21 @@ def is_markdown(code: str) -> bool: return code.startswith("mo.md") -def snippet_files() -> Generator[str, Any, None]: +def read_snippet_filenames_from_config() -> Generator[str, Any, None]: # Get custom snippets path from config if present config = get_default_config_manager(current_path=None).get_config() - custom_paths = [ - Path(p) for p in config.get("snippets", {}).get("custom_paths", []) - ] + custom_paths = config.get("snippets", {}).get("custom_paths", []) include_default_snippets = config.get("snippets", {}).get( "include_default_snippets", True ) + return read_snippet_filenames(include_default_snippets, custom_paths) + +def read_snippet_filenames(include_default_snippets: bool, custom_paths: List[str]) -> Generator[str, Any, None]: paths = [] if include_default_snippets: paths.append(import_files("marimo") / "_snippets" / "data") if custom_paths: - paths.extend(custom_paths) + paths.extend([Path(p) for p in custom_paths]) for root_path in paths: if not root_path.is_dir(): # Note: currently no handling of permissions errors, but theoretically diff --git a/tests/_snippets/test_snippets.py b/tests/_snippets/test_snippets.py index d651e83ee85..de744412219 100644 --- a/tests/_snippets/test_snippets.py +++ b/tests/_snippets/test_snippets.py @@ -1,4 +1,9 @@ -from marimo._snippets.snippets import get_title_from_code, read_snippets +from marimo._snippets.snippets import ( + get_title_from_code, + read_snippets, + read_snippet_filenames, +) +import pytest async def test_snippets() -> None: @@ -36,3 +41,25 @@ def test_get_title_from_code_with_multiple_titles() -> None: def test_get_title_from_code_with_non_title_hashes() -> None: code = "print('# This is not a title')" assert get_title_from_code(code) == "" + + +@pytest.mark.parametrize( + ("include_default_snippets", "custom_paths", "expected_snippets"), + [ + (True, [], 38), + (False, [], 0), + (True, ["/notarealdirectory"], 38), + (False, ["/notarealdirectory"], 0), + (False, ["marimo/_snippets/data"], 38), + (False, ["marimo/_snippets/data", "/notarealdirectory"], 38), + ], +) +def test_read_snippet_filenames( + include_default_snippets, custom_paths, expected_snippets +) -> None: + filenames = list( + read_snippet_filenames(include_default_snippets, custom_paths) + ) + assert len(filenames) == expected_snippets + assert all(filename.endswith(".py") for filename in filenames) + assert all("_snippets/data" in filename for filename in filenames)