diff --git a/marimo/_plugins/ui/_impl/tables/narwhals_table.py b/marimo/_plugins/ui/_impl/tables/narwhals_table.py index db3150c5199..c1195dc727a 100644 --- a/marimo/_plugins/ui/_impl/tables/narwhals_table.py +++ b/marimo/_plugins/ui/_impl/tables/narwhals_table.py @@ -119,6 +119,8 @@ def get_field_type( return ("integer", dtype_string) elif is_narwhals_temporal_type(dtype): return ("date", dtype_string) + elif dtype == nw.Duration: + return ("number", dtype_string) elif dtype.is_numeric(): return ("number", dtype_string) else: @@ -147,6 +149,7 @@ def search(self, query: str) -> TableManager[Any]: elif ( dtype.is_numeric() or is_narwhals_temporal_type(dtype) + or dtype == nw.Duration or dtype == nw.Boolean ): expressions.append( @@ -212,6 +215,22 @@ def _get_summary_internal(self, column: str) -> ColumnSummary: p75=col.quantile(0.75, interpolation="nearest"), p95=col.quantile(0.95, interpolation="nearest"), ) + if col.dtype == nw.Duration and isinstance(col.dtype, nw.Duration): + unit_map = { + "ms": (col.dt.total_milliseconds, "ms"), + "ns": (col.dt.total_nanoseconds, "ns"), + "us": (col.dt.total_microseconds, "μs"), + "s": (col.dt.total_seconds, "s"), + } + method, unit = unit_map[col.dtype.time_unit] + res = method() + return ColumnSummary( + total=total, + nulls=col.null_count(), + min=str(res.min()) + unit, + max=str(res.max()) + unit, + mean=str(res.mean()) + unit, + ) if ( col.dtype == nw.List or col.dtype == nw.Struct diff --git a/marimo/_plugins/ui/_impl/tables/polars_table.py b/marimo/_plugins/ui/_impl/tables/polars_table.py index 06991042e12..56710a4e045 100644 --- a/marimo/_plugins/ui/_impl/tables/polars_table.py +++ b/marimo/_plugins/ui/_impl/tables/polars_table.py @@ -94,19 +94,15 @@ def to_csv( ) ) elif isinstance(dtype, pl.Duration): - if dtype.time_unit == "ms": - result = result.with_columns( - column.dt.total_milliseconds() - ) - - elif dtype.time_unit == "ns": - result = result.with_columns( - column.dt.total_nanoseconds() - ) - elif dtype.time_unit == "us": - result = result.with_columns( - column.dt.total_microseconds() - ) + unit_map = { + "ms": column.dt.total_milliseconds, + "ns": column.dt.total_nanoseconds, + "us": column.dt.total_microseconds, + "s": column.dt.total_seconds, + } + if dtype.time_unit in unit_map: + method = unit_map[dtype.time_unit] + result = result.with_columns(method()) return result.write_csv().encode("utf-8") def to_json(self) -> bytes: @@ -198,6 +194,8 @@ def get_field_type( return ("date", dtype_string) elif dtype == pl.Time: return ("time", dtype_string) + elif dtype == pl.Duration: + return ("number", dtype_string) elif dtype == pl.Datetime: return ("datetime", dtype_string) elif dtype.is_temporal(): diff --git a/marimo/_smoke_tests/tables/polars_duration.py b/marimo/_smoke_tests/tables/polars_duration.py new file mode 100644 index 00000000000..da394b96971 --- /dev/null +++ b/marimo/_smoke_tests/tables/polars_duration.py @@ -0,0 +1,57 @@ +import marimo + +__generated_with = "0.10.6" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import marimo as mo + import polars as pl + return mo, pl + + +@app.cell +def _(mo): + mo.md(r"""## Polars""") + return + + +@app.cell +def _(pl): + df = pl.read_csv( + "https://raw.githubusercontent.com/vega/vega-datasets/refs/heads/main/data/co2-concentration.csv" + ) + df = df.with_columns( + pl.col("CO2").cast(pl.Duration), + ) + df + return (df,) + + +@app.cell +def _(df, mo): + mo.plain(df) + return + + +@app.cell +def _(mo): + mo.md(r"""## Pandas""") + return + + +@app.cell +def _(df): + df.to_pandas() + return + + +@app.cell +def _(df, mo): + mo.plain(df.to_pandas()) + return + + +if __name__ == "__main__": + app.run() diff --git a/marimo/_utils/narwhals_utils.py b/marimo/_utils/narwhals_utils.py index 60ef453ca63..a17fc94f90f 100644 --- a/marimo/_utils/narwhals_utils.py +++ b/marimo/_utils/narwhals_utils.py @@ -111,13 +111,11 @@ def is_narwhals_integer_type( def is_narwhals_temporal_type( dtype: Any, -) -> TypeGuard[nw.Datetime | nw.Date | nw.Duration | nw.Duration]: +) -> TypeGuard[nw.Datetime | nw.Date]: """ Check if the given dtype is temporal type. """ - return bool( - dtype == nw.Datetime or dtype == nw.Date or dtype == nw.Duration - ) + return bool(dtype == nw.Datetime or dtype == nw.Date) def is_narwhals_string_type( diff --git a/tests/_plugins/ui/_impl/tables/snapshots/narwhals.field_types.json b/tests/_plugins/ui/_impl/tables/snapshots/narwhals.field_types.json index 191b2f97e0c..18214c5655a 100644 --- a/tests/_plugins/ui/_impl/tables/snapshots/narwhals.field_types.json +++ b/tests/_plugins/ui/_impl/tables/snapshots/narwhals.field_types.json @@ -13,6 +13,6 @@ ["set", ["unknown", "Object"]], ["imaginary", ["unknown", "Object"]], ["time", ["unknown", "Unknown"]], - ["duration", ["date", "Duration(time_unit='us')"]], + ["duration", ["number", "Duration(time_unit='us')"]], ["mixed_list", ["unknown", "List(String)"]] ] diff --git a/tests/_plugins/ui/_impl/tables/snapshots/polars.field_types.json b/tests/_plugins/ui/_impl/tables/snapshots/polars.field_types.json index 2998e8b4313..f8107f389a6 100644 --- a/tests/_plugins/ui/_impl/tables/snapshots/polars.field_types.json +++ b/tests/_plugins/ui/_impl/tables/snapshots/polars.field_types.json @@ -13,6 +13,6 @@ ["set", ["unknown", "object"]], ["imaginary", ["unknown", "object"]], ["time", ["time", "Time"]], - ["duration", ["datetime", "duration[\u03bcs]"]], + ["duration", ["number", "duration[\u03bcs]"]], ["mixed_list", ["unknown", "list[str]"]] ] diff --git a/tests/_utils/test_narwhals_utils.py b/tests/_utils/test_narwhals_utils.py index 0f20b2a5e72..4fc0ab20e57 100644 --- a/tests/_utils/test_narwhals_utils.py +++ b/tests/_utils/test_narwhals_utils.py @@ -90,6 +90,7 @@ def test_narwhals_type_checks(): assert is_narwhals_temporal_type(nw.Datetime) assert is_narwhals_temporal_type(nw.Date) assert not is_narwhals_temporal_type(nw.Int64) + assert not is_narwhals_temporal_type(nw.Duration) assert is_narwhals_string_type(nw.String) assert is_narwhals_string_type(nw.Categorical)