diff --git a/src/pygama/evt/aggregators.py b/src/pygama/evt/aggregators.py index 6a50549bf..3d1bc2eb0 100644 --- a/src/pygama/evt/aggregators.py +++ b/src/pygama/evt/aggregators.py @@ -124,7 +124,7 @@ def evaluate_to_first_or_last( (t0 > outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch] ) - return types.Array(nda=out, dtype=type(default_value)) + return types.Array(nda=out) def evaluate_to_scalar( @@ -216,7 +216,7 @@ def evaluate_to_scalar( res = res.astype(bool) out[evt_ids_ch] = out[evt_ids_ch] & res & limarr - return types.Array(nda=out, dtype=type(default_value)) + return types.Array(nda=out) def evaluate_at_channel( @@ -277,7 +277,7 @@ def evaluate_at_channel( out[evt_ids_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[evt_ids_ch]) - return types.Array(nda=out, dtype=type(default_value)) + return types.Array(nda=out) def evaluate_at_channel_vov( @@ -350,7 +350,7 @@ def evaluate_at_channel_vov( if ch == channels[0]: type_name = res.dtype - return types.VectorOfVectors(ak.values_astype(out, type_name), dtype=type_name) + return types.VectorOfVectors(ak.values_astype(out, type_name)) def evaluate_to_aoesa( @@ -526,5 +526,4 @@ def evaluate_to_vector( ak.values_astype( ak.drop_none(ak.nan_to_none(ak.Array(out))), type(default_value) ), - dtype=type(default_value), ) diff --git a/src/pygama/evt/build_evt.py b/src/pygama/evt/build_evt.py index d2a426550..70f175b84 100644 --- a/src/pygama/evt/build_evt.py +++ b/src/pygama/evt/build_evt.py @@ -56,6 +56,9 @@ def build_evt( (see :func:`evaluate_expression`), - ``query`` defines an expression to mask the aggregation. - ``parameters`` defines any other parameter used in expression. + - ``dtype` defines the NumPy data type of the resulting data. + - ``initial`` defines the initial/default value. Useful with some types + of aggregators. For example: @@ -259,6 +262,20 @@ def build_evt( if "lgdo_attrs" in v.keys(): obj.attrs |= v["lgdo_attrs"] + # cast to type, if required + # hijack the poor LGDO + if "dtype" in v: + type_ = v["dtype"] + + if isinstance(obj, Array): + obj.nda = obj.nda.astype(type_) + if isinstance(obj, VectorOfVectors): + fldata_ptr = obj.flattened_data + while isinstance(fldata_ptr, VectorOfVectors): + fldata_ptr = fldata_ptr.flattened_data + + fldata_ptr.nda = fldata_ptr.nda.astype(type_) + log.debug(f"new column {k!s} = {obj!r}") table.add_field(k, obj) diff --git a/tests/evt/configs/vov-test-evt-config.json b/tests/evt/configs/vov-test-evt-config.json index 31334101e..6de44075b 100644 --- a/tests/evt/configs/vov-test-evt-config.json +++ b/tests/evt/configs/vov-test-evt-config.json @@ -28,14 +28,15 @@ "channels": "geds_on", "aggregation_mode": "gather", "query": "hit.cuspEmax_ctc_cal>25", - "expression": "hit.cuspEmax_ctc_cal" + "expression": "hit.cuspEmax_ctc_cal", + "dtype": "float32" }, "energy_sum": { "channels": "geds_on", "aggregation_mode": "sum", "query": "hit.cuspEmax_ctc_cal>25", "expression": "hit.cuspEmax_ctc_cal", - "initial": 0.0 + "initial": 0 }, "energy_idx": { "channels": "geds_on", @@ -66,7 +67,8 @@ "aggregation_mode": "sum", "expression": "hit.cuspEmax_ctc_cal > a", "parameters": { "a": 25 }, - "initial": 0 + "initial": 0, + "dtype": "int16" }, "is_saturated": { "aggregation_mode": "keep_at_ch:evt.energy_id", diff --git a/tests/evt/test_build_evt.py b/tests/evt/test_build_evt.py index cb6abbdb5..127086cd8 100644 --- a/tests/evt/test_build_evt.py +++ b/tests/evt/test_build_evt.py @@ -123,18 +123,25 @@ def test_vov(lgnd_test_data, files_config): assert os.path.exists(outfile) assert len(lh5.ls(outfile, "/evt/")) == 12 + vov_ene, _ = store.read("/evt/energy", outfile) vov_aoe, _ = store.read("/evt/aoe", outfile) arr_ac, _ = store.read("/evt/multiplicity", outfile) vov_aoeene, _ = store.read("/evt/energy_times_aoe", outfile) vov_eneac, _ = store.read("/evt/energy_times_multiplicity", outfile) arr_ac2, _ = store.read("/evt/multiplicity_squared", outfile) + assert isinstance(vov_ene, VectorOfVectors) assert isinstance(vov_aoe, VectorOfVectors) assert isinstance(arr_ac, Array) assert isinstance(vov_aoeene, VectorOfVectors) assert isinstance(vov_eneac, VectorOfVectors) assert isinstance(arr_ac2, Array) + + assert vov_ene.dtype == "float32" + assert vov_aoe.dtype == "float64" + assert arr_ac.dtype == "int16" + assert (np.diff(vov_ene.cumulative_length.nda, prepend=[0]) == arr_ac.nda).all() vov_eid = store.read("/evt/energy_id", outfile)[0].view_as("ak") @@ -146,7 +153,9 @@ def test_vov(lgnd_test_data, files_config): assert ak.all(ids == vov_eid) arr_ene = store.read("/evt/energy_sum", outfile)[0].view_as("ak") - assert ak.all(arr_ene == ak.nansum(vov_ene.view_as("ak"), axis=-1)) + assert ak.all( + ak.isclose(arr_ene, ak.nansum(vov_ene.view_as("ak"), axis=-1), rtol=1e-3) + ) assert ak.all(vov_aoe.view_as("ak") == vov_aoe_idx)