From 771b7fbda35261fbb09db371117b6fe0a7b4242a Mon Sep 17 00:00:00 2001 From: ouyangwenyu Date: Mon, 8 Apr 2024 11:33:30 +0800 Subject: [PATCH] refactor HydroMeanDataset but not finished yet --- torchhydro/datasets/data_sets.py | 374 +++++++++++++++---------------- 1 file changed, 181 insertions(+), 193 deletions(-) diff --git a/torchhydro/datasets/data_sets.py b/torchhydro/datasets/data_sets.py index 9394e23..dc147ab 100644 --- a/torchhydro/datasets/data_sets.py +++ b/torchhydro/datasets/data_sets.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2022-02-13 21:20:18 -LastEditTime: 2024-04-07 21:45:45 +LastEditTime: 2024-04-08 11:28:35 LastEditors: Wenyu Ouyang Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology FilePath: \torchhydro\torchhydro\datasets\data_sets.py @@ -161,14 +161,10 @@ def _pre_load_data(self): def _load_data(self): self._pre_load_data() - data_forcing_ds, data_output_ds, data_attr_ds = self._read_xyc() - # save unnormalized data to use in physics-based modeling, we will use streamflow with unit of mm/day - self.x_origin, self.y_origin, self.c_origin = self._to_dataarray_with_unit( - data_forcing_ds, data_output_ds, data_attr_ds - ) + self._read_xyc() # normalization norm_x, norm_y, norm_c = self._normalize() - self.x, self.y, self.c = self._kill_nan(norm_x, norm_c, norm_y) + self.x, self.y, self.c = self._kill_nan(norm_x, norm_y, norm_c) self._create_lookup_table() def _normalize(self): @@ -210,14 +206,14 @@ def _read_xyc(self): """ # y - data_flow_ds = self.data_source.read_ts_xrdataset( + data_output_ds = self.data_source.read_ts_xrdataset( self.t_s_dict["sites_id"], self.t_s_dict["t_final_range"], self.data_cfgs["target_cols"], ) if self.data_source.streamflow_unit != "mm/d": - data_flow_ds = streamflow_unit_conv( - data_flow_ds, self.data_source.read_area(self.t_s_dict["sites_id"]) + data_output_ds = streamflow_unit_conv( + data_output_ds, self.data_source.read_area(self.t_s_dict["sites_id"]) ) # x data_forcing_ds = self.data_source.read_ts_xrdataset( @@ -232,8 +228,9 @@ def _read_xyc(self): self.data_cfgs["constant_cols"], all_number=True, ) - - return data_forcing_ds, data_flow_ds, data_attr_ds + self.x_origin, self.y_origin, self.c_origin = self._to_dataarray_with_unit( + data_forcing_ds, data_output_ds, data_attr_ds + ) @property def basins(self): @@ -256,7 +253,7 @@ def _trans2da_and_setunits(self, ds): result.attrs["units"] = units_dict return result - def _kill_nan(self, x, c, y): + def _kill_nan(self, x, y, c): data_cfgs = self.data_cfgs y_rm_nan = data_cfgs["target_rm_nan"] x_rm_nan = data_cfgs["relevant_rm_nan"] @@ -298,54 +295,6 @@ def _create_lookup_table(self): self.lookup_table = dict(enumerate(lookup)) self.num_samples = len(self.lookup_table) - def _create_lookup_table_grid_mean(self): - lookup = [] - basins = self.t_s_dict["sites_id"] - forecast_length = self.forecast_length - warmup_length = self.warmup_length - dates = self.y["time"].to_numpy() - time_num = len(self.t_s_dict["t_final_range"]) - time_total_length = len(dates) - time_single_length = int(time_total_length / time_num) - is_tra_val_te = self.is_tra_val_te - for basin in tqdm( - basins, - file=sys.stdout, - disable=False, - desc=f"Creating {is_tra_val_te} lookup table", - ): - for num in range(time_num): - lookup.extend( - (basin, dates[f + num * time_single_length]) - for f in range(warmup_length, time_single_length - forecast_length) - ) - self.lookup_table = dict(enumerate(lookup)) - self.num_samples = len(self.lookup_table) - - def prepare_target(self): - gage_id_lst = self.t_s_dict["sites_id"] - t_range = self.t_s_dict["t_final_range"] - var_lst = self.data_cfgs["target_cols"] - path = self.data_cfgs["data_path"]["target"] - - if var_lst is None or not var_lst: - return None - - data = self.data_source.merge_nc_minio_datasets(path, gage_id_lst, var_lst) - - all_vars = data.data_vars - if any(var not in data.variables for var in var_lst): - raise ValueError(f"var_lst must all be in {all_vars}") - subset_list = [] - for start_date, end_date in t_range: - adjusted_end_date = ( - datetime.strptime(end_date, "%Y-%m-%d") - + timedelta(hours=self.forecast_length) - ).strftime("%Y-%m-%d") - subset = data.sel(time=slice(start_date, adjusted_end_date)) - subset_list.append(subset) - return xr.concat(subset_list, dim="time") - class BasinSingleFlowDataset(BaseDataset): """one time length output for each grid in a batch""" @@ -555,30 +504,191 @@ def _normalize(self): return scaler_hub.x, scaler_hub.y, scaler_hub.c -class HydroGridDataset(BaseDataset): +class HydroMeanDataset(BaseDataset): + def __init__(self, data_cfgs: dict, is_tra_val_te: str): + super(HydroMeanDataset, self).__init__(data_cfgs, is_tra_val_te) + + @property + def data_source(self): + return HydroBasins(self.data_cfgs["data_path"]) + + def _normalize(self): + x, y, c = super()._normalize() + return x.compute(), y.compute(), c.compute() + + def _read_xyc(self): + data_target_ds = self._prepare_target() + if data_target_ds is not None: + y_origin = super()._trans2da_and_setunits(data_target_ds) + else: + y_origin = None + + data_forcing_ds = self._prepare_forcing() + + if data_forcing_ds is not None: + x_origin = super()._trans2da_and_setunits(data_forcing_ds) + else: + x_origin = None + + if self.data_cfgs["constant_cols"]: + data_attr_ds = self.data_source.read_BA_xrdataset( + self.t_s_dict["sites_id"], + self.data_cfgs["constant_cols"], + self.data_cfgs["data_path"]["attributes"], + ) + c_orgin = super()._trans2da_and_setunits(data_attr_ds) + else: + c_orgin = None + self.x_origin, self.y_origin, self.c_origin = x_origin, y_origin, c_orgin + + def _prepare_target(self): + gage_id_lst = self.t_s_dict["sites_id"] + t_range = self.t_s_dict["t_final_range"] + var_lst = self.data_cfgs["target_cols"] + path = self.data_cfgs["data_path"]["target"] + + if var_lst is None or not var_lst: + return None + + data = self.data_source.merge_nc_minio_datasets(path, gage_id_lst, var_lst) + + all_vars = data.data_vars + if any(var not in data.variables for var in var_lst): + raise ValueError(f"var_lst must all be in {all_vars}") + subset_list = [] + for start_date, end_date in t_range: + adjusted_end_date = ( + datetime.strptime(end_date, "%Y-%m-%d") + + timedelta(hours=self.forecast_length) + ).strftime("%Y-%m-%d") + subset = data.sel(time=slice(start_date, adjusted_end_date)) + subset_list.append(subset) + return xr.concat(subset_list, dim="time") + + def _create_lookup_table(self): + lookup = [] + basins = self.t_s_dict["sites_id"] + forecast_length = self.forecast_length + warmup_length = self.warmup_length + dates = self.y["time"].to_numpy() + time_num = len(self.t_s_dict["t_final_range"]) + time_total_length = len(dates) + time_single_length = time_total_length // time_num + is_tra_val_te = self.is_tra_val_te + for basin in tqdm( + basins, + file=sys.stdout, + disable=False, + desc=f"Creating {is_tra_val_te} lookup table", + ): + for num in range(time_num): + lookup.extend( + (basin, dates[f + num * time_single_length]) + for f in range(warmup_length, time_single_length - forecast_length) + ) + self.lookup_table = dict(enumerate(lookup)) + self.num_samples = len(self.lookup_table) + + def __getitem__(self, item: int): + basin, time = self.lookup_table[item] + seq_length = self.rho + output_seq_len = self.data_cfgs["forecast_length"] + warmup_length = self.warmup_length + gpm_tp = ( + self.x.sel( + variable="gpm_tp", + basin=basin, + time=slice( + time - np.timedelta64(warmup_length + seq_length - 1, "h"), + time, + ), + ) + .to_numpy() + .T + ).reshape(-1, 1) + gfs_tp = ( + self.x.sel( + variable="gfs_tp", + basin=basin, + time=slice( + time + np.timedelta64(1, "h"), + time + np.timedelta64(output_seq_len, "h"), + ), + ) + .to_numpy() + .T + ).reshape(-1, 1) + x = np.concatenate((gpm_tp, gfs_tp), axis=0) + if self.c is not None and self.c.shape[-1] > 0: + c = self.c.sel(basin=basin).values + c = np.tile(c, (warmup_length + seq_length + output_seq_len, 1)) + x = np.concatenate((x, c), axis=1) + y = ( + self.y.sel( + basin=basin, + time=slice( + time + np.timedelta64(1, "h"), + time + np.timedelta64(output_seq_len, "h"), + ), + ) + .to_numpy() + .T + ) + return torch.from_numpy(x).float(), torch.from_numpy(y).float() + + def __len__(self): + return self.num_samples + + def _prepare_forcing(self): + gage_id_lst = self.t_s_dict["sites_id"] + t_range = self.t_s_dict["t_final_range"] + var_lst = self.data_cfgs["relevant_cols"] + path = self.data_cfgs["data_path"]["forcing"] + + if var_lst is None: + return None + + data = self.data_source.merge_nc_minio_datasets(path, gage_id_lst, var_lst) + + var_subset_list = [] + for start_date, end_date in t_range: + adjusted_start_date = ( + datetime.strptime(start_date, "%Y-%m-%d") - timedelta(hours=self.rho) + ).strftime("%Y-%m-%d") + adjusted_end_date = ( + datetime.strptime(end_date, "%Y-%m-%d") + + timedelta(hours=self.data_cfgs["forecast_length"]) + ).strftime("%Y-%m-%d") + subset = data.sel(time=slice(adjusted_start_date, adjusted_end_date)) + var_subset_list.append(subset) + + return xr.concat(var_subset_list, dim="time") + + +class HydroGridDataset(HydroMeanDataset): def __init__(self, data_cfgs: dict, is_tra_val_te: str): super(HydroGridDataset, self).__init__(data_cfgs, is_tra_val_te) def _load_data(self): self.data_source = HydroBasins(self.data_cfgs["data_path"]) self.forecast_length = self.data_cfgs["forecast_length"] - super().common_load_data() + self._pre_load_data() - data_target_ds = super().prepare_target() + data_target_ds = self._prepare_target() if data_target_ds is not None: - y_origin = super()._trans2da_and_setunits(data_target_ds) + y_origin = self._trans2da_and_setunits(data_target_ds) else: y_origin = None - data_gpm = self.prepare_forcing(0) + data_gpm = self._prepare_forcing(0) if self.data_cfgs["relevant_cols"][1] != ["None"]: - data_gfs = self.prepare_forcing(1) + data_gfs = self._prepare_forcing(1) else: data_gfs = None if self.data_cfgs["relevant_cols"][2] != ["None"]: - data_smap = self.prepare_forcing(2) + data_smap = self._prepare_forcing(2) else: data_smap = None @@ -588,7 +698,7 @@ def _load_data(self): self.data_cfgs["constant_cols"], self.data_cfgs["data_path"]["attributes"], ) - data_attr = super()._trans2da_and_setunits(data_attr_ds) + data_attr = self._trans2da_and_setunits(data_attr_ds) else: data_attr = None @@ -609,7 +719,7 @@ def _load_data(self): self.target_scaler = scaler_hub.target_scaler - super()._create_lookup_table_grid_mean() + self._create_lookup_table() def kill_nan(self, x, y, c, g, s): data_cfgs = self.data_cfgs @@ -781,7 +891,7 @@ def get_s(self, basin, time): ) return np.transpose(s, (1, 0, 2, 3)) - def prepare_forcing(self, data_type): + def _prepare_forcing(self, data_type): gage_id_lst = self.t_s_dict["sites_id"] t_range = self.t_s_dict["t_final_range"] var_lst = self.data_cfgs["relevant_cols"][data_type] @@ -817,128 +927,6 @@ def prepare_forcing(self, data_type): return data_dict -class HydroMeanDataset(BaseDataset): - def __init__(self, data_cfgs: dict, is_tra_val_te: str): - super(HydroMeanDataset, self).__init__(data_cfgs, is_tra_val_te) - - def _load_data(self): - self.data_source = HydroBasins(self.data_cfgs["data_path"]) - self.forecast_length = self.data_cfgs["forecast_length"] - super()._pre_load_data() - - data_target_ds = super().prepare_target() - if data_target_ds is not None: - y_origin = super()._trans2da_and_setunits(data_target_ds) - else: - y_origin = None - - data_forcing_ds = self.prepare_forcing() - - if data_forcing_ds is not None: - x_origin = super()._trans2da_and_setunits(data_forcing_ds) - else: - x_origin = None - - if self.data_cfgs["constant_cols"]: - data_attr_ds = self.data_source.read_BA_xrdataset( - self.t_s_dict["sites_id"], - self.data_cfgs["constant_cols"], - self.data_cfgs["data_path"]["attributes"], - ) - c_orgin = super()._trans2da_and_setunits(data_attr_ds) - else: - c_orgin = None - - scaler_hub = ScalerHub( - y_origin, - x_origin, - c_orgin, - self.data_cfgs, - self.is_tra_val_te, - self.data_source, - ) - self.target_scaler = scaler_hub.target_scaler - self.x, self.y, self.c = super()._kill_nan( - scaler_hub.x.compute(), scaler_hub.c.compute(), scaler_hub.y.compute() - ) - super()._create_lookup_table_grid_mean() - - def __getitem__(self, item: int): - basin, time = self.lookup_table[item] - seq_length = self.rho - output_seq_len = self.forecast_length - warmup_length = self.warmup_length - gpm_tp = ( - self.x.sel( - variable="gpm_tp", - basin=basin, - time=slice( - time - np.timedelta64(warmup_length + seq_length - 1, "h"), - time, - ), - ) - .to_numpy() - .T - ).reshape(-1, 1) - gfs_tp = ( - self.x.sel( - variable="gfs_tp", - basin=basin, - time=slice( - time + np.timedelta64(1, "h"), - time + np.timedelta64(output_seq_len, "h"), - ), - ) - .to_numpy() - .T - ).reshape(-1, 1) - x = np.concatenate((gpm_tp, gfs_tp), axis=0) - if self.c is not None and self.c.shape[-1] > 0: - c = self.c.sel(basin=basin).values - c = np.tile(c, (warmup_length + seq_length + output_seq_len, 1)) - x = np.concatenate((x, c), axis=1) - y = ( - self.y.sel( - basin=basin, - time=slice( - time + np.timedelta64(1, "h"), - time + np.timedelta64(output_seq_len, "h"), - ), - ) - .to_numpy() - .T - ) - return torch.from_numpy(x).float(), torch.from_numpy(y).float() - - def __len__(self): - return self.num_samples - - def prepare_forcing(self): - gage_id_lst = self.t_s_dict["sites_id"] - t_range = self.t_s_dict["t_final_range"] - var_lst = self.data_cfgs["relevant_cols"] - path = self.data_cfgs["data_path"]["forcing"] - - if var_lst is None: - return None - - data = self.data_source.merge_nc_minio_datasets(path, gage_id_lst, var_lst) - - var_subset_list = [] - for start_date, end_date in t_range: - adjusted_start_date = ( - datetime.strptime(start_date, "%Y-%m-%d") - timedelta(hours=self.rho) - ).strftime("%Y-%m-%d") - adjusted_end_date = ( - datetime.strptime(end_date, "%Y-%m-%d") - + timedelta(hours=self.forecast_length) - ).strftime("%Y-%m-%d") - subset = data.sel(time=slice(adjusted_start_date, adjusted_end_date)) - var_subset_list.append(subset) - - return xr.concat(var_subset_list, dim="time") - - # Most functions are the same or similar, # but data is not uploaded in Minio Server, and self.load_data() will be used when being succeeded # because it is written in the init function, so I did not succeed the class above