Skip to content

Commit

Permalink
Merge pull request #33 from jacanchaplais/feature/serialize-2D-32
Browse files Browse the repository at this point in the history
Serialization of Histogram2D
  • Loading branch information
jacanchaplais authored Sep 11, 2024
2 parents c7abaa5 + b048c84 commit 6b38279
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 9 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: tests

on:
- push

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-13, windows-latest, macos-latest]
pyver: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '1.5.6-0'
environment-name: test-env
create-args: >-
python=${{ matrix.pyver }}
pytest>=6.2.5
init-shell: bash
cache-environment: true
post-cleanup: 'all'

- name: Install colliderscope
shell: bash -el {0}
run: pip install ".[dev]"

- name: Run tests
shell: bash -el {0}
run: pytest
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
colliderscope
=============

|PyPI version| |Documentation| |License| |pre-commit| |Code style:
|PyPI version| |Tests| |Documentation| |License| |pre-commit| |Code style:
black|

Visualise your high energy physics (HEP) data with colliderscope!
Expand Down Expand Up @@ -36,3 +36,4 @@ More information on the API is available in the
:target: https://github.com/pre-commit/pre-commit
.. |Code style: black| image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
.. |Tests| image:: https://github.com/jacanchaplais/colliderscope/actions/workflows/tests.yml/badge.svg
79 changes: 72 additions & 7 deletions colliderscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ class Histogram:
Migrated from ``data`` module.
Renamed ``_align`` parameter to ``align``.
.. versionchanged:: 0.2.7
Added ``missed`` property for number of out-of-bounds updates.
:group: data
Parameters
Expand Down Expand Up @@ -538,10 +541,12 @@ def __post_init__(self) -> None:
self.bin_edges = np.linspace(*self.x_range, self.num_bins + 1)
self.bin_edges = self.bin_edges.squeeze()

def __eq__(self, other: "Histogram") -> bool:
def __eq__(self, other: tyx.Self) -> bool:
"""Determines if two ``Histogram`` instances contain the same
``bin_edges``, ``counts``, and normalisation for ``pdf``.
"""
if not isinstance(other, type(self)):
raise TypeError("Cannot compare Histogram with different types.")
eq = np.array_equal(self.bin_edges, other.bin_edges)
eq = eq and np.array_equal(self.counts, other.counts)
eq = eq and (self._total == other._total)
Expand Down Expand Up @@ -575,6 +580,7 @@ def update(self, val: base.HistValue) -> None:
return
mask = np.logical_and(idx > -1, idx < self.num_bins)
np.add.at(self.counts, idx[mask], 1)
self._total += idx.reshape(-1).shape[0]

