Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle marimo[extras] in --sandbox and package installation #3425

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 48 additions & 15 deletions marimo/_cli/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def _read_pyproject(script: str) -> Dict[str, Any] | None:
return None


def _get_python_version_requirement(pyproject: Dict[str, Any]) -> str | None:
def _get_python_version_requirement(
pyproject: Dict[str, Any] | None,
) -> str | None:
"""Extract Python version requirement from pyproject metadata."""
if pyproject is None:
return None
Expand Down Expand Up @@ -203,6 +205,49 @@ def prompt_run_in_sandbox(name: str | None) -> bool:
return False


def _is_marimo_dependency(dependency: str) -> bool:
# Split on any version specifier
without_version = re.split(r"[=<>~]+", dependency)[0]
# Match marimo and marimo[extras], but not marimo-<something-else>
return without_version == "marimo" or without_version.startswith("marimo[")


def _is_versioned(dependency: str) -> bool:
return any(c in dependency for c in ("==", ">=", "<=", ">", "<", "~"))


def _normalize_sandbox_dependencies(
dependencies: List[str], marimo_version: str
) -> List[str]:
"""Normalize marimo dependencies to have only one version.

If multiple marimo dependencies exist, prefer the one with brackets.
Add version to the remaining one if not already versioned.
"""
# Find all marimo dependencies
marimo_deps = [d for d in dependencies if _is_marimo_dependency(d)]
if not marimo_deps:
# During development, you can comment this out to install an
# editable version of marimo assuming you are in the marimo directory
# DO NOT COMMIT THIS WHEN SUBMITTING PRs
# return dependencies + [f"marimo -e ."]

return dependencies + [f"marimo=={marimo_version}"]

# Prefer the one with brackets if it exists
bracketed = next((d for d in marimo_deps if "[" in d), None)
chosen = bracketed if bracketed else marimo_deps[0]

# Remove all marimo deps
filtered = [d for d in dependencies if not _is_marimo_dependency(d)]

# Add version if not already versioned
if not _is_versioned(chosen):
chosen = f"{chosen}=={marimo_version}"

return filtered + [chosen]


