diff --git a/src/pygama/flow/data_loader.py b/src/pygama/flow/data_loader.py index 2c248633d..832d5dfa0 100644 --- a/src/pygama/flow/data_loader.py +++ b/src/pygama/flow/data_loader.py @@ -231,7 +231,7 @@ def set_config(self, config: dict | str) -> None: else: self.cut_priority[level] = 0 - def set_files(self, query: str | list[str]) -> None: + def set_files(self, query: str | list[str], append: bool = False) -> None: """Apply a file selection. Sets `self.file_list`, which is a list of indices corresponding to the @@ -244,6 +244,9 @@ def set_files(self, query: str | list[str]) -> None: supported by :meth:`pandas.DataFrame.query`. In addition, the ``all`` keyword is supported to select all files in the database. If list of strings, will be interpreted as key (cycle timestamp) list. + append + if ``True``, appends files to the existing `self.file_list` + instead of overwriting. Note ---- @@ -273,10 +276,11 @@ def set_files(self, query: str | list[str]) -> None: if not inds: log.warning("no files matching selection found") - if self.file_list is None: - self.file_list = inds - else: + if append and self.file_list is not None: self.file_list += inds + self.file_list = sorted(list(set(self.file_list))) + else: + self.file_list = inds def get_file_list(self) -> pd.DataFrame: """ @@ -285,7 +289,9 @@ def get_file_list(self) -> pd.DataFrame: """ return self.filedb.df.iloc[self.file_list] - def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None: + def set_datastreams( + self, ds: list | tuple | np.ndarray, word: str, append: bool = False + ) -> None: """Apply selection on data streams (or channels). Sets `self.table_list`. @@ -299,12 +305,15 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None: word the type of identifier used in ds. Should be a key in the given channel map or a word defined in the configuration file. + append + if ``True``, appends datastreams to the existing `self.table_list` + instead of overwriting. Example ------- >>> dl.set_datastreams(np.arange(40, 45), "ch") """ - if self.table_list is None: + if self.table_list is None or not append: self.table_list = {} ds = list(ds) @@ -324,6 +333,7 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None: found = True if level in self.table_list.keys(): self.table_list[level] += ds + self.table_list[level] = sorted(list(set(self.table_list[level]))) else: self.table_list[level] = ds @@ -331,7 +341,7 @@ def set_datastreams(self, ds: list | tuple | np.ndarray, word: str) -> None: # look for word in channel map raise NotImplementedError - def set_cuts(self, cuts: dict | list) -> None: + def set_cuts(self, cuts: dict | list, append: bool = False) -> None: """Apply a selection on columns in the data tables. Parameters @@ -342,12 +352,14 @@ def set_cuts(self, cuts: dict | list) -> None: structured as ``dict[tier] = cut_expr``. If passing a list, each item in the array should be able to be applied on one level of tables. The cuts at different levels will be joined with an AND. + append + if True, appends cuts to the existing cuts instead of overwriting Example ------- >>> dl.set_cuts({"raw": "daqenergy > 1000", "hit": "AoE > 3"}) """ - if self.cuts is None: + if self.cuts is None or not append: self.cuts = {} if isinstance(cuts, dict): # verify the correct structure @@ -356,7 +368,7 @@ def set_cuts(self, cuts: dict | list) -> None: raise ValueError( r"cuts dictionary must be in the format \{ level: string \}" ) - if key in self.cuts.keys(): + if key in self.cuts.keys() and append: self.cuts[key] += " and " + value else: self.cuts[key] = value @@ -1090,7 +1102,7 @@ def explode_evt_cols(el: pd.DataFrame, tier_table: Table): ) else: # not merge_files if in_memory: - load_out = {} + load_out = Struct(attrs={"int_keys": True}) if log.getEffectiveLevel() >= logging.INFO: progress_bar = tqdm( @@ -1179,13 +1191,15 @@ def explode_evt_cols(el: pd.DataFrame, tier_table: Table): f_table = utils.dict_to_table(col_dict, attr_dict) if in_memory: - load_out[file] = f_table + load_out.add_field(name=file, obj=f_table) if output_file: sto.write_object(f_table, f"{file}", output_file, wo_mode="o") # end file loop if log.getEffectiveLevel() >= logging.INFO: progress_bar.close() + if load_out: + load_out.update_datatype() if in_memory: if self.output_format == "lgdo.Table": @@ -1219,7 +1233,7 @@ def load_evts( raise NotImplementedError else: # Not merge_files if in_memory: - load_out = {} + load_out = Struct(attrs={"int_keys": True}) for file, f_entries in entry_list.items(): field_mask = [] f_table = None diff --git a/src/pygama/lgdo/lh5_store.py b/src/pygama/lgdo/lh5_store.py index 4f4848701..d6c15910b 100644 --- a/src/pygama/lgdo/lh5_store.py +++ b/src/pygama/lgdo/lh5_store.py @@ -344,7 +344,12 @@ def read_object( # fields. If they all had shared indexing, they should be in a # table... Maybe should emit a warning? Or allow them to be # dicts keyed by field name? - obj_dict[field], _ = self.read_object( + if "int_keys" in h5f[name].attrs: + if dict(h5f[name].attrs)["int_keys"]: + f = int(field) + else: + f = str(field) + obj_dict[f], _ = self.read_object( name + "/" + field, h5f, start_row=start_row, @@ -929,9 +934,11 @@ def write_object( else: obj_fld = obj[field] + # Convert keys to string for dataset names + f = str(field) self.write_object( obj_fld, - field, + f, lh5_file, group=group, start_row=start_row, diff --git a/src/pygama/lgdo/struct.py b/src/pygama/lgdo/struct.py index c8c5f0894..7fa2373f9 100644 --- a/src/pygama/lgdo/struct.py +++ b/src/pygama/lgdo/struct.py @@ -44,17 +44,19 @@ def datatype_name(self) -> str: return "struct" def form_datatype(self) -> str: - return self.datatype_name() + "{" + ",".join(self.keys()) + "}" + return ( + self.datatype_name() + "{" + ",".join([str(k) for k in self.keys()]) + "}" + ) def update_datatype(self) -> None: self.attrs["datatype"] = self.form_datatype() - def add_field(self, name: str, obj: LGDO) -> None: + def add_field(self, name: str | int, obj: LGDO) -> None: """Add a field to the table.""" self[name] = obj self.update_datatype() - def remove_field(self, name: str, delete: bool = False) -> None: + def remove_field(self, name: str | int, delete: bool = False) -> None: """Remove a field from the table. Parameters diff --git a/src/pygama/lgdo/table.py b/src/pygama/lgdo/table.py index d86f07f85..a5985dff2 100644 --- a/src/pygama/lgdo/table.py +++ b/src/pygama/lgdo/table.py @@ -225,7 +225,7 @@ def get_dataframe( if not hasattr(column, "nda"): raise ValueError(f"column {col} does not have an nda") else: - df[prefix + col] = column.nda.tolist() + df[prefix + str(col)] = column.nda.tolist() return df diff --git a/tests/flow/test_data_loader.py b/tests/flow/test_data_loader.py index a6eb7089f..3f0c466bb 100644 --- a/tests/flow/test_data_loader.py +++ b/tests/flow/test_data_loader.py @@ -76,7 +76,7 @@ def test_no_merge(test_dl): test_dl.set_output(columns=["timestamp"], merge_files=False) data = test_dl.load() - assert isinstance(data, dict) + assert isinstance(data, lgdo.Struct) assert isinstance(data[0], lgdo.Table) assert len(data) == 4 # 4 files assert list(data[0].keys()) == ["hit_table", "hit_idx", "timestamp"] @@ -147,6 +147,27 @@ def test_set_cuts(test_dl): assert (data.is_valid_cal == False).all() # noqa: E712 +def test_setter_overwrite(test_dl): + test_dl.set_files("all") + test_dl.set_datastreams([1084803, 1084804, 1121600], "ch") + test_dl.set_cuts({"hit": "trapEmax > 5000"}) + test_dl.set_output(columns=["trapEmax"]) + + data = test_dl.load().get_dataframe() + + test_dl.set_files("timestamp == '20230318T012144Z'") + test_dl.set_datastreams([1084803, 1121600], "ch") + test_dl.set_cuts({"hit": "trapEmax > 0"}) + + data2 = test_dl.load().get_dataframe() + + assert 1084804 not in data2["hit_table"] + assert len(pd.unique(data2["file"])) == 1 + assert len(data2.query("hit_table == 1084803")) > len( + data.query("hit_table == 1084803") + ) + + def test_browse(test_dl): test_dl.set_files("type == 'phy'") test_dl.set_datastreams([1057600, 1059201], "ch")