diff --git a/src/pygama/evt/modules/cross_talk.py b/src/pygama/evt/modules/cross_talk.py index b6dd88e28..5d69211ad 100644 --- a/src/pygama/evt/modules/cross_talk.py +++ b/src/pygama/evt/modules/cross_talk.py @@ -4,12 +4,13 @@ import awkward as ak import numpy as np -from lgdo import Array, VectorOfVectors -from lgdo.lh5 import LH5Store +from lgdo import VectorOfVectors -from pygama.evt import utils -def cross_talk_corrected_energy_awkard_slow(energies:ak.Array,rawids:ak.Array,matrix:dict,allow_non_existing:bool=True): + +def cross_talk_corrected_energy_awkard_slow( + energies: ak.Array, rawids: ak.Array, matrix: dict, allow_non_existing: bool = True +): """ Function to perform the cross talk correction on awkward arrays of energy and rawid. The energies are first sorted from largest to smallest, a term is then added to the @@ -40,75 +41,95 @@ def cross_talk_corrected_energy_awkard_slow(energies:ak.Array,rawids:ak.Array,ma if not isinstance(rawids, ak.Array): raise TypeError("rawids must be an awkward array") - if not isinstance(matrix,dict): + if not isinstance(matrix, dict): raise TypeError("matrix must be a python dictonary") - if not isinstance(allow_non_existing,bool): + if not isinstance(allow_non_existing, bool): raise TypeError("allow_non_existing must be a Boolean") # first check that energies and rawids have the same dimensions - if (ak.all(ak.num(energies,axis=-1)!=ak.num(rawids,axis=-1))): - raise ValueError("Error: the length of each subarray of energies and rawids must be equal") - - if (ak.num(energies,axis=-2)!=ak.num(rawids,axis=-2)): + if ak.all(ak.num(energies, axis=-1) != ak.num(rawids, axis=-1)): + raise ValueError( + "Error: the length of each subarray of energies and rawids must be equal" + ) + + if ak.num(energies, axis=-2) != ak.num(rawids, axis=-2): raise ValueError("Error: the number of energies is not equal to rawids") - + # check that the matrix elements exist for c1 in np.unique(ak.flatten(rawids).to_numpy()): if c1 not in matrix.keys(): - if allow_non_existing==True: - matrix[c1]={} + if allow_non_existing == True: + matrix[c1] = {} else: - raise ValueError(f"Error allow_non_existing is set to False and {c1} isnt present in the matrix") + raise ValueError( + f"Error allow_non_existing is set to False and {c1} isnt present in the matrix" + ) for c2 in np.unique(ak.flatten(rawids).to_numpy()): - if (c1==c2): + if c1 == c2: continue else: if c2 not in matrix[c1].keys(): - if allow_non_existing==True: - matrix[c1][c2]=0 + if allow_non_existing == True: + matrix[c1][c2] = 0 else: - raise ValueError(f"Error allow_non_existing is set to False and {c2} isnt present in the matrix[{c1}]") + raise ValueError( + f"Error allow_non_existing is set to False and {c2} isnt present in the matrix[{c1}]" + ) ## add a check that the matrix is symmetric for c1 in matrix.keys(): for c2 in matrix[c1].keys(): - if abs(matrix[c1][c2]-matrix[c2][c1])>1e-6: - raise ValueError(f"Error input cross talk matrix is not symmetric for {c1},{c2}") - + if abs(matrix[c1][c2] - matrix[c2][c1]) > 1e-6: + raise ValueError( + f"Error input cross talk matrix is not symmetric for {c1},{c2}" + ) ## sort the energies and rawids - args = ak.argsort(energies,ascending=False) + args = ak.argsort(energies, ascending=False) - energies= energies[args] + energies = energies[args] rawids = rawids[args] energies_corrected = [] ## we should try to speed this up - for energy_vec_tmp,rawid_vec_tmp in zip(energies,rawids): - + for energy_vec_tmp, rawid_vec_tmp in zip(energies, rawids): + energies_corrected_tmp = list(energy_vec_tmp) - for id_main, (energy_main,rawid_main) in enumerate(zip(energy_vec_tmp,rawid_vec_tmp)): - for id_other, (energy_other,rawid_other) in enumerate(zip(energy_vec_tmp,rawid_vec_tmp)): + for id_main, (energy_main, rawid_main) in enumerate( + zip(energy_vec_tmp, rawid_vec_tmp) + ): + for id_other, (energy_other, rawid_other) in enumerate( + zip(energy_vec_tmp, rawid_vec_tmp) + ): + + if id_main != id_other: + energies_corrected_tmp[id_other] += ( + matrix[rawid_main][rawid_other] * energy_main + ) - if (id_main!=id_other): - energies_corrected_tmp[id_other]+=matrix[rawid_main][rawid_other]*energy_main - energies_corrected.append(energies_corrected_tmp) ## convert to awkward array and unsort return ak.Array(energies_corrected)[args] - - - -def get_energy_corrected(f_hit:str,f_dsp:str,f_tcm:str,hit_group:str,dsp_group:str,tcm_group:str,tcm_id_table_pattern:str,channels:list, - cross_talk_matrix:str,energy_variable:str)->VectorOfVectors: +def get_energy_corrected( + f_hit: str, + f_dsp: str, + f_tcm: str, + hit_group: str, + dsp_group: str, + tcm_group: str, + tcm_id_table_pattern: str, + channels: list, + cross_talk_matrix: str, + energy_variable: str, +) -> VectorOfVectors: """ Function to compute cross-talk corrected energies. Parameters @@ -129,4 +150,3 @@ def get_energy_corrected(f_hit:str,f_dsp:str,f_tcm:str,hit_group:str,dsp_group:s """ raise NotImplementedError("error cross talk correction is not yet implemented") - diff --git a/src/pygama/evt/modules/geds.py b/src/pygama/evt/modules/geds.py index 51b465ae3..7b327f8da 100644 --- a/src/pygama/evt/modules/geds.py +++ b/src/pygama/evt/modules/geds.py @@ -2,13 +2,14 @@ from __future__ import annotations +import json from collections.abc import Sequence -from lgdo import lh5, types +from lgdo import types from .. import utils from . import cross_talk -import json + def apply_xtalk_correction( datainfo: utils.DataInfo, @@ -18,7 +19,7 @@ def apply_xtalk_correction( energy_observable: types.VectorOfVectors, rawids: types.VectorOfVectors, xtalk_matrix_filename: str, - threshold:float + threshold: float, ) -> types.VectorOfVectors: """Applies the cross-talk correction to the energy observable. The format of `xtalk_matrix_filename` should be currently be a path to a JSON file. @@ -44,15 +45,16 @@ def apply_xtalk_correction( """ # read in xtalk matrices (currently a json file) - with open(xtalk_matrix_filename, 'r') as file: - cross_talk_matrix = json.load(file) + with open(xtalk_matrix_filename) as file: + cross_talk_matrix = json.load(file) # do the correction - energies_corr = cross_talk.cross_talk_corrected_energy_awkard_slow(energies=energy_observable, - rawids=rawids, - matrix=cross_talk_matrix, - allow_non_existing=False - ) + energies_corr = cross_talk.cross_talk_corrected_energy_awkard_slow( + energies=energy_observable, + rawids=rawids, + matrix=cross_talk_matrix, + allow_non_existing=False, + ) # return the result as LGDO return types.VectorOfVectors( diff --git a/tests/evt/test_cross_talk.py b/tests/evt/test_cross_talk.py index 561edaaeb..958c5d4f8 100644 --- a/tests/evt/test_cross_talk.py +++ b/tests/evt/test_cross_talk.py @@ -1,56 +1,71 @@ import awkward as ak -from pygama.evt.modules import cross_talk import pytest + +from pygama.evt.modules import cross_talk + + def test_cross_talk_corrected_energy_awkard_slow(): - - energies_test=ak.Array([[1000,200],[100],[500,3000,100]]) - rawid_test =ak.Array([[1,2],[3],[1,2,3]]) + energies_test = ak.Array([[1000, 200], [100], [500, 3000, 100]]) + rawid_test = ak.Array([[1, 2], [3], [1, 2, 3]]) ## first check exceptions - matrix={1:{1:1.00, 2:0.01, 3:0.02, 4:0.00}, - 2:{1:0.01, 2:1.00, 3:0, 4:0.01}, - 3:{1:0.02, 2:0.00, 3:1.00, 4:0.01}, - 4:{1:0.00, 2:0.01, 3:0.01, 4:1.00} - } + matrix = { + 1: {1: 1.00, 2: 0.01, 3: 0.02, 4: 0.00}, + 2: {1: 0.01, 2: 1.00, 3: 0, 4: 0.01}, + 3: {1: 0.02, 2: 0.00, 3: 1.00, 4: 0.01}, + 4: {1: 0.00, 2: 0.01, 3: 0.01, 4: 1.00}, + } # if rawid and energies have different shapes (juts the first entry) with pytest.raises(ValueError): - cross_talk.cross_talk_corrected_energy_awkard_slow(ak.Array([[1000,200]]),rawid_test,matrix,True) + cross_talk.cross_talk_corrected_energy_awkard_slow( + ak.Array([[1000, 200]]), rawid_test, matrix, True + ) # filter some values from energy first so each event has a different size with pytest.raises(ValueError): - cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test[energies_test!=1000],rawid_test,matrix,True) + cross_talk.cross_talk_corrected_energy_awkard_slow( + energies_test[energies_test != 1000], rawid_test, matrix, True + ) ## checks on the matrix # first check if the matrix has empty elements an exception is raised if allow_non_existing is false - matrix_not_full={1:{1:1.00, 2:0.01, 3:0.02, 4:0.00}, - 2:{1:0.01, 2:1.00, 3:0}, - 3:{1:0.02, 2:0.00, 3:1.00, 4:0.01}, - 4:{1:0.00, 3:0.01, 4:1.00} - } - + matrix_not_full = { + 1: {1: 1.00, 2: 0.01, 3: 0.02, 4: 0.00}, + 2: {1: 0.01, 2: 1.00, 3: 0}, + 3: {1: 0.02, 2: 0.00, 3: 1.00, 4: 0.01}, + 4: {1: 0.00, 3: 0.01, 4: 1.00}, + } + + with pytest.raises(ValueError): + cross_talk.cross_talk_corrected_energy_awkard_slow( + energies_test, rawid_test, matrix_not_full, False + ) + + matrix_not_sym = { + 1: {1: 1.00, 2: 0.0, 3: 0.02, 4: 0.00}, + 2: {1: 0.01, 2: 1.00, 3: 0, 4: 0.01}, + 3: {1: 0.02, 2: 0.00, 3: 1.00, 4: 0.01}, + 4: {1: 0.00, 2: 0.01, 3: 0.01, 4: 1.00}, + } with pytest.raises(ValueError): - cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test,rawid_test,matrix_not_full,False) - - matrix_not_sym={1:{1:1.00, 2:0.0, 3:0.02, 4:0.00}, - 2:{1:0.01, 2:1.00, 3:0, 4:0.01}, - 3:{1:0.02, 2:0.00, 3:1.00, 4:0.01}, - 4:{1:0.00, 2:0.01, 3:0.01, 4:1.00} - } - with pytest.raises(ValueError): - cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test,rawid_test,matrix_not_sym,True) + cross_talk.cross_talk_corrected_energy_awkard_slow( + energies_test, rawid_test, matrix_not_sym, True + ) ## now check the result returned (given no exceptions) - ### 1st event is two channels 1% cross talk and [1000,200] energy so we expect Ecorr = [1002,210] + ### 1st event is two channels 1% cross talk and [1000,200] energy so we expect Ecorr = [1002,210] ### 2nd event one channel so the energy shouldnt change ### 3rd event 3 channels [500,3000,100] energy ch1 has 1% ct to ch2 and 2% to ch3 so we will have ### E=[500,3005,102] then ch2 has 1% cross talk to ch1 and 0 to ch3 ### E= [503,3005,102] finally ch3 has 2% cross talk with ch1 ### E=[505,3005,102] - - energy_corr =cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test,rawid_test,matrix,True) - assert np.all(energy_corr[0].to_numpy()==np.array([1002,210])) - assert np.all(energy_corr[1].to_numpy()==np.array([100])) - assert np.all(energy_corr[2].to_numpy()==np.array([505,3005,102])) \ No newline at end of file + energy_corr = cross_talk.cross_talk_corrected_energy_awkard_slow( + energies_test, rawid_test, matrix, True + ) + + assert np.all(energy_corr[0].to_numpy() == np.array([1002, 210])) + assert np.all(energy_corr[1].to_numpy() == np.array([100])) + assert np.all(energy_corr[2].to_numpy() == np.array([505, 3005, 102]))