Skip to content

Commit

Permalink
Merge pull request #493 from gracesong312/main
Browse files Browse the repository at this point in the history
Let ``DataLoader`` default setters overwrite (not append) + `lgdo.Struct.int_dtype`
  • Loading branch information
gipert authored May 18, 2023
2 parents 92c8d53 + 2cf4cc4 commit 9083f69
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 19 deletions.
38 changes: 26 additions & 12 deletions src/pygama/flow/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def set_config(self, config: dict | str) -> None:
else:
self.cut_priority[level] = 0

def set_files(self, query: str | list[str]) -> None:
def set_files(self, query: str | list[str], append: bool = False) -> None:
"""Apply a file selection.
Sets `self.file_list`, which is a list of indices corresponding to the
Expand All @@ -244,6 +244,9 @@ def set_files(self, query: str | list[str]) -> None:
supported by :meth:`pandas.DataFrame.query`. In addition, the
``all`` keyword is supported to select all files in the database.
If list of strings, will be interpreted as key (cycle timestamp) list.
append
if ``True``, appends files to the existing `self.file_list`
instead of overwriting.
Note
----
Expand Down Expand Up @@ -273,10 +276,11 @@ def set_files(self, query: str | list[str]) -> None:
if not inds:
log.warning("no files matching selection found")

if self.file_list is None:
self.file_list = inds
else:
if append and self.file_list is not None:
self.file_list += inds
self.file_list = sorted(list(set(self.file_list)))
else:
self.file_list = inds

def get_file_list(self) -> pd.DataFrame:
"""
Expand All @@ -285,7 +289,9 @@ def get_file_list(self) -> pd.DataFrame:
"""
return self.filedb.df.iloc[self.file_list]

def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None:
def set_datastreams(
self, ds: list | tuple | np.ndarray, word: str, append: bool = False
) -> None:
"""Apply selection on data streams (or channels).
Sets `self.table_list`.
Expand All @@ -299,12 +305,15 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None:
word
the type of identifier used in ds. Should be a key in the given
channel map or a word defined in the configuration file.
append
if ``True``, appends datastreams to the existing `self.table_list`
instead of overwriting.
Example
-------
>>> dl.set_datastreams(np.arange(40, 45), "ch")
"""
if self.table_list is None:
if self.table_list is None or not append:
self.table_list = {}

ds = list(ds)
Expand All @@ -324,14 +333,15 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None:
found = True
if level in self.table_list.keys():
self.table_list[level] += ds
self.table_list[level] = sorted(list(set(self.table_list[level])))
else:
self.table_list[level] = ds

if not found:
# look for word in channel map
raise NotImplementedError

def set_cuts(self, cuts: dict | list) -> None:
def set_cuts(self, cuts: dict | list, append: bool = False) -> None:
"""Apply a selection on columns in the data tables.
Parameters
Expand All @@ -342,12 +352,14 @@ def set_cuts(self, cuts: dict | list) -> None:
structured as ``dict[tier] = cut_expr``. If passing a list, each
item in the array should be able to be applied on one level of
tables. The cuts at different levels will be joined with an AND.
append
if True, appends cuts to the existing cuts instead of overwriting
Example
-------
>>> dl.set_cuts({"raw": "daqenergy > 1000", "hit": "AoE > 3"})
"""
if self.cuts is None:
if self.cuts is None or not append:
self.cuts = {}
if isinstance(cuts, dict):
# verify the correct structure
Expand All @@ -356,7 +368,7 @@ def set_cuts(self, cuts: dict | list) -> None:
raise ValueError(
r"cuts dictionary must be in the format \{ level: string \}"
)
if key in self.cuts.keys():
if key in self.cuts.keys() and append:
self.cuts[key] += " and " + value
else:
self.cuts[key] = value
Expand Down Expand Up @@ -1126,7 +1138,7 @@ def explode_evt_cols(el: pd.DataFrame, tier_table: Table):
)
else: # not merge_files
if in_memory:
load_out = {}
load_out = Struct(attrs={"int_keys": True})

if log.getEffectiveLevel() >= logging.INFO:
progress_bar = tqdm(
Expand Down Expand Up @@ -1215,13 +1227,15 @@ def explode_evt_cols(el: pd.DataFrame, tier_table: Table):
f_table = utils.dict_to_table(col_dict, attr_dict)

if in_memory:
load_out[file] = f_table
load_out.add_field(name=file, obj=f_table)
if output_file:
sto.write_object(f_table, f"{file}", output_file, wo_mode="o")
# end file loop

if log.getEffectiveLevel() >= logging.INFO:
progress_bar.close()
if load_out:
load_out.update_datatype()

if in_memory:
if self.output_format == "lgdo.Table":
Expand Down Expand Up @@ -1255,7 +1269,7 @@ def load_evts(
raise NotImplementedError
else: # Not merge_files
if in_memory:
load_out = {}
load_out = Struct(attrs={"int_keys": True})
for file, f_entries in entry_list.items():
field_mask = []
f_table = None
Expand Down
11 changes: 9 additions & 2 deletions src/pygama/lgdo/lh5_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,12 @@ def read_object(
# fields. If they all had shared indexing, they should be in a
# table... Maybe should emit a warning? Or allow them to be
# dicts keyed by field name?
obj_dict[field], _ = self.read_object(
if "int_keys" in h5f[name].attrs:
if dict(h5f[name].attrs)["int_keys"]:
f = int(field)
else:
f = str(field)
obj_dict[f], _ = self.read_object(
name + "/" + field,
h5f,
start_row=start_row,
Expand Down Expand Up @@ -929,9 +934,11 @@ def write_object(
else:
obj_fld = obj[field]

# Convert keys to string for dataset names
f = str(field)
self.write_object(
obj_fld,
field,
f,
lh5_file,
group=group,
start_row=start_row,
Expand Down
8 changes: 5 additions & 3 deletions src/pygama/lgdo/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,19 @@ def datatype_name(self) -> str:
return "struct"

def form_datatype(self) -> str:
return self.datatype_name() + "{" + ",".join(self.keys()) + "}"
return (
self.datatype_name() + "{" + ",".join([str(k) for k in self.keys()]) + "}"
)

def update_datatype(self) -> None:
self.attrs["datatype"] = self.form_datatype()

def add_field(self, name: str, obj: LGDO) -> None:
def add_field(self, name: str | int, obj: LGDO) -> None:
"""Add a field to the table."""
self[name] = obj
self.update_datatype()

def remove_field(self, name: str, delete: bool = False) -> None:
def remove_field(self, name: str | int, delete: bool = False) -> None:
"""Remove a field from the table.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion src/pygama/lgdo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_dataframe(
if not hasattr(column, "nda"):
raise ValueError(f"column {col} does not have an nda")
else:
df[prefix + col] = column.nda.tolist()
df[prefix + str(col)] = column.nda.tolist()

return df

Expand Down
23 changes: 22 additions & 1 deletion tests/flow/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_no_merge(test_dl):
test_dl.set_output(columns=["timestamp"], merge_files=False)
data = test_dl.load()

assert isinstance(data, dict)
assert isinstance(data, lgdo.Struct)
assert isinstance(data[0], lgdo.Table)
assert len(data) == 4 # 4 files
assert list(data[0].keys()) == ["hit_table", "hit_idx", "timestamp"]
Expand Down Expand Up @@ -147,6 +147,27 @@ def test_set_cuts(test_dl):
assert (data.is_valid_cal == False).all() # noqa: E712


def test_setter_overwrite(test_dl):
test_dl.set_files("all")
test_dl.set_datastreams([1084803, 1084804, 1121600], "ch")
test_dl.set_cuts({"hit": "trapEmax > 5000"})
test_dl.set_output(columns=["trapEmax"])

data = test_dl.load().get_dataframe()

test_dl.set_files("timestamp == '20230318T012144Z'")
test_dl.set_datastreams([1084803, 1121600], "ch")
test_dl.set_cuts({"hit": "trapEmax > 0"})

data2 = test_dl.load().get_dataframe()

assert 1084804 not in data2["hit_table"]
assert len(pd.unique(data2["file"])) == 1
assert len(data2.query("hit_table == 1084803")) > len(
data.query("hit_table == 1084803")
)


def test_browse(test_dl):
test_dl.set_files("type == 'phy'")
test_dl.set_datastreams([1057600, 1059201], "ch")
Expand Down

0 comments on commit 9083f69

Please sign in to comment.