diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index d3bbd62f6..95a6539af 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -30,6 +30,7 @@ - filter - gather_every - head + - hist - implementation - is_between - is_duplicated diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index e9a82d1d2..676f32219 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -21,6 +21,7 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import pad_series +from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantSeries from narwhals.utils import Implementation from narwhals.utils import generate_temporary_column_name @@ -1004,6 +1005,116 @@ def rank( result = pc.if_else(null_mask, pa.scalar(None), rank) return self._from_native_series(result) + def hist( + self: Self, + bins: list[float | int] | None = None, + *, + bin_count: int | None = None, + include_category: bool = True, + include_breakpoint: bool = True, + ) -> ArrowDataFrame: + import numpy as np # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + def _hist_from_bin_count( + bin_count: int, + ) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]: + d = pc.min_max(self._native_series) + lower, upper = d["min"].as_py(), d["max"].as_py() + if lower == upper: + lower -= 0.001 * abs(lower) if lower != 0 else 0.001 + upper += 0.001 * abs(upper) if upper != 0 else 0.001 + + width = (upper - lower) / bin_count + bin_proportions = pc.divide(pc.subtract(self._native_series, lower), width) + bin_indices = pc.floor(bin_proportions) + + bin_indices = pc.if_else( # shift bins so they are right-closed + pc.and_( + pc.equal(bin_indices, bin_proportions), + pc.greater(bin_indices, 0), + ), + pc.subtract(bin_indices, 1), + bin_indices, + ) + counts = ( # count bin id occurrences + pa.Table.from_arrays( + pc.value_counts(bin_indices) + .cast(pa.struct({"values": pa.int64(), "counts": pa.int64()})) + .flatten(), + names=["values", "counts"], + ) + .join( # align bin ids to all possible bin ids (populate in missing bins) + pa.Table.from_arrays([np.arange(bin_count)], ["values"]), + keys="values", + join_type="right outer", + ) + .sort_by("values") + ) + counts = counts.set_column( # empty bin intervals should have a 0 count + 0, "counts", pc.coalesce(counts.column("counts"), 0) + ) + + # extract left/right side of the intervals + bin_left = pc.multiply(counts.column("values"), width) + bin_right = pc.add(bin_left, width) + bin_left = pa.chunked_array( + [ # pad lowest bin by 1% of range + [pc.subtract(bin_left[0], (upper - lower) * 0.001)], + bin_left[1:], + ] + ) + counts = counts.column("counts") + return counts, bin_left, bin_right + + def _hist_from_bins( + bins: Sequence[int | float], + ) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]: + bins = np.asarray(bins) + if (np.diff(bins) < 0).any(): + msg = "bins must increase monotonically" + raise InvalidOperationError(msg) + + bin_indices = np.searchsorted(bins, self._native_series, side="left") + obs_cats, obs_counts = np.unique(bin_indices, return_counts=True) + obj_cats = np.arange(1, len(bins)) + counts = np.zeros_like(obj_cats) + counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)] + + bin_right = bins[1:] + bin_left = bins[:-1] + return counts, bin_left, bin_right + + if bins is not None: + counts, bin_left, bin_right = _hist_from_bins(bins) + + elif bin_count is not None: + if bin_count == 0: + counts, bin_left, bin_right = [], [], [] + else: + counts, bin_left, bin_right = _hist_from_bin_count(bin_count) + + else: # pragma: no cover + # caller guarantees that either bins or bin_count is specified + msg = "must provide one of `bin_count` or `bins`" + raise InvalidOperationError(msg) + + data: dict[str, Sequence[int | float | str]] = {} + if include_breakpoint: + data["breakpoint"] = bin_right + if include_category: + data["category"] = [ + f"({left}, {right}]" for left, right in zip(bin_left, bin_right) + ] + data["count"] = counts + + return ArrowDataFrame( + pa.Table.from_pydict(data), + backend_version=self._backend_version, + version=self._version, + ) + def __iter__(self: Self) -> Iterator[Any]: yield from ( maybe_extract_py_scalar(x, return_py_scalar=True) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index f79aab1ea..024d46d70 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -1031,6 +1031,53 @@ def rank( return self._from_native_series(ranked_series) + def hist( + self: Self, + bins: list[float | int] | None = None, + *, + bin_count: int | None = None, + include_category: bool = True, + include_breakpoint: bool = True, + ) -> PandasLikeDataFrame: + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + ns = self.__native_namespace__() + data: dict[str, Sequence[int | float | str]] + + if bin_count is not None and bin_count == 0: + data = {} + if include_breakpoint: + data["breakpoint"] = [] + if include_category: + data["category"] = [] + data["count"] = [] + + return PandasLikeDataFrame( + ns.DataFrame(data), + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + ) + + result = ( + ns.cut(self._native_series, bins=bins if bin_count is None else bin_count) + .value_counts() + .sort_index() + ) + data = {} + if include_breakpoint: + data["breakpoint"] = result.index.categories.right + if include_category: + data["category"] = ns.Categorical(result.index.categories.astype(str)) + data["count"] = result.reset_index(drop=True) + + return PandasLikeDataFrame( + ns.DataFrame(data), + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + ) + @property def str(self: Self) -> PandasLikeSeriesStringNamespace: return PandasLikeSeriesStringNamespace(self) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index fc8cfc8f1..2ab863df9 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -416,6 +416,60 @@ def __contains__(self: Self, other: Any) -> bool: msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}." raise InvalidOperationError(msg) from exc + def hist( + self: Self, + bins: list[float | int] | None = None, + *, + bin_count: int | None = None, + include_category: bool = True, + include_breakpoint: bool = True, + ) -> PolarsDataFrame: + from narwhals._polars.dataframe import PolarsDataFrame + from narwhals.exceptions import InvalidOperationError + + # check for monotonicity, polars<1.0 does not do this. + if bins is not None: + for i in range(1, len(bins)): + if bins[i - 1] >= bins[i]: + msg = "bins must increase monotonically" + raise InvalidOperationError(msg) + + # polars<1.0 returned bins -inf to inf in these conditions + if (self._backend_version < (1, 0)) and ((len(bins) == 0) or (bin_count == 0)): # type:ignore[arg-type] + data: list[pl.Series] = [] + if include_breakpoint: + data.append(pl.Series("breakpoint", [], dtype=pl.Float64)) + if include_category: + data.append(pl.Series("category", [], dtype=pl.Category)) + data.append(pl.Series("count", [], dtype=pl.UInt32)) + return PolarsDataFrame( + pl.DataFrame(data), + backend_version=self._backend_version, + version=self._version, + ) + + df = self._native_series.hist( + bins=bins, + bin_count=bin_count, + include_category=include_category, + include_breakpoint=include_breakpoint, + ) + if not include_category and not include_breakpoint: + df.columns = ["count"] + + if self._backend_version < (1, 0): # pragma: no cover + if ( + bins is not None + ): # polars<1.0 implicitly adds -inf and inf to either end of bins + r = pl.int_range(0, len(df)) + df = df.filter((r > 0) & (r < len(df) - 1)) + if include_breakpoint: + df = df.rename({"break_point": "breakpoint"}) + + return PolarsDataFrame( + df, backend_version=self._backend_version, version=self._version + ) + def to_polars(self: Self) -> pl.Series: return self._native_series diff --git a/narwhals/series.py b/narwhals/series.py index 7b82e39e7..942bace10 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -12,6 +12,7 @@ from narwhals.dependencies import is_numpy_scalar from narwhals.dtypes import _validate_dtype +from narwhals.exceptions import InvalidOperationError from narwhals.series_cat import SeriesCatNamespace from narwhals.series_dt import SeriesDateTimeNamespace from narwhals.series_list import SeriesListNamespace @@ -4900,6 +4901,46 @@ def rank( self._compliant_series.rank(method=method, descending=descending) ) + def hist( + self: Self, + bins: list[float | int] | None = None, + *, + bin_count: int | None = None, + include_category: bool = True, + include_breakpoint: bool = True, + ) -> DataFrame[Any]: + """Bin values into buckets and count their occurrences. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + bins: A monotonically increasing sequence of values. + bin_count: If no bins provided, this will be used to determine the distance of the bins. + include_category: Include a column that indicates the upper value of each bin. + include_breakpoint: Include a column that shows the intervals as categories. + + Returns: + A new DataFrame containing the counts of values that occur within each passed bin. + """ + if bins is None and bin_count is None: + msg = "must provide one of `bin_count` or `bins`" + raise InvalidOperationError(msg) + if bins is not None and bin_count is not None: + msg = "can only provide one of `bin_count` or `bins`" + raise InvalidOperationError(msg) + + return self._dataframe( + self._compliant_series.hist( + bins=bins, + bin_count=bin_count, + include_category=include_category, + include_breakpoint=include_breakpoint, + ), + level=self._level, + ) + @property def str(self: Self) -> SeriesStringNamespace[Self]: return SeriesStringNamespace(self) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index df443230c..12f2b59b5 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -618,6 +618,44 @@ def rolling_std( ddof=ddof, ) + def hist( + self: Self, + bins: list[float | int] | None = None, + *, + bin_count: int | None = None, + include_category: bool = True, + include_breakpoint: bool = True, + ) -> DataFrame[Any]: + """Bin values into buckets and count their occurrences. + + !!! warning + This functionality is considered **unstable**. It may be changed at any point + without it being considered a breaking change. + + Arguments: + bins: A monotonically increasing sequence of values. + bin_count: If no bins provided, this will be used to determine the distance of the bins. + include_category: Include a column that indicates the upper value of each bin. + include_breakpoint: Include a column that shows the intervals as categories. + + Returns: + A new DataFrame containing the counts of values that occur within each passed bin. + """ + from narwhals.exceptions import NarwhalsUnstableWarning + from narwhals.utils import find_stacklevel + + msg = ( + "`Series.hist` is being called from the stable API although considered " + "an unstable feature." + ) + warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel()) + return super().hist( # type: ignore[return-value] + bins=bins, + bin_count=bin_count, + include_category=include_category, + include_breakpoint=include_breakpoint, + ) + class Expr(NwExpr): def _l1_norm(self: Self) -> Self: diff --git a/tests/series_only/hist_test.py b/tests/series_only/hist_test.py new file mode 100644 index 000000000..d9f3c3a7c --- /dev/null +++ b/tests/series_only/hist_test.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +from typing import Any + +import hypothesis.strategies as st +import pytest +from hypothesis import given + +import narwhals.stable.v1 as nw +from narwhals.exceptions import InvalidOperationError +from tests.utils import POLARS_VERSION +from tests.utils import ConstructorEager +from tests.utils import assert_equal_data +from tests.utils import nwise + +data = { + "int": [0, 1, 2, 3, 4, 5, 6], + "float": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], +} + +bins_and_expected = [ + { + "bins": [-float("inf"), 2.5, 5.5, float("inf")], + "expected": [3, 3, 1], + }, + { + "bins": [1.0, 2.5, 5.5, float("inf")], + "expected": [1, 3, 1], + }, + { + "bins": [1.0, 2.5, 5.5], + "expected": [1, 3], + }, + { + "bins": [-10.0, -1.0, 2.5, 5.5], + "expected": [0, 3, 3], + }, +] +counts_and_expected = [ + { + "bin_count": 4, + "expected_bins": [-0.006, 1.5, 3.0, 4.5, 6.0], + "expected_count": [2, 2, 1, 2], + }, + { + "bin_count": 12, + "expected_bins": [ + -0.006, + 0.5, + 1.0, + 1.5, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 4.5, + 5.0, + 5.5, + 6.0, + ], + "expected_count": [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + }, + { + "bin_count": 0, + "expected_bins": [], + "expected_count": [], + }, +] + + +@pytest.mark.parametrize("params", bins_and_expected) +@pytest.mark.parametrize("include_breakpoint", [True, False]) +@pytest.mark.parametrize("include_category", [True, False]) +@pytest.mark.filterwarnings( + "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." +) +def test_hist_bin( + constructor_eager: ConstructorEager, + *, + params: dict[str, Any], + include_breakpoint: bool, + include_category: bool, +) -> None: + df = nw.from_native(constructor_eager(data)) + bins = params["bins"] + + expected = { + "breakpoint": bins[1:], + "category": [f"({left}, {right}]" for left, right in nwise(bins, n=2)], + "count": params["expected"], + } + if not include_breakpoint: + del expected["breakpoint"] + if not include_category: + del expected["category"] + + result = df["int"].hist( + bins=bins, + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + assert_equal_data(result, expected) + + result = df["float"].hist( + bins=bins, + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + assert_equal_data(result, expected) + + +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), + reason="hist(bin_count=...) behavior significantly changed after 1.0", +) +@pytest.mark.parametrize("params", counts_and_expected) +@pytest.mark.parametrize("include_breakpoint", [True, False]) +@pytest.mark.parametrize("include_category", [True, False]) +@pytest.mark.filterwarnings( + "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." +) +def test_hist_count( + constructor_eager: ConstructorEager, + *, + params: dict[str, Any], + include_breakpoint: bool, + include_category: bool, +) -> None: + df = nw.from_native(constructor_eager(data)) + + bins = params["expected_bins"] + expected = { + "breakpoint": bins[1:], + "category": [f"({left}, {right}]" for left, right in nwise(bins, n=2)], + "count": params["expected_count"], + } + if not include_breakpoint: + del expected["breakpoint"] + if not include_category: + del expected["category"] + + result = df["int"].hist( + bin_count=params["bin_count"], + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + assert_equal_data(result, expected) + + result = df["float"].hist( + bin_count=params["bin_count"], + include_breakpoint=include_breakpoint, + include_category=include_category, + ) + assert_equal_data(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." +) +def test_hist_bin_and_bin_count() -> None: + import polars as pl + + s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) + with pytest.raises(InvalidOperationError, match="must provide one of"): + s.hist(bins=None, bin_count=None) + + with pytest.raises(InvalidOperationError, match="can only provide one of"): + s.hist(bins=[1, 3], bin_count=4) + + +@pytest.mark.filterwarnings( + "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." +) +def test_hist_non_monotonic(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager({"int": [0, 1, 2, 3, 4, 5, 6]})) + + with pytest.raises(Exception, match="monotonic"): + df["int"].hist(bins=[5, 0, 2]) + + with pytest.raises(Exception, match="monotonic"): + df["int"].hist(bins=[5, 2, 0]) + + +@given( # type: ignore[misc] + data=st.lists(st.floats(min_value=-1_000, max_value=1_000), min_size=1, max_size=100), + bin_deltas=st.lists( + st.floats(min_value=0.001, max_value=1_000, allow_nan=False), max_size=50 + ), +) +@pytest.mark.filterwarnings( + "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." +) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), + reason="hist(bins=...) cannot be used for compatibility checks since narwhals aims to mimi polars>=1.0 behavior", +) +@pytest.mark.slow +def test_hist_bin_hypotheis( + constructor_eager: ConstructorEager, + data: list[float], + bin_deltas: list[float], +) -> None: + import polars as pl + + df = nw.from_native(constructor_eager({"values": data})).select( + nw.col("values").cast(nw.Float64) + ) + bins = ( + nw.from_native(constructor_eager({"bins": bin_deltas})["bins"], series_only=True) # type:ignore[index] + .cast(nw.Float64) + .cum_sum() + ) + + result = df["values"].hist( + bins=bins.to_list(), + include_breakpoint=False, + include_category=False, + ) + expected = ( + pl.Series(data, dtype=pl.Float64) + .hist( + bins=pl.Series(bin_deltas, dtype=pl.Float64).cum_sum().to_list(), + include_breakpoint=False, + include_category=False, + ) + .rename({"": "count"}) + ).to_dict(as_series=False) + + assert_equal_data(result, expected) + + +@given( # type: ignore[misc] + data=st.lists( + st.floats(min_value=-1_000, max_value=1_000, allow_subnormal=False), + min_size=1, + max_size=100, + ), + bin_count=st.integers(min_value=0, max_value=1_000), +) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), + reason="hist(bin_count=...) behavior significantly changed after 1.0", +) +@pytest.mark.filterwarnings( + "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." +) +@pytest.mark.slow +def test_hist_count_hypothesis( + constructor_eager: ConstructorEager, + data: list[float], + bin_count: int, + request: pytest.FixtureRequest, +) -> None: + import polars as pl + + df = nw.from_native(constructor_eager({"values": data})).select( + nw.col("values").cast(nw.Float64) + ) + + result = df["values"].hist( + bin_count=bin_count, + include_breakpoint=False, + include_category=False, + ) + expected = ( + pl.Series(data, dtype=pl.Float64) + .hist( + bin_count=bin_count, + include_breakpoint=False, + include_category=False, + ) + .rename({"": "count"}) + ) + + # Bug in Polars <= 1.2.0; hist becomes unreliable when passing bin_counts + # for data with a wide range and a large number of passed bins + # https://github.com/pola-rs/polars/issues/20879 + if expected["count"].sum() != len(data) and "polars" not in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + + assert_equal_data(result, expected.to_dict(as_series=False)) diff --git a/tests/utils.py b/tests/utils.py index f4f612619..fc1e6c2d2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,10 +4,15 @@ import os import sys import warnings +from itertools import islice +from itertools import tee from typing import Any from typing import Callable +from typing import Generator +from typing import Iterable from typing import Iterator from typing import Sequence +from typing import TypeVar import pandas as pd @@ -17,6 +22,8 @@ from narwhals.utils import Implementation from narwhals.utils import parse_version +T = TypeVar("T") + if sys.version_info >= (3, 10): from typing import TypeAlias # pragma: no cover else: @@ -50,6 +57,14 @@ def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]: return zip(left, right) +def nwise(iterable: Iterable[T], n: int = 2) -> Generator[tuple[T, ...], None, None]: + """Produces a sliding window across values. + + Behaves like a generic version of `itertools.pairwise`. + """ + yield from zip(*(islice(it, i, None) for i, it in enumerate(tee(iterable, n)))) + + def _to_comparable_list(column_values: Any) -> Any: if ( hasattr(column_values, "_compliant_series") diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index e0ebf97a9..fa5a4a5c7 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -18,6 +18,7 @@ "implementation", "is_empty", "is_sorted", + "hist", "item", "name", "rename",