Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let DataLoader default setters overwrite (not append) + lgdo.Struct.int_dtype #493

Merged
merged 18 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/pygama/flow/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def set_config(self, config: dict | str) -> None:
else:
self.cut_priority[level] = 0

def set_files(self, query: str | list[str]) -> None:
def set_files(self, query: str | list[str], append: bool = False) -> None:
"""Apply a file selection.

Sets `self.file_list`, which is a list of indices corresponding to the
Expand All @@ -244,6 +244,9 @@ def set_files(self, query: str | list[str]) -> None:
supported by :meth:`pandas.DataFrame.query`. In addition, the
``all`` keyword is supported to select all files in the database.
If list of strings, will be interpreted as key (cycle timestamp) list.
append
if ``True``, appends files to the existing `self.file_list`
instead of overwriting.

Note
----
Expand Down Expand Up @@ -273,10 +276,11 @@ def set_files(self, query: str | list[str]) -> None:
if not inds:
log.warning("no files matching selection found")

if self.file_list is None:
self.file_list = inds
else:
if append and self.file_list is not None:
self.file_list += inds
gipert marked this conversation as resolved.
Show resolved Hide resolved
self.file_list = sorted(list(set(self.file_list)))
else:
self.file_list = inds

def get_file_list(self) -> pd.DataFrame:
"""
Expand All @@ -285,7 +289,9 @@ def get_file_list(self) -> pd.DataFrame:
"""
return self.filedb.df.iloc[self.file_list]

def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None:
def set_datastreams(
self, ds: list | tuple | np.ndarray, word: str, append: bool = False
) -> None:
"""Apply selection on data streams (or channels).

Sets `self.table_list`.
Expand All @@ -299,12 +305,15 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None:
word
the type of identifier used in ds. Should be a key in the given
channel map or a word defined in the configuration file.
append
if ``True``, appends datastreams to the existing `self.table_list`
instead of overwriting.

Example
-------
>>> dl.set_datastreams(np.arange(40, 45), "ch")
"""
if self.table_list is None:
if self.table_list is None or not append:
self.table_list = {}

ds = list(ds)
Expand All @@ -324,14 +333,15 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None:
found = True
if level in self.table_list.keys():
self.table_list[level] += ds
self.table_list[level] = sorted(list(set(self.table_list[level])))
else:
self.table_list[level] = ds

if not found:
# look for word in channel map
raise NotImplementedError

def set_cuts(self, cuts: dict | list) -> None:
def set_cuts(self, cuts: dict | list, append: bool = False) -> None:
"""Apply a selection on columns in the data tables.

Parameters
Expand All @@ -342,12 +352,14 @@ def set_cuts(self, cuts: dict | list) -> None:
structured as ``dict[tier] = cut_expr``. If passing a list, each
item in the array should be able to be applied on one level of
tables. The cuts at different levels will be joined with an AND.
append
if True, appends cuts to the existing cuts instead of overwriting

