Skip to content

Commit

Permalink
add logsumexp, sort, searchsorted and ecdf functions (#40)
Browse files Browse the repository at this point in the history
* add logsumexp, searchsorted and ecdf functions

* avoid full obj name + module in page headings

* try matplotlib plot directive

* extend tests

* lint fixes

* changelog and doc improvements
  • Loading branch information
OriolAbril authored Jan 16, 2023
1 parent 36d8961 commit 6e5fb3d
Show file tree
Hide file tree
Showing 16 changed files with 392 additions and 88 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ good-names=a,
m,
t,
q,
v,
x,
y,
z,
Expand Down
5 changes: 5 additions & 0 deletions docs/source/_templates/autosummary/base.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{{ objname | escape | underline}}

.. currentmodule:: {{ module }}

.. auto{{ objtype }}:: {{ objname }}
2 changes: 1 addition & 1 deletion docs/source/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{{ fullname | escape | underline}}
{{ objname | escape | underline}}

.. currentmodule:: {{ module }}

Expand Down
1 change: 1 addition & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Moreover, it also provides some convenience functions in the top-level namespace
.. autosummary::
:toctree: generated/
sort
empty_ref
ones_ref
zeros_ref
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/numba.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@
:toctree: generated/
histogram
searchsorted
ecdf
```
1 change: 1 addition & 0 deletions docs/source/api/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ but the output will be a numpy array.
:toctree: generated/
rankdata
logsumexp
```

## Convenience functions
Expand Down
9 changes: 7 additions & 2 deletions docs/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

## v0.x.x (Unreleased)
### New features
* Added `empty_ref`, `ones_ref` and `zeros_ref` DataArray creation helpers {pull}`37`
* Added `linalg.diagonal` wrapper {pull}`37`
* Added {func}`.empty_ref`, {func}`.ones_ref` and {func}`.zeros_ref` DataArray creation helpers {pull}`37`
* Added {func}`.linalg.diagonal` wrapper {pull}`37`
* Added {func}`.stats.logsumexp` wrapper {pull}`40`
* Added {func}`.searchsorted` and {func}`.ecdf` in {mod}`~xarray_einstats.numba` module {pull}`40`
* Added {func}`~xarray_einstats.sort` wrapper for vectorized sort along specific dimension using values {pull}`40`

### Maintenance and fixes
* Fix issue in `linalg.svd` for non-square matrices {pull}`37`
* Fix evaluation of distribution methods (e.g. `.pdf`) on scalars {pull}`38` and {pull}`39`
* Ensure support on inputs with stacked dimensions {pull}`40`

### Documentation
* Ported NumPy tutorial on linear algebra with multidimensional arrays {pull}`37`
* Added ecdf usage example and plotting reference {pull}`40`

## v0.4.0 (2022 Dec 9)
### New features
Expand Down
6 changes: 6 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"sphinx_copybutton",
"jupyter_sphinx",
"sphinx_design",
"matplotlib.sphinxext.plot_directive",
]

templates_path = ["_templates"]
Expand Down Expand Up @@ -88,6 +89,11 @@
**{f"{singular}s": f":any:`{singular}s <{singular}>`" for singular in singulars},
}

# Include the example source for plots in API docs
plot_include_source = True
plot_formats = [("png", 90)]
plot_html_show_formats = False
plot_html_show_source_link = False

# -- Options for HTML output

Expand Down
128 changes: 61 additions & 67 deletions docs/source/tutorials/np_linalg_tutorial_port.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ doc = [
"jupyter-sphinx",
"sphinx-design",
"watermark",
"matplotlib",
]

[tool.black]
Expand Down
37 changes: 37 additions & 0 deletions src/xarray_einstats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,43 @@
__version__ = "0.5.0.dev0"


def sort(da, dim, **kwargs):
"""Sort along dimension using DataArray values."""
sort_kwargs = dict(axis=-1)
if "kind" in kwargs:
sort_kwargs["kind"] = kwargs.pop("kind")
return xr.apply_ufunc(
np.sort,
da,
input_core_dims=[[dim]],
output_core_dims=[[dim]],
kwargs=sort_kwargs,
**kwargs,
)


def _remove_indexes_to_reduce(da, dims):
"""Remove indexes related to provided dims.
Removes indexes related to dims on which we need to operate.
As many functions only support integer `axis` or None,
in order to have our functions operate on multiple dimensions
we need to stack/flatten them. If some of those dimensions
are already indexed by a multiindex this doesn't work, so we
remove the indexes. As they are reduced, that information
will end up being lost eventually either way.
"""
index_keys = list(da.indexes)
remove_indicator = [
(any(da.indexes[k] is index for k in index_keys if k in dims))
for name, index in da.indexes.items()
]
indexes_to_remove = [k for k, remove in zip(index_keys, remove_indicator) if remove]
da = da.drop_indexes(indexes_to_remove)
coords_to_remove = [coord for coord in da.coords if coord in indexes_to_remove or coord in dims]
return da.reset_coords(coords_to_remove, drop=True)


def _find_index(elem, to_search_in):
for i, da in enumerate(to_search_in):
if elem in da.dims:
Expand Down
155 changes: 151 additions & 4 deletions src/xarray_einstats/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import numpy as np
import xarray as xr

__all__ = ["histogram"]
from . import _remove_indexes_to_reduce, sort

__all__ = ["histogram", "searchsorted", "ecdf"]


