Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Apr 22, 2024
1 parent 66572ce commit 1ae805e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 78 deletions.
92 changes: 56 additions & 36 deletions src/pygama/evt/modules/cross_talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import awkward as ak
import numpy as np
from lgdo import Array, VectorOfVectors
from lgdo.lh5 import LH5Store
from lgdo import VectorOfVectors

from pygama.evt import utils

def cross_talk_corrected_energy_awkard_slow(energies:ak.Array,rawids:ak.Array,matrix:dict,allow_non_existing:bool=True):

def cross_talk_corrected_energy_awkard_slow(
energies: ak.Array, rawids: ak.Array, matrix: dict, allow_non_existing: bool = True
):
"""
Function to perform the cross talk correction on awkward arrays of energy and rawid.
The energies are first sorted from largest to smallest, a term is then added to the
Expand Down Expand Up @@ -40,75 +41,95 @@ def cross_talk_corrected_energy_awkard_slow(energies:ak.Array,rawids:ak.Array,ma
if not isinstance(rawids, ak.Array):
raise TypeError("rawids must be an awkward array")

if not isinstance(matrix,dict):
if not isinstance(matrix, dict):
raise TypeError("matrix must be a python dictonary")

if not isinstance(allow_non_existing,bool):
if not isinstance(allow_non_existing, bool):
raise TypeError("allow_non_existing must be a Boolean")

# first check that energies and rawids have the same dimensions
if (ak.all(ak.num(energies,axis=-1)!=ak.num(rawids,axis=-1))):
raise ValueError("Error: the length of each subarray of energies and rawids must be equal")

if (ak.num(energies,axis=-2)!=ak.num(rawids,axis=-2)):
if ak.all(ak.num(energies, axis=-1) != ak.num(rawids, axis=-1)):
raise ValueError(
"Error: the length of each subarray of energies and rawids must be equal"
)

if ak.num(energies, axis=-2) != ak.num(rawids, axis=-2):
raise ValueError("Error: the number of energies is not equal to rawids")

# check that the matrix elements exist

for c1 in np.unique(ak.flatten(rawids).to_numpy()):
if c1 not in matrix.keys():

if allow_non_existing==True:
matrix[c1]={}
if allow_non_existing == True:
matrix[c1] = {}
else:
raise ValueError(f"Error allow_non_existing is set to False and {c1} isnt present in the matrix")
raise ValueError(
f"Error allow_non_existing is set to False and {c1} isnt present in the matrix"
)

for c2 in np.unique(ak.flatten(rawids).to_numpy()):
if (c1==c2):
if c1 == c2:
continue
else:
if c2 not in matrix[c1].keys():
if allow_non_existing==True:
matrix[c1][c2]=0
if allow_non_existing == True:
matrix[c1][c2] = 0
else:
raise ValueError(f"Error allow_non_existing is set to False and {c2} isnt present in the matrix[{c1}]")
raise ValueError(
f"Error allow_non_existing is set to False and {c2} isnt present in the matrix[{c1}]"
)

## add a check that the matrix is symmetric

for c1 in matrix.keys():
for c2 in matrix[c1].keys():
if abs(matrix[c1][c2]-matrix[c2][c1])>1e-6:
raise ValueError(f"Error input cross talk matrix is not symmetric for {c1},{c2}")

if abs(matrix[c1][c2] - matrix[c2][c1]) > 1e-6:
raise ValueError(
f"Error input cross talk matrix is not symmetric for {c1},{c2}"
)

## sort the energies and rawids
args = ak.argsort(energies,ascending=False)
args = ak.argsort(energies, ascending=False)

energies= energies[args]
energies = energies[args]
rawids = rawids[args]

energies_corrected = []
## we should try to speed this up
for energy_vec_tmp,rawid_vec_tmp in zip(energies,rawids):
for energy_vec_tmp, rawid_vec_tmp in zip(energies, rawids):

energies_corrected_tmp = list(energy_vec_tmp)
for id_main, (energy_main,rawid_main) in enumerate(zip(energy_vec_tmp,rawid_vec_tmp)):
for id_other, (energy_other,rawid_other) in enumerate(zip(energy_vec_tmp,rawid_vec_tmp)):
for id_main, (energy_main, rawid_main) in enumerate(
zip(energy_vec_tmp, rawid_vec_tmp)
):
for id_other, (energy_other, rawid_other) in enumerate(
zip(energy_vec_tmp, rawid_vec_tmp)
):

if id_main != id_other:
energies_corrected_tmp[id_other] += (
matrix[rawid_main][rawid_other] * energy_main
)

if (id_main!=id_other):
energies_corrected_tmp[id_other]+=matrix[rawid_main][rawid_other]*energy_main

energies_corrected.append(energies_corrected_tmp)

## convert to awkward array and unsort
return ak.Array(energies_corrected)[args]





def get_energy_corrected(f_hit:str,f_dsp:str,f_tcm:str,hit_group:str,dsp_group:str,tcm_group:str,tcm_id_table_pattern:str,channels:list,
cross_talk_matrix:str,energy_variable:str)->VectorOfVectors:
def get_energy_corrected(
f_hit: str,
f_dsp: str,
f_tcm: str,
hit_group: str,
dsp_group: str,
tcm_group: str,
tcm_id_table_pattern: str,
channels: list,
cross_talk_matrix: str,
energy_variable: str,
) -> VectorOfVectors:
"""
Function to compute cross-talk corrected energies.
Parameters
Expand All @@ -129,4 +150,3 @@ def get_energy_corrected(f_hit:str,f_dsp:str,f_tcm:str,hit_group:str,dsp_group:s
"""

raise NotImplementedError("error cross talk correction is not yet implemented")

22 changes: 12 additions & 10 deletions src/pygama/evt/modules/geds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from __future__ import annotations

import json
from collections.abc import Sequence

from lgdo import lh5, types
from lgdo import types

from .. import utils
from . import cross_talk
import json


def apply_xtalk_correction(
datainfo: utils.DataInfo,
Expand All @@ -18,7 +19,7 @@ def apply_xtalk_correction(
energy_observable: types.VectorOfVectors,
rawids: types.VectorOfVectors,
xtalk_matrix_filename: str,
threshold:float
threshold: float,
) -> types.VectorOfVectors:
"""Applies the cross-talk correction to the energy observable.
The format of `xtalk_matrix_filename` should be currently be a path to a JSON file.
Expand All @@ -44,15 +45,16 @@ def apply_xtalk_correction(
"""
# read in xtalk matrices (currently a json file)

with open(xtalk_matrix_filename, 'r') as file:
cross_talk_matrix = json.load(file)
with open(xtalk_matrix_filename) as file:
cross_talk_matrix = json.load(file)

# do the correction
energies_corr = cross_talk.cross_talk_corrected_energy_awkard_slow(energies=energy_observable,
rawids=rawids,
matrix=cross_talk_matrix,
allow_non_existing=False
)
energies_corr = cross_talk.cross_talk_corrected_energy_awkard_slow(
energies=energy_observable,
rawids=rawids,
matrix=cross_talk_matrix,
allow_non_existing=False,
)

# return the result as LGDO
return types.VectorOfVectors(
Expand Down
79 changes: 47 additions & 32 deletions tests/evt/test_cross_talk.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,71 @@
import awkward as ak
from pygama.evt.modules import cross_talk
import pytest

from pygama.evt.modules import cross_talk


def test_cross_talk_corrected_energy_awkard_slow():


energies_test=ak.Array([[1000,200],[100],[500,3000,100]])
rawid_test =ak.Array([[1,2],[3],[1,2,3]])
energies_test = ak.Array([[1000, 200], [100], [500, 3000, 100]])
rawid_test = ak.Array([[1, 2], [3], [1, 2, 3]])

## first check exceptions
matrix={1:{1:1.00, 2:0.01, 3:0.02, 4:0.00},
2:{1:0.01, 2:1.00, 3:0, 4:0.01},
3:{1:0.02, 2:0.00, 3:1.00, 4:0.01},
4:{1:0.00, 2:0.01, 3:0.01, 4:1.00}
}
matrix = {
1: {1: 1.00, 2: 0.01, 3: 0.02, 4: 0.00},
2: {1: 0.01, 2: 1.00, 3: 0, 4: 0.01},
3: {1: 0.02, 2: 0.00, 3: 1.00, 4: 0.01},
4: {1: 0.00, 2: 0.01, 3: 0.01, 4: 1.00},
}

# if rawid and energies have different shapes (juts the first entry)
with pytest.raises(ValueError):
cross_talk.cross_talk_corrected_energy_awkard_slow(ak.Array([[1000,200]]),rawid_test,matrix,True)
cross_talk.cross_talk_corrected_energy_awkard_slow(
ak.Array([[1000, 200]]), rawid_test, matrix, True
)

# filter some values from energy first so each event has a different size
with pytest.raises(ValueError):
cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test[energies_test!=1000],rawid_test,matrix,True)
cross_talk.cross_talk_corrected_energy_awkard_slow(
energies_test[energies_test != 1000], rawid_test, matrix, True
)

## checks on the matrix
# first check if the matrix has empty elements an exception is raised if allow_non_existing is false
matrix_not_full={1:{1:1.00, 2:0.01, 3:0.02, 4:0.00},
2:{1:0.01, 2:1.00, 3:0},
3:{1:0.02, 2:0.00, 3:1.00, 4:0.01},
4:{1:0.00, 3:0.01, 4:1.00}
}

matrix_not_full = {
1: {1: 1.00, 2: 0.01, 3: 0.02, 4: 0.00},
2: {1: 0.01, 2: 1.00, 3: 0},
3: {1: 0.02, 2: 0.00, 3: 1.00, 4: 0.01},
4: {1: 0.00, 3: 0.01, 4: 1.00},
}

with pytest.raises(ValueError):
cross_talk.cross_talk_corrected_energy_awkard_slow(
energies_test, rawid_test, matrix_not_full, False
)

matrix_not_sym = {
1: {1: 1.00, 2: 0.0, 3: 0.02, 4: 0.00},
2: {1: 0.01, 2: 1.00, 3: 0, 4: 0.01},
3: {1: 0.02, 2: 0.00, 3: 1.00, 4: 0.01},
4: {1: 0.00, 2: 0.01, 3: 0.01, 4: 1.00},
}
with pytest.raises(ValueError):
cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test,rawid_test,matrix_not_full,False)

matrix_not_sym={1:{1:1.00, 2:0.0, 3:0.02, 4:0.00},
2:{1:0.01, 2:1.00, 3:0, 4:0.01},
3:{1:0.02, 2:0.00, 3:1.00, 4:0.01},
4:{1:0.00, 2:0.01, 3:0.01, 4:1.00}
}
with pytest.raises(ValueError):
cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test,rawid_test,matrix_not_sym,True)
cross_talk.cross_talk_corrected_energy_awkard_slow(
energies_test, rawid_test, matrix_not_sym, True
)

## now check the result returned (given no exceptions)
### 1st event is two channels 1% cross talk and [1000,200] energy so we expect Ecorr = [1002,210]
### 1st event is two channels 1% cross talk and [1000,200] energy so we expect Ecorr = [1002,210]
### 2nd event one channel so the energy shouldnt change
### 3rd event 3 channels [500,3000,100] energy ch1 has 1% ct to ch2 and 2% to ch3 so we will have
### E=[500,3005,102] then ch2 has 1% cross talk to ch1 and 0 to ch3
### E= [503,3005,102] finally ch3 has 2% cross talk with ch1
### E=[505,3005,102]

energy_corr =cross_talk.cross_talk_corrected_energy_awkard_slow(energies_test,rawid_test,matrix,True)

assert np.all(energy_corr[0].to_numpy()==np.array([1002,210]))
assert np.all(energy_corr[1].to_numpy()==np.array([100]))
assert np.all(energy_corr[2].to_numpy()==np.array([505,3005,102]))
energy_corr = cross_talk.cross_talk_corrected_energy_awkard_slow(
energies_test, rawid_test, matrix, True
)

assert np.all(energy_corr[0].to_numpy() == np.array([1002, 210]))
assert np.all(energy_corr[1].to_numpy() == np.array([100]))
assert np.all(energy_corr[2].to_numpy() == np.array([505, 3005, 102]))

0 comments on commit 1ae805e

Please sign in to comment.