Skip to content

Commit

Permalink
Refactor and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bennyweise committed Jan 19, 2025
1 parent e4bd7f8 commit 82b3b46
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
13 changes: 7 additions & 6 deletions marimo/_snippets/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Snippets:
async def read_snippets() -> Snippets:
snippets: List[Snippet] = []

for file in snippet_files():
for file in read_snippet_filenames_from_config():
app = get_app(file)
assert app is not None
sections: List[SnippetSection] = []
Expand Down Expand Up @@ -90,20 +90,21 @@ def is_markdown(code: str) -> bool:
return code.startswith("mo.md")


def snippet_files() -> Generator[str, Any, None]:
def read_snippet_filenames_from_config() -> Generator[str, Any, None]:
# Get custom snippets path from config if present
config = get_default_config_manager(current_path=None).get_config()
custom_paths = [
Path(p) for p in config.get("snippets", {}).get("custom_paths", [])
]
custom_paths = config.get("snippets", {}).get("custom_paths", [])
include_default_snippets = config.get("snippets", {}).get(
"include_default_snippets", True
)
return read_snippet_filenames(include_default_snippets, custom_paths)

def read_snippet_filenames(include_default_snippets: bool, custom_paths: List[str]) -> Generator[str, Any, None]:
paths = []
if include_default_snippets:
paths.append(import_files("marimo") / "_snippets" / "data")
if custom_paths:
paths.extend(custom_paths)
paths.extend([Path(p) for p in custom_paths])
for root_path in paths:
if not root_path.is_dir():
# Note: currently no handling of permissions errors, but theoretically
Expand Down
29 changes: 28 additions & 1 deletion tests/_snippets/test_snippets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from marimo._snippets.snippets import get_title_from_code, read_snippets
from marimo._snippets.snippets import (
get_title_from_code,
read_snippets,
read_snippet_filenames,
)
import pytest


async def test_snippets() -> None:
Expand Down Expand Up @@ -36,3 +41,25 @@ def test_get_title_from_code_with_multiple_titles() -> None:
def test_get_title_from_code_with_non_title_hashes() -> None:
code = "print('# This is not a title')"
assert get_title_from_code(code) == ""


@pytest.mark.parametrize(
("include_default_snippets", "custom_paths", "expected_snippets"),
[
(True, [], 38),
(False, [], 0),
(True, ["/notarealdirectory"], 38),
(False, ["/notarealdirectory"], 0),
(False, ["marimo/_snippets/data"], 38),
(False, ["marimo/_snippets/data", "/notarealdirectory"], 38),
],
)
def test_read_snippet_filenames(
include_default_snippets, custom_paths, expected_snippets
) -> None:
filenames = list(
read_snippet_filenames(include_default_snippets, custom_paths)
)
assert len(filenames) == expected_snippets
assert all(filename.endswith(".py") for filename in filenames)
assert all("_snippets/data" in filename for filename in filenames)

0 comments on commit 82b3b46

Please sign in to comment.