def run_in_sandbox(
args: List[str],
name: Optional[str] = None,
Expand All @@ -219,20 +264,8 @@ def run_in_sandbox(
get_dependencies_from_filename(name) if name is not None else []
)

# The sandbox needs to manage marimo, too, to make sure
# that the outer environment doesn't leak into the sandbox.
if "marimo" not in dependencies:
dependencies.append("marimo")

# Rename marimo to marimo=={__version__}
index_of_marimo = dependencies.index("marimo")
if index_of_marimo != -1:
dependencies[index_of_marimo] = f"marimo=={__version__}"

# During development, you can comment this out to install an
# editable version of marimo assuming you are in the marimo directory
# DO NOT COMMIT THIS WHEN SUBMITTING PRs
# dependencies[index_of_marimo] = "-e ."
# Normalize marimo dependencies
dependencies = _normalize_sandbox_dependencies(dependencies, __version__)

with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".txt"
Expand Down
13 changes: 11 additions & 2 deletions marimo/_runtime/packages/pypi_package_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ def update_notebook_script_metadata(
import_namespaces_to_add: Optional[List[str]] = None,
import_namespaces_to_remove: Optional[List[str]] = None,
) -> None:
"""Update the notebook's script metadata with the packages to add/remove.

Args:
filepath: Path to the notebook file
packages_to_add: List of packages to add to the script metadata
packages_to_remove: List of packages to remove from the script metadata
import_namespaces_to_add: List of import namespaces to add
import_namespaces_to_remove: List of import namespaces to remove
"""
packages_to_add = packages_to_add or []
packages_to_remove = packages_to_remove or []
import_namespaces_to_add = import_namespaces_to_add or []
Expand All @@ -152,8 +161,8 @@ def _is_installed(package: str) -> bool:
return without_brackets.lower() in version_map

def _maybe_add_version(package: str) -> str:
# Skip marimo
if package == "marimo":
# Skip marimo and marimo[<mod>], but not marimo-<something-else>
if package == "marimo" or package.startswith("marimo["):
return package
without_brackets = package.split("[")[0]
version = version_map.get(without_brackets.lower())
Expand Down
69 changes: 69 additions & 0 deletions tests/_cli/test_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from marimo._cli.sandbox import (
_get_dependencies,
_get_python_version_requirement,
_is_marimo_dependency,
_normalize_sandbox_dependencies,
_pyproject_toml_to_requirements_txt,
_read_pyproject,
get_dependencies_from_filename,
Expand Down Expand Up @@ -239,3 +241,70 @@ def test_get_dependencies_with_nonexistent_file():

# Test with None
assert get_dependencies_from_filename(None) == [] # type: ignore


def test_normalize_marimo_dependencies():
# Test adding marimo when not present
assert _normalize_sandbox_dependencies(["numpy"], "1.0.0") == [
"numpy",
"marimo==1.0.0",
]

# Test preferring bracketed version
assert _normalize_sandbox_dependencies(
["marimo", "marimo[extras]", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]==1.0.0"]

# Test keeping existing version with brackets
assert _normalize_sandbox_dependencies(
["marimo[extras]>=0.1.0", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]>=0.1.0"]

# Test adding version when none exists
assert _normalize_sandbox_dependencies(
Copy link

@lucharo lucharo Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be worthwhile to add a case for several marimo extra dependencies? What would the behaviour be in that case? or will that edge case never happen the way dependencies are added one by one?

extra test case:

    assert _normalize_sandbox_dependencies(
        ["marimo[extras]>=0.1.0", "marimo[sql]>=0.1.0", "numpy"], "1.0.0"
    ) == ["numpy", ???] 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to add multiple extras, you declare them as marimo[extra1,extra2]. I am not sure if uv or other packages managers will merge this for you so you don't get into a bad state like that

["marimo[extras]", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]==1.0.0"]

# Test keeping only one marimo dependency
assert _normalize_sandbox_dependencies(
["marimo>=0.1.0", "marimo[extras]>=0.2.0", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]>=0.2.0"]
assert _normalize_sandbox_dependencies(
["marimo", "marimo[extras]>=0.2.0", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]>=0.2.0"]

# Test various version specifiers are preserved
version_specs = [
"==0.1.0",
">=0.1.0",
"<=0.1.0",
">0.1.0",
"<0.1.0",
"~=0.1.0",
]
for spec in version_specs:
assert _normalize_sandbox_dependencies(
[f"marimo{spec}", "numpy"], "1.0.0"
) == ["numpy", f"marimo{spec}"]


def test_is_marimo_dependency():
assert _is_marimo_dependency("marimo")
assert _is_marimo_dependency("marimo[extras]")
assert not _is_marimo_dependency("marimo-extras")
assert not _is_marimo_dependency("marimo-ai")

# With version specifiers
assert _is_marimo_dependency("marimo==0.1.0")
assert _is_marimo_dependency("marimo[extras]>=0.1.0")
assert _is_marimo_dependency("marimo[extras]==0.1.0")
assert _is_marimo_dependency("marimo[extras]~=0.1.0")
assert _is_marimo_dependency("marimo[extras]<=0.1.0")
assert _is_marimo_dependency("marimo[extras]>=0.1.0")
assert _is_marimo_dependency("marimo[extras]<=0.1.0")

# With other packages
assert not _is_marimo_dependency("numpy")
assert not _is_marimo_dependency("pandas")
assert not _is_marimo_dependency("marimo-ai")
assert not _is_marimo_dependency("marimo-ai==0.1.0")
108 changes: 108 additions & 0 deletions tests/_runtime/packages/test_package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,111 @@ def _get_version_map(self) -> dict[str, str]:
"ibis-framework[duckdb]==2.0",
],
]


def test_update_script_metadata_marimo_packages() -> None:
runs_calls: list[list[str]] = []

class MockUvPackageManager(UvPackageManager):
def run(self, command: list[str]) -> bool:
runs_calls.append(command)
return True

def _get_version_map(self) -> dict[str, str]:
return {
"marimo": "0.1.0",
"marimo-ai": "0.2.0",
"pandas": "2.0.0",
}

pm = MockUvPackageManager()

# Test 1: Basic package handling
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo-ai", # Should have version (different package)
"pandas", # Should have version
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo-ai==0.2.0",
"pandas==2.0.0",
]
]
runs_calls.clear()

# Test 2: Marimo package consolidation - should prefer marimo[ai] over marimo
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo",
"marimo[sql]",
"pandas",
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo",
"marimo[sql]",
"pandas==2.0.0",
]
]
runs_calls.clear()

# Test 3: Multiple marimo extras - should use first one
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo",
"marimo[sql]",
"marimo[recommended]",
"pandas",
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo",
"marimo[sql]",
"marimo[recommended]",
"pandas==2.0.0",
]
]
runs_calls.clear()

# Test 4: Only plain marimo
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo",
"pandas",
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo",
"pandas==2.0.0",
]
]
runs_calls.clear()
Loading