Skip to content

Commit

Permalink
Merge pull request #3 from ggmarshall/xtalk
Browse files Browse the repository at this point in the history
Xtalk
  • Loading branch information
tdixon97 authored May 4, 2024
2 parents a2f6c1e + 5c46392 commit b3b03fc
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 157 deletions.
170 changes: 70 additions & 100 deletions src/pygama/evt/modules/geds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@

from collections.abc import Sequence

import awkward as ak
import numpy as np
from legendmeta.catalog import Props
from lgdo import lh5, types
from lgdo.lh5 import ls

from pygama.hit.build_hit import _reorder_table_operations

from .. import utils
from . import xtalk
Expand Down Expand Up @@ -49,18 +44,20 @@ def apply_xtalk_correction(
tcm: utils.TCMData,
table_names: Sequence[str],
*,
mode: str,
uncalibrated_energy_name: str,
calibrated_energy_name: str,
multiplicity_logic: str,
threshold: float = None,
xtalk_matrix_filename: str,
xtalk_matrix_filename: str = "",
xtalk_rawid_name: str = "xtc/rawid_index",
xtalk_matrix_name: str = "xtc/xtalk_matrix_negative",
positive_xtalk_matrix_name: str = "xtc/xtalk_matrix_positive",
) -> types.VectorOfVectors:
"""Applies the cross-talk correction to the energy observable.
The format of `xtalk_matrix_filename` should be currently be a path to a lh5 file.
The correction is appplied using matrix algebra for all triggers above the threshold.
The correction is applied using matrix algebra for all triggers above the threshold.
Parameters
----------
Expand All @@ -87,53 +84,49 @@ def apply_xtalk_correction(
name of the lh5 object containing the name of the rawids
"""

# read lh5 files to numpy
xtalk_matrix_numpy = lh5.read_as(xtalk_matrix_name, xtalk_matrix_filename, "np")
print(f"Read {xtalk_rawid_name}, {xtalk_matrix_filename}")
xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_name, xtalk_matrix_filename, "np")

