diff --git a/src/pygama/evt/utils.py b/src/pygama/evt/utils.py index 4f8391353..4aedc1438 100644 --- a/src/pygama/evt/utils.py +++ b/src/pygama/evt/utils.py @@ -16,16 +16,22 @@ H5DataLoc = namedtuple( "H5DataLoc", ("file", "group", "table_fmt"), defaults=3 * (None,) ) - -DataInfo = namedtuple( - "DataInfo", ("raw", "tcm", "dsp", "hit", "evt"), defaults=5 * (None,) -) +DataInfo = namedtuple("DataInfo", ("raw", "tcm", "evt"), defaults=3 * (None,)) TCMData = namedtuple("TCMData", ("id", "idx", "cumulative_length")) def make_files_config(data: dict): - if not isinstance(data, DataInfo): + if not isinstance(data, tuple): + if "raw" not in data: + data["raw"] = (None,) + if "tcm" not in data: + data["tcm"] = (None,) + if "evt" not in data: + data["evt"] = (None,) + DataInfo = namedtuple( + "DataInfo", tuple(data.keys()), defaults=len(data.keys()) * (None,) + ) return DataInfo( *[ H5DataLoc(*data[tier]) if tier in data else H5DataLoc() @@ -72,7 +78,7 @@ def find_parameters( idx_ch, field_list, ) -> dict: - """Finds and returns parameters from `hit` and `dsp` tiers. + """Finds and returns parameters from non `tcm`, `evt` tiers. Parameters ---------- @@ -83,43 +89,38 @@ def find_parameters( idx_ch index array of entries to be read from datainfo. field_list - list of tuples ``(tier, field)`` to be found in the `hit/dsp` tiers. + list of tuples ``(tier, field)`` to be found in non `tcm`, `evt` tiers. """ f = make_files_config(datainfo) - # find fields in either dsp, hit - dsp_flds = [e[1] for e in field_list if e[0] == f.dsp.group] - hit_flds = [e[1] for e in field_list if e[0] == f.hit.group] + final_dict = {} - hit_dict, dsp_dict = {}, {} + for name, tier in f._asdict().items(): + if name not in ["tcm", "evt"] and tier.file is not None: # skip other tables + keys = [ + k.split("/")[-1] + for k in lh5.ls(tier.file, f"{ch.replace('/', '')}/{tier.group}/") + ] + flds = [e[1] for e in field_list if e[0] == name and e[1] in keys] - if len(hit_flds) > 0: - hit_ak = lh5.read_as( - f"{ch.replace('/', '')}/{f.hit.group}/", - f.hit.file, - field_mask=hit_flds, - idx=idx_ch, - library="ak", - ) + if len(flds) > 0: + tier_ak = lh5.read_as( + f"{ch.replace('/', '')}/{tier.group}/", + tier.file, + field_mask=flds, + idx=idx_ch, + library="ak", + ) - hit_dict = dict( - zip([f"{f.hit.group}_" + e for e in ak.fields(hit_ak)], ak.unzip(hit_ak)) - ) + tier_dict = dict( + zip( + [f"{name}_" + e for e in ak.fields(tier_ak)], + ak.unzip(tier_ak), + ) + ) + final_dict = final_dict | tier_dict - if len(dsp_flds) > 0: - dsp_ak = lh5.read_as( - f"{ch.replace('/', '')}/{f.dsp.group}/", - f.dsp.file, - field_mask=dsp_flds, - idx=idx_ch, - library="ak", - ) - - dsp_dict = dict( - zip([f"{f.dsp.group}_" + e for e in ak.fields(dsp_ak)], ak.unzip(dsp_ak)) - ) - - return hit_dict | dsp_dict + return final_dict def get_data_at_channel( @@ -178,10 +179,16 @@ def get_data_at_channel( # evaluate expression # move tier+dots in expression to underscores (e.g. evt.foo -> evt_foo) + + new_expr = expr + for name in f._asdict(): + if name == "evt": + new_expr = new_expr.replace(f"{name}.", "") + elif name not in ["tcm", "raw"]: + new_expr = new_expr.replace(f"{name}.", f"{name}_") + res = eval( - expr.replace(f"{f.dsp.group}.", f"{f.dsp.group}_") - .replace(f"{f.hit.group}.", f"{f.hit.group}_") - .replace(f"{f.evt.group}.", ""), + new_expr, var, ) @@ -231,17 +238,23 @@ def get_mask_from_query( # get sub evt based query condition if needed if isinstance(query, str): - query_lst = re.findall(r"(hit|dsp).([a-zA-Z_$][\w$]*)", query) + query_lst = re.findall( + rf"({'|'.join(f._asdict().keys())}).([a-zA-Z_$][\w$]*)", query + ) query_var = find_parameters( datainfo=datainfo, ch=ch, idx_ch=idx_ch, field_list=query_lst, ) + + new_query = query + for name in f._asdict(): + if name not in ["tcm", "evt"]: + new_query = new_query.replace(f"{name}.", f"{name}_") + limarr = eval( - query.replace(f"{f.dsp.group}.", f"{f.dsp.group}_").replace( - f"{f.hit.group}.", f"{f.hit.group}_" - ), + new_query, query_var, )