Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed May 3, 2024
1 parent b91c37c commit a2f6c1e
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/pygama/evt/modules/xtalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ def build_energy_array(
return energies_out.T


def filter_hits(datainfo:utils.DataInfo,tcm:utils.TCMData,logic:str,corrected_energy:np.ndarray,rawids:np.ndarray)->np.ndarray:
def filter_hits(
datainfo: utils.DataInfo,
tcm: utils.TCMData,
logic: str,
corrected_energy: np.ndarray,
rawids: np.ndarray,
) -> np.ndarray:
"""
Function to which hits in an event are above threshold.
Parameters:
Expand All @@ -76,7 +82,7 @@ def filter_hits(datainfo:utils.DataInfo,tcm:utils.TCMData,logic:str,corrected_en
rawids
1D array of the rawids corresponding to each column
Returns
a numpy array of the mask of which
a numpy array of the mask of which
"""

Expand All @@ -87,19 +93,19 @@ def filter_hits(datainfo:utils.DataInfo,tcm:utils.TCMData,logic:str,corrected_en
logic = logic.replace(".", "___")

c = compile(logic, "gcc -O3 -ffast-math build_hit.py", "eval")
for idx_chan,channel in enumerate(rawids):
for idx_chan, channel in enumerate(rawids):
tbl = lgdo.Table()

for name in c.co_names:
if ("___" not in name):
if "___" not in name:
continue
tier, column = name.split("___")

try:
table_fmt = datainfo._asdict()[tier].table_fmt
group = datainfo._asdict()[tier].group
file = datainfo._asdict()[tier].file
keys=ls(file)
keys = ls(file)
table_id = utils.get_tcm_id_by_pattern(table_fmt, f"ch{channel}")
idx_events = ak.to_numpy(tcm.idx[tcm.id == table_id])

Expand All @@ -108,17 +114,18 @@ def filter_hits(datainfo:utils.DataInfo,tcm:utils.TCMData,logic:str,corrected_en
data = lh5.read(
f"ch{channel}/{group}/{column}", file, idx=idx_events
)
tbl.add_column(name,data)
tbl.add_column(name, data)
except KeyError:
pass

# add the corrected energy to the table
tbl.add_column("corrected_energy",lgdo.Array(corrected_energy[:][idx_chan]))
tbl.add_column("corrected_energy", lgdo.Array(corrected_energy[:][idx_chan]))
res = tbl.eval(logic)
mask[idx_chan][idx_events]=res
mask[idx_chan][idx_events] = res

return mask


def xtalk_corrected_energy(
uncalibrated_energies: np.ndarray,
calibrated_energies: np.ndarray,
Expand Down

0 comments on commit a2f6c1e

Please sign in to comment.