Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Series.hist #1859

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
- filter
- gather_every
- head
- hist
- implementation
- is_between
- is_duplicated
Expand Down
113 changes: 113 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import pad_series
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantSeries
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -1004,6 +1005,118 @@ def rank(
result = pc.if_else(null_mask, pa.scalar(None), rank)
return self._from_native_series(result)

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> ArrowDataFrame:
import numpy as np # ignore-banned-import
import pyarrow as pa
import pyarrow.compute as pc
camriddell marked this conversation as resolved.
Show resolved Hide resolved

from narwhals._arrow.dataframe import ArrowDataFrame

def _hist_from_bin_count(
bin_count: int,
) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]:
d = pc.min_max(self._native_series)
lower, upper = d["min"].as_py(), d["max"].as_py()
if lower == upper:
lower -= 0.001 * abs(lower) if lower != 0 else 0.001
upper += 0.001 * abs(upper) if upper != 0 else 0.001

width = (upper - lower) / bin_count
bin_proportions = pc.divide(pc.subtract(self._native_series, lower), width)
bin_indices = pc.floor(bin_proportions)

bin_indices = pc.if_else( # shift bins so they are right-closed
pc.and_(
pc.equal(bin_indices, bin_proportions),
pc.greater(bin_indices, 0),
),
pc.subtract(bin_indices, 1),
bin_indices,
)
counts = ( # count bin id occurrences
pa.Table.from_arrays(
pc.value_counts(bin_indices)
.cast(pa.struct({"values": pa.int64(), "counts": pa.int64()}))
.flatten(),
names=["values", "counts"],
)
.join( # align bin ids to all possible bin ids (populate in missing bins)
pa.Table.from_arrays([np.arange(bin_count)], ["values"]),
keys="values",
join_type="right outer",
)
.sort_by("values")
)
counts = counts.set_column( # empty bin intervals should have a 0 count
0, "counts", pc.coalesce(counts.column("counts"), 0)
)

# extract left/right side of the intervals
bin_left = pc.multiply(counts.column("values"), width)
bin_right = pc.add(bin_left, width)
bin_left = pa.chunked_array(
[ # pad lowest bin by 1% of range
[pc.subtract(bin_left[0], (upper - lower) * 0.001)],
bin_left[1:],
]
)
counts = counts.column("counts")
return counts, bin_left, bin_right

def _hist_from_bins(
bins: Sequence[int | float],
) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]:
bins = np.asarray(bins)
if (np.diff(bins) < 0).any():
msg = "bins must increase monotonically"
raise InvalidOperationError(msg)

bin_indices = np.searchsorted(bins, self._native_series, side="left")
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
obj_cats = np.arange(1, len(bins))
counts = np.zeros_like(obj_cats)
counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)]

bin_right = bins[1:]
bin_left = bins[:-1]
return counts, bin_left, bin_right

if bins is not None:
counts, bin_left, bin_right = _hist_from_bins(bins)

elif bin_count is not None:
if bin_count == 0:
counts, bin_left, bin_right = [], [], []
else:
counts, bin_left, bin_right = _hist_from_bin_count(bin_count)

else: # pragma: no cover
# caller guarantees that either bins or bin_count is specified
msg = "must provide one of `bin_count` or `bins`"
raise InvalidOperationError(msg)

data: dict[str, Sequence[int | float | str]] = {}
if include_breakpoint:
data["breakpoint"] = bin_right
if include_category:
data["category"] = [
f"({left}, {right}]" for left, right in zip(bin_left, bin_right)
]
data["count"] = counts

return ArrowDataFrame(
pa.Table.from_pydict(data),
backend_version=self._backend_version,
version=self._version,
)

