From eae1eb939f003b1c4ff5f98ef11b524c362cb858 Mon Sep 17 00:00:00 2001 From: owen Date: Tue, 5 Nov 2024 20:14:11 +0800 Subject: [PATCH] add one more time step for seq2seqdataset _read_xyc --- tests/test_data_sets.py | 133 +------------------------------ torchhydro/configs/config.py | 4 - torchhydro/datasets/data_sets.py | 75 +++++++++++++---- torchhydro/models/seq2seq.py | 17 ++-- 4 files changed, 69 insertions(+), 160 deletions(-) diff --git a/tests/test_data_sets.py b/tests/test_data_sets.py index 8435c42..81fe738 100644 --- a/tests/test_data_sets.py +++ b/tests/test_data_sets.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2024-05-27 13:33:08 -LastEditTime: 2024-11-05 11:46:01 +LastEditTime: 2024-11-05 18:20:19 LastEditors: Wenyu Ouyang Description: Unit test for datasets FilePath: \torchhydro\tests\test_data_sets.py @@ -149,134 +149,3 @@ def test_create_lookup_table(tmp_path): isinstance(key, int) and isinstance(value, tuple) for key, value in lookup_table.items() ) - - -def test_seq2seqdataset_getitem_train_mode(monkeypatch): - # 模拟 os.listdir 返回值 - def mock_listdir(path): - if "timeseries" in path: - return ["1D"] - elif "attributes" in path: - return ["attributes.csv"] - return [] - - # 模拟 os.path.isdir 返回值 - def mock_isdir(path): - return True - - # 模拟 pandas.read_csv 返回值 - def mock_read_csv(filepath, *args, **kwargs): - if "attributes.csv" in filepath: - return pd.DataFrame( - { - "basin_id": ["01013500", "01013501"], - "area": [100.0, 200.0], - "attr1": [1.0, 2.0], - "attr2": [3.0, 4.0], - } - ) - elif "timeseries" in filepath: - return pd.DataFrame( - { - "date": pd.date_range(start="2001-01-01", periods=365, freq="D"), - "prcp": [0.1] * 365, - "pet": [0.05] * 365, - "streamflow": [1.0] * 365, - "surface_sm": [0.3] * 365, - } - ) - return pd.DataFrame() - - # 模拟 os.path.exists 返回值 - def mock_exists(path): - return True - - # 模拟 os.path.join 返回值 - def mock_join(a, *p): - return a + "/" + "/".join(p) - - # 使用 monkeypatch 模拟 os.listdir、os.path.isdir、os.path.exists、os.path.join 和 pandas.read_csv - monkeypatch.setattr(os, "listdir", mock_listdir) - monkeypatch.setattr(os.path, "isdir", mock_isdir) - monkeypatch.setattr(os.path, "exists", mock_exists) - monkeypatch.setattr(os.path, "join", mock_join) - monkeypatch.setattr(pd, "read_csv", mock_read_csv) - - data_sources_dict.update({"mockdatasource": MockDatasource}) - data_cfgs = { - "source_cfgs": { - "source_name": "mockdatasource", - "source_path": "mock_path", - "other_settings": {"time_unit": ["1D"]}, - }, - "object_ids": ["01013500", "01013501"], - "t_range_train": ["2001-01-01", "2002-01-01"], - "t_range_test": ["2002-01-01", "2003-01-01"], - "relevant_cols": ["prcp", "pet"], - "target_cols": ["streamflow", "surface_sm"], - "constant_cols": ["geol_1st_class", "geol_2nd_class"], - "forecast_history": 7, - "warmup_length": 14, - "forecast_length": 1, - "min_time_unit": "D", - "min_time_interval": 1, - "target_rm_nan": True, - "relevant_rm_nan": True, - "constant_rm_nan": True, - "prec_window": 3, - "en_output_size": 5, - } - is_tra_val_te = "train" - dataset = Seq2SeqDataset(data_cfgs, is_tra_val_te) - item = 0 - (x, x_h, y), y_out = dataset[item] - assert isinstance(x, torch.Tensor) - assert isinstance(x_h, torch.Tensor) - assert isinstance(y, torch.Tensor) - assert isinstance(y_out, torch.Tensor) - assert ( - x.shape[1] - == len(data_cfgs["relevant_cols"]) + len(data_cfgs["constant_cols"]) + 1 - ) - assert x_h.shape[1] == len(data_cfgs["constant_cols"]) + 1 - assert y.shape[1] == len(data_cfgs["target_cols"]) - assert y_out.shape == y.shape - - -def test_seq2seqdataset_getitem_test_mode(): - data_sources_dict.update({"mockdatasource": MockDatasource}) - data_cfgs = { - "source_cfgs": { - "source_name": "mockdatasource", - "source_path": "mock_path", - }, - "object_ids": ["01013500", "01013501"], - "t_range_train": ["2001-01-01", "2002-01-01"], - "t_range_test": ["2002-01-01", "2003-01-01"], - "relevant_cols": ["prcp", "pet"], - "target_cols": ["streamflow", "surface_sm"], - "constant_cols": ["geol_1st_class", "geol_2nd_class"], - "forecast_history": 7, - "warmup_length": 14, - "forecast_length": 1, - "min_time_unit": "D", - "min_time_interval": 1, - "target_rm_nan": True, - "relevant_rm_nan": True, - "constant_rm_nan": True, - "prec_window": 3, - "en_output_size": 5, - } - is_tra_val_te = "test" - dataset = Seq2SeqDataset(data_cfgs, is_tra_val_te) - item = 0 - (x, x_h), y = dataset[item] - assert isinstance(x, torch.Tensor) - assert isinstance(x_h, torch.Tensor) - assert isinstance(y, torch.Tensor) - assert ( - x.shape[1] - == len(data_cfgs["relevant_cols"]) + len(data_cfgs["constant_cols"]) + 1 - ) - assert x_h.shape[1] == len(data_cfgs["constant_cols"]) + 1 - assert y.shape[1] == len(data_cfgs["target_cols"]) diff --git a/torchhydro/configs/config.py b/torchhydro/configs/config.py index da1a773..8cfaa5b 100644 --- a/torchhydro/configs/config.py +++ b/torchhydro/configs/config.py @@ -1031,10 +1031,6 @@ def update_cfg(cfg_file, new_args): cfg_file["data_cfgs"]["prec_window"] = new_args.model_hyperparam[ "prec_window" ] - if "en_output_size" in new_args.model_hyperparam.keys(): - cfg_file["data_cfgs"]["en_output_size"] = new_args.model_hyperparam[ - "en_output_size" - ] if new_args.batch_size is not None: # raise AttributeError("Please set the batch_size!!!") batch_size = new_args.batch_size diff --git a/torchhydro/datasets/data_sets.py b/torchhydro/datasets/data_sets.py index 349756d..3144112 100644 --- a/torchhydro/datasets/data_sets.py +++ b/torchhydro/datasets/data_sets.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2024-04-08 18:16:53 -LastEditTime: 2024-11-05 11:29:27 +LastEditTime: 2024-11-05 20:02:49 LastEditors: Wenyu Ouyang Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology FilePath: \torchhydro\torchhydro\datasets\data_sets.py @@ -340,15 +340,31 @@ def _read_xyc(self): x, y, c data """ # x + start_date = self.t_s_dict["t_final_range"][0] + end_date = self.t_s_dict["t_final_range"][1] + self._read_xyc_specified_time(start_date, end_date) + + def _read_xyc_specified_time(self, start_date, end_date): + """Read x, y, c data from data source with specified time range + We set this function as sometimes we need adjust the time range for some specific dataset, + such as seq2seq dataset (it needs one more period for the end of the time range) + + Parameters + ---------- + start_date : str + start time + end_date : str + end time + """ data_forcing_ds_ = self.data_source.read_ts_xrdataset( self.t_s_dict["sites_id"], - self.t_s_dict["t_final_range"], + [start_date, end_date], self.data_cfgs["relevant_cols"], ) # y data_output_ds_ = self.data_source.read_ts_xrdataset( self.t_s_dict["sites_id"], - self.t_s_dict["t_final_range"], + [start_date, end_date], self.data_cfgs["target_cols"], ) if isinstance(data_output_ds_, dict) or isinstance(data_forcing_ds_, dict): @@ -615,6 +631,26 @@ class Seq2SeqDataset(BaseDataset): def __init__(self, data_cfgs: dict, is_tra_val_te: str): super(Seq2SeqDataset, self).__init__(data_cfgs, is_tra_val_te) + def _read_xyc(self): + """ + NOTE: the lookup table is same as BaseDataset, + but the data retrieved from datasource should has one more period, + because we include the concepts of start and end moment of the period + + Returns + ------- + tuple[xr.Dataset, xr.Dataset, xr.Dataset] + x, y, c data + """ + start_date = self.t_s_dict["t_final_range"][0] + end_date = self.t_s_dict["t_final_range"][1] + interval = self.data_cfgs["min_time_interval"] + # TODO: Now only support hour, need to better handle different time units, such as Month, Day + adjusted_end_date = ( + datetime.strptime(end_date, "%Y-%m-%d-%H") + timedelta(hours=interval) + ).strftime("%Y-%m-%d-%H") + self._read_xyc_specified_time(start_date, adjusted_end_date) + def _normalize(self): x, y, c = super()._normalize() # TODO: this work for minio? maybe better to move to basedataset @@ -627,29 +663,38 @@ def __getitem__(self, item: int): basin, time = self.lookup_table[item] rho = self.rho horizon = self.horizon - prec = self.data_cfgs["prec_window"] - en_output_size = self.data_cfgs["en_output_size"] - + prec = self.data_cfgs.get("prec_window", 0) + # p cover all encoder-decoder periods; +1 means the period while +0 means start of the current period p = self.x[basin, time + 1 : time + rho + horizon + 1, 0].reshape(-1, 1) + # s only cover encoder periods s = self.x[basin, time : time + rho, 1:] x = np.concatenate((p[:rho], s), axis=1) - c = self.c[basin, :] - c = np.tile(c, (rho + horizon, 1)) - x = np.concatenate((x, c[:rho]), axis=1) - - x_h = np.concatenate((p[rho:], c[rho:]), axis=1) + if self.c is None or self.c.shape[-1] == 0: + xc = x + else: + c = self.c[basin, :] + c = np.tile(c, (rho + horizon, 1)) + xc = np.concatenate((x, c[:rho]), axis=1) + # xh cover decoder periods + try: + xh = np.concatenate((p[rho:], c[rho:]), axis=1) + except ValueError as e: + print(f"Error in np.concatenate: {e}") + print(f"p[rho:].shape: {p[rho:].shape}, c[rho:].shape: {c[rho:].shape}") + raise + # y cover specified encoder size (prec_window) and all decoder periods y = self.y[basin, time + rho - prec + 1 : time + rho + horizon + 1, :] if self.is_tra_val_te == "train": return [ - torch.from_numpy(x).float(), - torch.from_numpy(x_h).float(), + torch.from_numpy(xc).float(), + torch.from_numpy(xh).float(), torch.from_numpy(y).float(), ], torch.from_numpy(y).float() return [ - torch.from_numpy(x).float(), - torch.from_numpy(x_h).float(), + torch.from_numpy(xc).float(), + torch.from_numpy(xh).float(), ], torch.from_numpy(y).float() diff --git a/torchhydro/models/seq2seq.py b/torchhydro/models/seq2seq.py index 6167ce9..555d55e 100644 --- a/torchhydro/models/seq2seq.py +++ b/torchhydro/models/seq2seq.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2024-04-17 12:32:26 -LastEditTime: 2024-11-01 12:01:16 +LastEditTime: 2024-11-05 19:10:02 LastEditors: Wenyu Ouyang Description: FilePath: \torchhydro\torchhydro\models\seq2seq.py @@ -123,7 +123,6 @@ def __init__( forecast_length, prec_window=0, teacher_forcing_ratio=0.5, - en_output_size=1, ): """General Seq2Seq model @@ -140,18 +139,15 @@ def __init__( forecast_length : _type_ the length of the forecast, i.e., the periods of decoder outputs prec_window : int, optional - starting index of decoder output for teacher forcing; default is 0 + the encoder's final several outputs in the final output; + default is 0 which means no encoder output is included in the final output; teacher_forcing_ratio : float, optional the probability of using teacher forcing - en_output_size : int, optional - the encoder's final several outputs in the final output; - default is 1 which means the final encoder output is included in the final output """ super(GeneralSeq2Seq, self).__init__() self.trg_len = forecast_length self.prec_window = prec_window self.teacher_forcing_ratio = teacher_forcing_ratio - self.en_output_size = en_output_size self.encoder = Encoder( input_dim=en_input_size, hidden_dim=hidden_size, output_dim=output_size ) @@ -179,6 +175,9 @@ def forward(self, *src): if trgs is None or self.teacher_forcing_ratio <= 0: current_input = output else: + # TODO: teacher forcing has no streamflow now? maybe need a mask to choose streamflow + # trgs is retrieved from the seq2seqdataset, and its time-length is prec_window(encoder) + all decoder steps + # hence for decoder step t, the target variable is trgs[:, prec_window + t, :] sm_trg = trgs[:, (self.prec_window + t), 1].unsqueeze(1).unsqueeze(1) # most of soil moisture from remote sensing are not nan, # so if we meet nan values, we just ignore the teacher forcing @@ -197,8 +196,8 @@ def forward(self, *src): current_input = output outputs = torch.stack(outputs, dim=1) - if self.en_output_size > 0: - prec_outputs = encoder_outputs[:, -self.en_output_size :, :] + if self.prec_window > 0: + prec_outputs = encoder_outputs[:, -self.prec_window :, :] outputs = torch.cat((prec_outputs, outputs), dim=1) return outputs