def to_json(
self,
Expand Down Expand Up @@ -674,6 +680,13 @@ def pdf(self) -> ty.Tuple[base.DoubleVector, base.DoubleVector]:
"""
return self.midpoints(), self.density()

@property
def missed(self) -> int:
"""Number of update values which fell out-of-bounds of the
histogram's domain.
"""
return self.total - np.sum(self.counts).item()

def density(self) -> base.DoubleVector:
"""Probability density of the histogram, normalised by the total
count of the histogram.
Expand Down Expand Up @@ -1195,7 +1208,11 @@ def _is_1d(data: np.ndarray) -> bool:
class Histogram2D:
"""Constant memory 2D histogram data structure.
.. versionadded:: 0.3.0
.. versionadded:: 0.2.6
.. versionchanged:: 0.2.7
Added serialize / from_serialized methods for easy IO.
Added ``missed`` property for number of out-of-bounds updates.
:group: data
Expand Down Expand Up @@ -1241,21 +1258,37 @@ def __init__(
dtype: npt.DTypeLike = np.float64,
) -> None:
self.num_bins_x = num_bins_x
self.window_x = window_x
self.window_x: ty.Tuple[float, float] = tuple(window_x)
self.num_bins_y = num_bins_y
self.window_y = window_y
self.window_y: ty.Tuple[float, float] = tuple(window_y)
self.dtype = np.dtype(dtype)
self._total = 0
bin_width_x = abs((window_x[1] - window_x[0]) / num_bins_x)
bin_width_y = abs((window_y[1] - window_y[0]) / num_bins_y)
self.bin_width_x = bin_width_x
self.bin_width_y = bin_width_y
self.accumulate: base.IntVector = np.zeros(
self.accumulate: base.NumberVector = np.zeros(
(num_bins_y, num_bins_x), dtype=dtype
)
self.bin_edges_x = np.linspace(*window_x, num_bins_x + 1).squeeze()
self.bin_edges_y = np.linspace(*window_y, num_bins_y + 1).squeeze()

def __eq__(self, other: tyx.Self) -> bool:
if not isinstance(other, type(self)):
raise TypeError("Cannot compare Histogram2D with different types.")
for attrib in (
"window_x",
"window_y",
"num_bins_x",
"num_bins_y",
"dtype",
):
if getattr(self, attrib) != getattr(other, attrib):
return False
if not np.array_equal(self.accumulate, other.accumulate):
return False
return True

@property
def total(self) -> int:
return self._total
Expand All @@ -1271,12 +1304,11 @@ def _table_info(self) -> ty.List[ty.List[ty.Any]]:
f"[{self.window_y[0]:.3f}, "
f"{self.window_y[1]:.3f})"
)
missed = self.total - np.sum(self.accumulate).item()
return [
["x-axis", xaxis],
["y-axis", yaxis],
["total", round(self.total, 3)],
["missed", round(missed, 3)],
["missed", round(self.missed, 3)],
["dtype", self.dtype],
]

Expand Down Expand Up @@ -1384,3 +1416,36 @@ def midpoints(self) -> ty.Tuple[base.DoubleVector, base.DoubleVector]:
self.num_bins_y,
),
)

@property
def missed(self) -> ty.Union[float, int]:
"""Number of update values which fell out-of-bounds of the
histogram's domain.
"""
return self.total - np.sum(self.accumulate).item()

def serialize(self) -> ty.Dict[str, ty.Any]:
"""Converts ``Histogram2D`` into serialized representation."""
return {
"num_bins_x": self.num_bins_x,
"num_bins_y": self.num_bins_y,
"window_x": self.window_x,
"window_y": self.window_y,
"dtype": np.dtype(self.dtype).str,
"accumulate": self.accumulate.tolist(),
"total": self.total,
}

@classmethod
def from_serialized(cls, data: ty.Dict[str, ty.Any]) -> tyx.Self:
"""Instantiates ``Histogram2D`` from serialized data."""
accumulate = data.pop("accumulate", None)
if accumulate is None:
raise ValueError("Must have `accumulate` item in data dictionary")
total = data.pop("total", None)
if total is None:
raise ValueError("Must have `total` item in data dictionary")
hist2d = cls(**data)
hist2d.accumulate[...] = accumulate
hist2d._total = total
return hist2d
1 change: 1 addition & 0 deletions colliderscope/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BoolVector: tyx.TypeAlias = npt.NDArray[np.bool_]
IntVector: tyx.TypeAlias = npt.NDArray[np.int32]
DoubleVector: tyx.TypeAlias = npt.NDArray[np.float64]
NumberVector: tyx.TypeAlias = npt.NDArray[np.number]
VoidVector: tyx.TypeAlias = npt.NDArray[np.void]
AnyVector: tyx.TypeAlias = npt.NDArray[ty.Any]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"pyvis <=0.1.9",
"colour",
"webcolors",
"graphicle >=0.3.9",
"graphicle >=0.4.1",
"plotly",
"more-itertools >=2.1",
]
Expand Down
66 changes: 66 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
import random

import numpy as np

import colliderscope as csp


def create_hist1d(total: int, num_misses: int) -> csp.Histogram:
if total < num_misses:
raise ValueError("Cannot create more misses than there are values!")
num_bins = random.randint(20, 1000)
window = (random.uniform(-50.0, 0.0), random.uniform(1.0, 1000.0))
hist = csp.Histogram(num_bins, window)
rng = np.random.default_rng()
hist.update(rng.uniform(*window, size=(total - num_misses)))
lo_misses = num_misses // 2
hi_misses = num_misses - lo_misses
hist.update(rng.uniform(window[1], 5000.0, size=hi_misses))
hist.update(rng.uniform(-1000.0, window[0], size=lo_misses))
return hist


def test_missed_and_total_1d() -> None:
total, misses = 1_000, 24
hist = create_hist1d(total=total, num_misses=misses)
assert hist.missed == misses, "Number of missed updates not consistent."
assert hist.total == total, "Total number of updates incorrect."


def create_hist2d(total: int, num_misses: int) -> csp.Histogram2D:
if total < num_misses:
raise ValueError("Cannot create more misses than there are values!")
nbinx, nbiny = random.randint(20, 1000), random.randint(20, 1000)
winx = (random.uniform(-50.0, 0.0), random.uniform(1.0, 1000.0))
winy = (random.uniform(-50.0, 0.0), random.uniform(1.0, 1000.0))
hist = csp.Histogram2D(nbinx, winx, nbiny, winy, np.int32)
rng = np.random.default_rng()
hist.update(
x=rng.uniform(*winx, size=total - num_misses),
y=rng.uniform(*winy, size=total - num_misses),
)
hist.update(
x=rng.uniform(winx[1], 1000.0, size=num_misses),
y=rng.uniform(-1000.0, winy[0], size=num_misses),
)
return hist


def test_missed_and_total_2d() -> None:
hist = create_hist2d(total=2_000, num_misses=44)
assert hist.missed == 44, "Number of missed updates not consistent."
assert hist.total == 2_000, "Total number of updates incorrect."


def test_hist_serialize_inversion() -> None:
# checking 1D histogram
hist1d = create_hist1d(100, 3)
hist1d_str = json.dumps(hist1d.serialize())
hist1d_read = csp.Histogram.from_serialized(json.loads(hist1d_str))
assert hist1d_read == hist1d, "Serialization not invertible for 1d hist."
# checking 2D histogram
hist2d = create_hist2d(200, 7)
hist2d_str = json.dumps(hist2d.serialize())
hist2d_read = csp.Histogram2D.from_serialized(json.loads(hist2d_str))
assert hist2d_read == hist2d, "Serialization not invertible for 2d hist."

0 comments on commit 6b38279

Please sign in to comment.