@numba.guvectorize(
Expand Down Expand Up @@ -104,15 +106,16 @@ def histogram(da, dims, bins=None, density=False, **kwargs):
else:
bin_edges = bins
if not isinstance(dims, str):
da = da.stack(__hist__=dims)
dims = "__hist__"
aux_dim = f"__hist_dim__:{','.join(dims)}"
da = _remove_indexes_to_reduce(da, dims).stack({aux_dim: dims})
dims = aux_dim
histograms = xr.apply_ufunc(
hist_ufunc,
da,
bin_edges,
input_core_dims=[[dims], ["bin"]],
output_core_dims=[["bin"]],
**kwargs
**kwargs,
)
histograms = histograms.isel({"bin": slice(None, -1)}).assign_coords(
left_edges=("bin", bin_edges[:-1]), right_edges=("bin", bin_edges[1:])
Expand All @@ -122,3 +125,147 @@ def histogram(da, dims, bins=None, density=False, **kwargs):
histograms.sum("bin") * (histograms.right_edges - histograms.left_edges)
)
return histograms


@numba.guvectorize(
[
"void(uint8[:], uint8[:], uint8[:])",
"void(uint16[:], uint16[:], uint16[:])",
"void(uint32[:], uint32[:], uint32[:])",
"void(uint64[:], uint64[:], uint64[:])",
"void(int8[:], int8[:], int8[:])",
"void(int16[:], int16[:], int16[:])",
"void(int32[:], int32[:], int32[:])",
"void(int64[:], int64[:], int64[:])",
"void(float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:])",
],
"(n),(m)->(m)",
cache=True,
target="parallel",
nopython=True,
)
def searchsorted_ufunc(da, v, res): # pragma: no cover
"""Use :func:`numba.guvectorize` to convert numpy searchsorted into a vectorized ufunc.
Notes
-----
As of now, its only intended use is in for `ecdf`, so the `side` is
hardcoded and the rest of the library will assume so.
"""
res[:] = np.searchsorted(da, v, side="right")


def searchsorted(da, v, dims=None, **kwargs):
"""Numbify :func:`numpy.searchsorted` to support vectorized computations.
Parameters
----------
da : DataArray
Input data
v : DataArray
The values to insert into `da`.
dims : str or iterable of str, optional
The dimensions over which to apply the searchsort. Computation
will be parallelized over the rest with numba.
**kwargs : dict, optional
Keyword arguments passed as-is to :func:`xarray.apply_ufunc`.
Notes
-----
It has been designed to be used by :func:`~xarray_einstats.numba.ecdf`,
so its setting of input and output core dims makes some assumptions
based on that, it doesn't aim to be general use vectorized/parallelized
searchsorted.
"""
if dims is None:
dims = [d for d in da.dims if d not in v.dims]
if not isinstance(dims, str):
aux_dim = f"__aux_dim__:{','.join(dims)}"
da = _remove_indexes_to_reduce(da, dims).stack({aux_dim: dims}, create_index=False)
core_dims = [aux_dim]
else:
aux_dim = dims
core_dims = [dims]

v_dims = [d for d in v.dims if d not in da.dims]

return xr.apply_ufunc(
searchsorted_ufunc,
sort(da, dim=aux_dim),
v,
input_core_dims=[core_dims, v_dims],
output_core_dims=[v_dims],
**kwargs,
)


def ecdf(da, dims=None, *, npoints=None, **kwargs):
"""Compute the x and y values of ecdf plots in a vectorized way.
Parameters
----------
da : DataArray
Input data containing the samples on which we want to compute the ecdf.
dims : str or iterable of str, optional
Dimensions over which the ecdf should be computed. They are flattened
and converted to a ``quantile`` dimension that contains the values
to plot; the other dimensions should be used for facetting and aesthetics.
The default is computing the ecdf over the flattened input.
npoints : int, optional
Number of points on which to evaluate the ecdf. It defaults
to the minimum between 200 and the total number of points in each
block defined by `dims`.
**kwargs : dict, optional
Keyword arguments passed as-is to :func:`xarray.apply_ufunc` through
:func:`~xarray_einstats.numba.searchsorted`.
Returns
-------
Dataset
Dataset with two data variables: ``x`` and ``y`` with the values to plot.
Examples
--------
Compute and plot the ecdf over all the data:
.. plot::
:context: close-figs
from xarray_einstats import tutorial, numba
import matplotlib.pyplot as plt
ds = tutorial.generate_mcmc_like_dataset(3)
out = numba.ecdf(ds["mu"], dims=("chain", "draw", "team"))
plt.plot(out["x"], out["y"], drawstyle="steps-post");
Compute vectorized ecdf values to plot multiple subplots and
multiple lines in each with different hue:
.. plot::
:context: close-figs
out = numba.ecdf(ds["mu"], dims="draw")
out["y"].assign_coords(x=out["x"]).plot.line(
x="x", hue="chain", col="team", col_wrap=3, drawstyle="steps-post"
);
Warnings
--------
New and experimental feature, its API might change.
"""
if dims is None:
dims = da.dims
elif isinstance(dims, str):
dims = [dims]
total_points = np.product([da.sizes[d] for d in dims])
if npoints is None:
npoints = min(total_points, 200)
x = xr.DataArray(np.linspace(0, 1, npoints), dims=["quantile"])
max_da = da.max(dims)
min_da = da.min(dims)
x = (max_da - min_da) * x + min_da

y = searchsorted(da, x, dims=dims, **kwargs) / total_points
return xr.Dataset({"x": x, "y": y})
Loading

0 comments on commit 6e5fb3d

Please sign in to comment.