Example
-------
>>> dl.set_cuts({"raw": "daqenergy > 1000", "hit": "AoE > 3"})
"""
if self.cuts is None:
if self.cuts is None or not append:
self.cuts = {}
if isinstance(cuts, dict):
# verify the correct structure
Expand All @@ -356,7 +368,7 @@ def set_cuts(self, cuts: dict | list) -> None:
raise ValueError(
r"cuts dictionary must be in the format \{ level: string \}"
)
if key in self.cuts.keys():
if key in self.cuts.keys() and append:
self.cuts[key] += " and " + value
else:
self.cuts[key] = value
Expand Down Expand Up @@ -1090,7 +1102,7 @@ def explode_evt_cols(el: pd.DataFrame, tier_table: Table):
)
else: # not merge_files
if in_memory:
load_out = {}
load_out = Struct(attrs={"int_keys": True})

if log.getEffectiveLevel() >= logging.INFO:
progress_bar = tqdm(
Expand Down Expand Up @@ -1179,13 +1191,15 @@ def explode_evt_cols(el: pd.DataFrame, tier_table: Table):
f_table = utils.dict_to_table(col_dict, attr_dict)

if in_memory:
load_out[file] = f_table
load_out.add_field(name=file, obj=f_table)
if output_file:
sto.write_object(f_table, f"{file}", output_file, wo_mode="o")
# end file loop

if log.getEffectiveLevel() >= logging.INFO:
progress_bar.close()
if load_out:
load_out.update_datatype()

if in_memory:
if self.output_format == "lgdo.Table":
Expand Down Expand Up @@ -1219,7 +1233,7 @@ def load_evts(
raise NotImplementedError
else: # Not merge_files
if in_memory:
load_out = {}
load_out = Struct(attrs={"int_keys": True})
for file, f_entries in entry_list.items():
field_mask = []
f_table = None
Expand Down
11 changes: 9 additions & 2 deletions src/pygama/lgdo/lh5_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,12 @@ def read_object(
# fields. If they all had shared indexing, they should be in a
# table... Maybe should emit a warning? Or allow them to be
# dicts keyed by field name?
obj_dict[field], _ = self.read_object(
if "int_keys" in h5f[name].attrs:
if dict(h5f[name].attrs)["int_keys"]:
f = int(field)
else:
f = str(field)
obj_dict[f], _ = self.read_object(
name + "/" + field,
h5f,
start_row=start_row,
Expand Down Expand Up @@ -929,9 +934,11 @@ def write_object(
else:
obj_fld = obj[field]

# Convert keys to string for dataset names
f = str(field)
self.write_object(
obj_fld,
field,
f,
lh5_file,
group=group,
start_row=start_row,
Expand Down
8 changes: 5 additions & 3 deletions src/pygama/lgdo/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,19 @@ def datatype_name(self) -> str:
return "struct"

def form_datatype(self) -> str:
return self.datatype_name() + "{" + ",".join(self.keys()) + "}"
return (
self.datatype_name() + "{" + ",".join([str(k) for k in self.keys()]) + "}"
)

def update_datatype(self) -> None:
self.attrs["datatype"] = self.form_datatype()

def add_field(self, name: str, obj: LGDO) -> None:
def add_field(self, name: str | int, obj: LGDO) -> None:
"""Add a field to the table."""
self[name] = obj
self.update_datatype()

def remove_field(self, name: str, delete: bool = False) -> None:
def remove_field(self, name: str | int, delete: bool = False) -> None:
"""Remove a field from the table.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion src/pygama/lgdo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_dataframe(
if not hasattr(column, "nda"):
raise ValueError(f"column {col} does not have an nda")
else:
df[prefix + col] = column.nda.tolist()
df[prefix + str(col)] = column.nda.tolist()

return df

Expand Down
23 changes: 22 additions & 1 deletion tests/flow/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_no_merge(test_dl):
test_dl.set_output(columns=["timestamp"], merge_files=False)
data = test_dl.load()

assert isinstance(data, dict)
assert isinstance(data, lgdo.Struct)
assert isinstance(data[0], lgdo.Table)
assert len(data) == 4 # 4 files
assert list(data[0].keys()) == ["hit_table", "hit_idx", "timestamp"]
Expand Down Expand Up @@ -147,6 +147,27 @@ def test_set_cuts(test_dl):
assert (data.is_valid_cal == False).all() # noqa: E712


def test_setter_overwrite(test_dl):
test_dl.set_files("all")
test_dl.set_datastreams([1084803, 1084804, 1121600], "ch")
test_dl.set_cuts({"hit": "trapEmax > 5000"})
test_dl.set_output(columns=["trapEmax"])

data = test_dl.load().get_dataframe()

test_dl.set_files("timestamp == '20230318T012144Z'")
test_dl.set_datastreams([1084803, 1121600], "ch")
test_dl.set_cuts({"hit": "trapEmax > 0"})

data2 = test_dl.load().get_dataframe()

assert 1084804 not in data2["hit_table"]
assert len(pd.unique(data2["file"])) == 1
assert len(data2.query("hit_table == 1084803")) > len(
data.query("hit_table == 1084803")
)


def test_browse(test_dl):
test_dl.set_files("type == 'phy'")
test_dl.set_datastreams([1057600, 1059201], "ch")
Expand Down