Skip to content

Commit

Permalink
Import latest changes to pygama.dsp
Browse files Browse the repository at this point in the history
  • Loading branch information
gipert committed May 24, 2023
1 parent 14bd3fc commit 89a3371
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 70 deletions.
51 changes: 29 additions & 22 deletions src/dspeed/processing_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Union
from typing import Any

import lgdo
import numpy as np
from lgdo import LGDO
from lgdo.lgdo_utils import expand_path
from numba import vectorize
from pint import Quantity, Unit
Expand All @@ -26,8 +27,6 @@

log = logging.getLogger(__name__)

LGDO = Union[lgdo.Scalar, lgdo.Array, lgdo.VectorOfVectors, lgdo.Struct]

# Filler value for variables to be automatically deduced later
auto = "auto"

Expand Down Expand Up @@ -129,22 +128,22 @@ def __init__(
Parameters
----------
proc_chain
ProcessingChain that contains this variable
:class:`ProcessingChain` that contains this variable.
name
Name of variable used to look it up
Name of variable used to look it up.
shape
Shape of variable, without buffer_len dimension
Shape of variable, without `buffer_len` dimension.
dtype
Data type of variable
Data type of variable.
grid
Coordinate grid associated with variable. This contains the
period and offset of the variable. For variables where
is_coord is True, use this to perform unit conversions
is_coord is True, use this to perform unit conversions.
unit
Unit associated with variable during I/O.
is_coord
If True, variable represents an array index and can be converted
into a unitted number using grid
If ``True``, variable represents an array index and can be converted
into a unitted number using grid.
"""
assert isinstance(proc_chain, ProcessingChain) and isinstance(name, str)
self.proc_chain = proc_chain
Expand Down Expand Up @@ -343,7 +342,7 @@ def __init__(self, block_width: int = 8, buffer_len: int = None) -> None:
number of entries to simultaneously process.
buffer_len
length of input and output buffers. Should be a multiple of
`block_width`
`block_width`.
"""
# Dictionary from name to scratch data buffers as ProcChainVar
self._vars_dict = {}
Expand Down Expand Up @@ -1083,15 +1082,21 @@ def __init__(
# check if arr_dims can be broadcast to match fun_dims
for i in range(max(len(fun_dims), len(arr_dims))):
fd = fun_dims[-i - 1] if i < len(fun_dims) else None
ad = arr_dims[-i - 1] if i < len(arr_dims) else None
ad = (
arr_dims[-i - 1]
if i < len(arr_dims)
else self.proc_chain._block_width
if i == len(arr_dims)
else None
)

if isinstance(fd, str):
if fd in dims_dict:
this_dim = dims_dict[fd]
if not ad or this_dim.length != ad:
raise ProcessingChainError(
f"failed to broadcast array dimensions for "
f"{func.__name}. Could not find consistent value "
f"{func.__name__}. Could not find consistent value "
f"for dimension {fd}"
)
if not this_dim.grid:
Expand Down Expand Up @@ -1164,6 +1169,7 @@ def __init__(
):
dim_list = outerdims.copy()
for d in dims.split(","):
d = d.strip()
if not d:
continue
if d not in dims_dict:
Expand Down Expand Up @@ -1311,11 +1317,12 @@ def __init__(self, var: ProcChainVar, unit: str | Unit) -> None:


class IOManager(metaclass=ABCMeta):
"""
Base class. IOManagers will be associated with a type of input/output
buffer, and must define a read and write for each one. __init__ methods
should update variable with any information from buffer, and check that
buffer and variable are compatible.
r"""Base class.
:class:`IOManager`\ s will be associated with a type of input/output
buffer, and must define a read and write for each one. ``__init__()``
methods should update variable with any information from buffer, and check
that buffer and variable are compatible.
"""

@abstractmethod
Expand All @@ -1333,7 +1340,7 @@ def __str__(self) -> str:

# Ok, this one's not LGDO
class NumpyIOManager(IOManager):
"""IO Manager for buffers that are numpy arrays"""
r""":class:`IOManager` for buffers that are :class:`numpy.ndarray`\ s."""

def __init__(self, io_buf: np.ndarray, var: ProcChainVar) -> None:
assert isinstance(io_buf, np.ndarray) and isinstance(var, ProcChainVar)
Expand Down Expand Up @@ -1368,7 +1375,7 @@ def __str__(self) -> str:


class LGDOArrayIOManager(IOManager):
"""IO Manager for buffers that are lgdo Arrays"""
r"""IO Manager for buffers that are :class:`lgdo.Array`\ s."""

def __init__(self, io_array: lgdo.Array, var: ProcChainVar) -> None:
assert isinstance(io_array, lgdo.Array) and isinstance(var, ProcChainVar)
Expand Down Expand Up @@ -1420,7 +1427,7 @@ def __str__(self) -> str:


class LGDOArrayOfEqualSizedArraysIOManager(IOManager):
"""IO Manager for buffers that are numpy ArrayOfEqualSizedArrays"""
r""":class:`IOManager` for buffers that are :class:`lgdo.ArrayOfEqualSizedArray`\ s."""

def __init__(self, io_array: np.ArrayOfEqualSizedArrays, var: ProcChainVar) -> None:
assert isinstance(io_array, lgdo.ArrayOfEqualSizedArrays) and isinstance(
Expand Down Expand Up @@ -1641,7 +1648,7 @@ def build_processing_chain(
(proc_chain, field_mask, lh5_out)
- `proc_chain` -- :class:`ProcessingChain` object that is executed
- `field_mask` -- list of input fields that are used
- `lh5_out` -- output :class:`~.lgdo.table.Table` containing processed
- `lh5_out` -- output :class:`~lgdo.table.Table` containing processed
values
"""
proc_chain = ProcessingChain(block_width, lh5_in.size)
Expand Down
Loading

0 comments on commit 89a3371

Please sign in to comment.