Skip to content

Commit

Permalink
build_evt(): support declaring dtype of output columns
Browse files Browse the repository at this point in the history
  • Loading branch information
gipert committed Mar 31, 2024
1 parent af8f3ef commit e5c0527
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 9 deletions.
9 changes: 4 additions & 5 deletions src/pygama/evt/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def evaluate_to_first_or_last(
(t0 > outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch]
)

return types.Array(nda=out, dtype=type(default_value))
return types.Array(nda=out)


def evaluate_to_scalar(
Expand Down Expand Up @@ -216,7 +216,7 @@ def evaluate_to_scalar(
res = res.astype(bool)
out[evt_ids_ch] = out[evt_ids_ch] & res & limarr

return types.Array(nda=out, dtype=type(default_value))
return types.Array(nda=out)


def evaluate_at_channel(
Expand Down Expand Up @@ -277,7 +277,7 @@ def evaluate_at_channel(

out[evt_ids_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[evt_ids_ch])

return types.Array(nda=out, dtype=type(default_value))
return types.Array(nda=out)


def evaluate_at_channel_vov(
Expand Down Expand Up @@ -350,7 +350,7 @@ def evaluate_at_channel_vov(
if ch == channels[0]:
type_name = res.dtype

return types.VectorOfVectors(ak.values_astype(out, type_name), dtype=type_name)
return types.VectorOfVectors(ak.values_astype(out, type_name))


def evaluate_to_aoesa(
Expand Down Expand Up @@ -526,5 +526,4 @@ def evaluate_to_vector(
ak.values_astype(
ak.drop_none(ak.nan_to_none(ak.Array(out))), type(default_value)
),
dtype=type(default_value),
)
17 changes: 17 additions & 0 deletions src/pygama/evt/build_evt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def build_evt(
(see :func:`evaluate_expression`),
- ``query`` defines an expression to mask the aggregation.
- ``parameters`` defines any other parameter used in expression.
- ``dtype` defines the NumPy data type of the resulting data.
- ``initial`` defines the initial/default value. Useful with some types
of aggregators.
For example:
Expand Down Expand Up @@ -259,6 +262,20 @@ def build_evt(
if "lgdo_attrs" in v.keys():
obj.attrs |= v["lgdo_attrs"]

# cast to type, if required
# hijack the poor LGDO
if "dtype" in v:
type_ = v["dtype"]

if isinstance(obj, Array):
obj.nda = obj.nda.astype(type_)
if isinstance(obj, VectorOfVectors):
fldata_ptr = obj.flattened_data
while isinstance(fldata_ptr, VectorOfVectors):
fldata_ptr = fldata_ptr.flattened_data

fldata_ptr.nda = fldata_ptr.nda.astype(type_)

log.debug(f"new column {k!s} = {obj!r}")
table.add_field(k, obj)

Expand Down
8 changes: 5 additions & 3 deletions tests/evt/configs/vov-test-evt-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
"channels": "geds_on",
"aggregation_mode": "gather",
"query": "hit.cuspEmax_ctc_cal>25",
"expression": "hit.cuspEmax_ctc_cal"
"expression": "hit.cuspEmax_ctc_cal",
"dtype": "float32"
},
"energy_sum": {
"channels": "geds_on",
"aggregation_mode": "sum",
"query": "hit.cuspEmax_ctc_cal>25",
"expression": "hit.cuspEmax_ctc_cal",
"initial": 0.0
"initial": 0
},
"energy_idx": {
"channels": "geds_on",
Expand Down Expand Up @@ -66,7 +67,8 @@
"aggregation_mode": "sum",
"expression": "hit.cuspEmax_ctc_cal > a",
"parameters": { "a": 25 },
"initial": 0
"initial": 0,
"dtype": "int16"
},
"is_saturated": {
"aggregation_mode": "keep_at_ch:evt.energy_id",
Expand Down
11 changes: 10 additions & 1 deletion tests/evt/test_build_evt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,25 @@ def test_vov(lgnd_test_data, files_config):

assert os.path.exists(outfile)
assert len(lh5.ls(outfile, "/evt/")) == 12

vov_ene, _ = store.read("/evt/energy", outfile)
vov_aoe, _ = store.read("/evt/aoe", outfile)
arr_ac, _ = store.read("/evt/multiplicity", outfile)
vov_aoeene, _ = store.read("/evt/energy_times_aoe", outfile)
vov_eneac, _ = store.read("/evt/energy_times_multiplicity", outfile)
arr_ac2, _ = store.read("/evt/multiplicity_squared", outfile)

assert isinstance(vov_ene, VectorOfVectors)
assert isinstance(vov_aoe, VectorOfVectors)
assert isinstance(arr_ac, Array)
assert isinstance(vov_aoeene, VectorOfVectors)
assert isinstance(vov_eneac, VectorOfVectors)
assert isinstance(arr_ac2, Array)

assert vov_ene.dtype == "float32"
assert vov_aoe.dtype == "float64"
assert arr_ac.dtype == "int16"

assert (np.diff(vov_ene.cumulative_length.nda, prepend=[0]) == arr_ac.nda).all()

vov_eid = store.read("/evt/energy_id", outfile)[0].view_as("ak")
Expand All @@ -146,7 +153,9 @@ def test_vov(lgnd_test_data, files_config):
assert ak.all(ids == vov_eid)

arr_ene = store.read("/evt/energy_sum", outfile)[0].view_as("ak")
assert ak.all(arr_ene == ak.nansum(vov_ene.view_as("ak"), axis=-1))
assert ak.all(
ak.isclose(arr_ene, ak.nansum(vov_ene.view_as("ak"), axis=-1), rtol=1e-3)
)
assert ak.all(vov_aoe.view_as("ak") == vov_aoe_idx)


Expand Down

0 comments on commit e5c0527

Please sign in to comment.