def __iter__(self: Self) -> Iterator[Any]:
yield from (
maybe_extract_py_scalar(x, return_py_scalar=True)
Expand Down
48 changes: 48 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,54 @@ def rank(

return self._from_native_series(ranked_series)

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> PandasLikeDataFrame:
from pandas import Categorical
from pandas import cut
camriddell marked this conversation as resolved.
Show resolved Hide resolved

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

data: dict[str, Sequence[int | float | str]]
if bin_count is not None and bin_count == 0:
data = {}
if include_breakpoint:
data["breakpoint"] = []
if include_category:
data["category"] = []
data["count"] = []

return PandasLikeDataFrame(
self.__native_namespace__().DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)

result = (
cut(self._native_series, bins=bins if bin_count is None else bin_count)
.value_counts()
.sort_index()
)
data = {}
if include_breakpoint:
data["breakpoint"] = result.index.categories.right
if include_category:
data["category"] = Categorical(result.index.categories.astype(str))
data["count"] = result.reset_index(drop=True)

return PandasLikeDataFrame(
self.__native_namespace__().DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)

@property
def str(self: Self) -> PandasLikeSeriesStringNamespace:
return PandasLikeSeriesStringNamespace(self)
Expand Down
23 changes: 23 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,29 @@ def __contains__(self: Self, other: Any) -> bool:
msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}."
raise InvalidOperationError(msg) from exc

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame

df = self._native_series.hist(
bins=bins,
bin_count=bin_count,
include_category=include_category,
include_breakpoint=include_breakpoint,
)
if not include_category and not include_breakpoint:
df.columns = ["count"]

return PolarsDataFrame(
df, backend_version=self._backend_version, version=self._version
)

def to_polars(self: Self) -> pl.Series:
return self._native_series

Expand Down
41 changes: 41 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from narwhals.dependencies import is_numpy_scalar
from narwhals.dtypes import _validate_dtype
from narwhals.exceptions import InvalidOperationError
from narwhals.series_cat import SeriesCatNamespace
from narwhals.series_dt import SeriesDateTimeNamespace
from narwhals.series_list import SeriesListNamespace
Expand Down Expand Up @@ -4900,6 +4901,46 @@ def rank(
self._compliant_series.rank(method=method, descending=descending)
)

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> DataFrame[Any]:
"""Bin values into buckets and count their occurrences.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

Arguments:
bins: A monotonically increasing sequence of values.
bin_count: If no bins provided, this will be used to determine the distance of the bins.
include_category: Include a column that indicates the upper value of each bin.
include_breakpoint: Include a column that shows the intervals as categories.

Returns:
A new DataFrame containing the counts of values that occur within each passed bin.
"""
if bins is None and bin_count is None:
msg = "must provide one of `bin_count` or `bins`"
raise InvalidOperationError(msg)
if bins is not None and bin_count is not None:
msg = "can only provide one of `bin_count` or `bins`"
raise InvalidOperationError(msg)

return self._dataframe(
self._compliant_series.hist(
bins=bins,
bin_count=bin_count,
include_category=include_category,
include_breakpoint=include_breakpoint,
),
level=self._level,
)

@property
def str(self: Self) -> SeriesStringNamespace[Self]:
return SeriesStringNamespace(self)
Expand Down
38 changes: 38 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,44 @@ def rolling_std(
ddof=ddof,
)

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> DataFrame[Any]:
"""Bin values into buckets and count their occurrences.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

Arguments:
bins: A monotonically increasing sequence of values.
bin_count: If no bins provided, this will be used to determine the distance of the bins.
include_category: Include a column that indicates the upper value of each bin.
include_breakpoint: Include a column that shows the intervals as categories.

Returns:
A new DataFrame containing the counts of values that occur within each passed bin.
"""
from narwhals.exceptions import NarwhalsUnstableWarning
from narwhals.utils import find_stacklevel

msg = (
"`Series.hist` is being called from the stable API although considered "
"an unstable feature."
)
warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel())
return super().hist( # type: ignore[return-value]
bins=bins,
bin_count=bin_count,
include_category=include_category,
include_breakpoint=include_breakpoint,
)


class Expr(NwExpr):
def _l1_norm(self: Self) -> Self:
Expand Down
Loading
Loading