positive_xtalk_matrix_numpy = lh5.read_as(
positive_xtalk_matrix_name, xtalk_matrix_filename, "np"
tcm_id_array = xtalk.build_tcm_id_array(tcm, datainfo, xtalk_matrix_rawids)

energies_corr = xtalk.get_xtalk_correction(
tcm,
datainfo,
uncalibrated_energy_name,
calibrated_energy_name,
threshold,
xtalk_matrix_filename,
xtalk_rawid_name,
xtalk_matrix_name,
positive_xtalk_matrix_name,
)

# Combine positive and negative matrixs
# Now the matrix should have negative values corresponding to negative cross talk
# and positive values corresponding to positive cross talk .
# we also set nan to 0 and we transpose so that the row corresponds to response and column trigger
xtalk_matrix = np.nan_to_num(
np.where(
abs(xtalk_matrix_numpy) > abs(positive_xtalk_matrix_numpy),
xtalk_matrix_numpy,
positive_xtalk_matrix_numpy,
),
0,
).T

uncalibrated_energy_array = xtalk.build_energy_array(
uncalibrated_energy_name, tcm, datainfo, xtalk_matrix_rawids
)
calibrated_energy_array = xtalk.build_energy_array(
calibrated_energy_name, tcm, datainfo, xtalk_matrix_rawids
multiplicity_mask = xtalk.filter_hits(
datainfo,
tcm,
multiplicity_logic,
energies_corr,
xtalk_matrix_rawids,
)

energies_corr = xtalk.xtalk_corrected_energy(
uncalibrated_energy_array, calibrated_energy_array, xtalk_matrix, threshold
)

# return the result as LGDO
return types.VectorOfVectors(energies_corr)
if mode == "energy":
return types.VectorOfVectors(energies_corr[multiplicity_mask])
elif mode == "tcm_id":
return types.VectorOfVectors(tcm_id_array[multiplicity_mask])
else:
raise ValueError(f"Unknown mode: {mode}")


def apply_xtalk_correction_and_calibrate(
datainfo: utils.DataInfo,
tcm: utils.TCMData,
table_names: Sequence[str],
*,
mode: str,
uncalibrated_energy_name: str,
calibrated_energy_name: str,
par_files: str | list[str],
multiplicity_logic: str,
threshold: float = None,
xtalk_matrix_filename: str,
xtalk_matrix_filename: str = "",
xtalk_rawid_name: str = "xtc/rawid_index",
xtalk_matrix_name: str = "xtc/xtalk_matrix_negative",
positive_xtalk_matrix_name: str = "xtc/xtalk_matrix_positive",
Expand All @@ -142,7 +135,7 @@ def apply_xtalk_correction_and_calibrate(
"""Applies the cross-talk correction to the energy observable.
The format of `xtalk_matrix_filename` should be currently be a path to a lh5 file.
The correction is appplied using matrix algebra for all triggers above the threshold.
The correction is applied using matrix algebra for all triggers above the threshold.
Parameters
----------
Expand All @@ -169,70 +162,47 @@ def apply_xtalk_correction_and_calibrate(
name of the lh5 object containing the name of the rawids
"""

xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_name, xtalk_matrix_filename, "np")
tcm_id_array = xtalk.build_tcm_id_array(tcm, datainfo, xtalk_matrix_rawids)

energies_corr = xtalk.get_xtalk_correction(
tcm,
datainfo,
uncalibrated_energy_name,
calibrated_energy_name,
threshold,
xtalk_matrix_filename,
xtalk_rawid_name,
xtalk_matrix_name,
positive_xtalk_matrix_name,
)

energies_corr = apply_xtalk_correction(
datainfo
tcm
table_names
*,
uncalibrated_energy_name = uncalibrated_energy_name,
calibrated_energy_name = calibrated_energy_name,
threshold = threshold,
xtalk_matrix_filename = xtalk_matrix_filename,
xtalk_rawid_name = xtalk_rawid_name,
xtalk_matrix_name = xtalk_matrix_name,
positive_xtalk_matrix_name = positive_xtalk_matrix_name,
).view_as("np")

out_arr = np.full_like(energies_corr, np.nan)
par_dicts = Props.read_from(par_files)
pars = {
chan: chan_dict["pars"]["operations"] for chan, chan_dict in par_dicts.items()
}

p = uncalibrated_energy_name.split(".")
tier = p[0] if len(p) > 1 else "hit"
column = p[1] if len(p) > 1 else p[0]

table_fmt = datainfo._asdict()[tier].table_fmt
group = datainfo._asdict()[tier].group
file = datainfo._asdict()[tier].file

keys = ls(file)
xtalk_matrix_rawids = lh5.read_as(xtalk_rawid_name, xtalk_matrix_filename, "np")

if out_param is None:
out_param = calibrated_energy_name.split(".")[-1]

for i, chan in enumerate(xtalk_matrix_rawids):
try:
cfg = pars[f"ch{chan}"]
cfg, chan_inputs = xtalk.remove_uneeded_operations(
_reorder_table_operations(cfg), out_param
)
chan_inputs.remove(uncalibrated_energy_name.split(".")[-1])

# get the event indexs
table_id = utils.get_tcm_id_by_pattern(table_fmt, f"ch{chan}")
idx_events = ak.to_numpy(tcm.idx[tcm.id == table_id])

# read the energy data
if f"ch{chan}" in keys:
outtbl_obj = sto.read(
f"ch{chan}/dsp/", file, idx=idx_events, field_mask=chan_inputs
)[0]
outtbl_obj.add_column(
uncalibrated_energy_name.split(".")[-1],
types.Array(energies_corr[:, i]),
)

for outname, info in cfg.items():
outcol = outtbl_obj.eval(
info["expression"], info.get("parameters", None)
)
outtbl_obj.add_column(outname, outcol)
out_arr[:, i] = outtbl_obj[out_param].nda
except KeyError:
out_arr[:, i] = np.nan
calibrated_corr = xtalk.calibrate_energy(
datainfo,
tcm,
energies_corr,
xtalk_matrix_rawids,
par_files,
uncalibrated_energy_name,
out_param,
)

multiplicity_mask = xtalk.filter_hits(
datainfo,
tcm,
multiplicity_logic,
calibrated_corr,
xtalk_matrix_rawids,
)

# return the result as LGDO
return types.VectorOfVectors(out_arr)
if mode == "energy":
return types.VectorOfVectors(calibrated_corr[multiplicity_mask])
elif mode == "tcm_id":
return types.VectorOfVectors(tcm_id_array[multiplicity_mask])
else:
raise ValueError(f"Unknown mode: {mode}")
Loading

0 comments on commit b3b03fc

Please sign in to comment.