Skip to content

Commit

Permalink
add one more time step for seq2seqdataset _read_xyc
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 5, 2024
1 parent e557287 commit eae1eb9
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 160 deletions.
133 changes: 1 addition & 132 deletions tests/test_data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-05-27 13:33:08
LastEditTime: 2024-11-05 11:46:01
LastEditTime: 2024-11-05 18:20:19
LastEditors: Wenyu Ouyang
Description: Unit test for datasets
FilePath: \torchhydro\tests\test_data_sets.py
Expand Down Expand Up @@ -149,134 +149,3 @@ def test_create_lookup_table(tmp_path):
isinstance(key, int) and isinstance(value, tuple)
for key, value in lookup_table.items()
)


def test_seq2seqdataset_getitem_train_mode(monkeypatch):
# 模拟 os.listdir 返回值
def mock_listdir(path):
if "timeseries" in path:
return ["1D"]
elif "attributes" in path:
return ["attributes.csv"]
return []

# 模拟 os.path.isdir 返回值
def mock_isdir(path):
return True

# 模拟 pandas.read_csv 返回值
def mock_read_csv(filepath, *args, **kwargs):
if "attributes.csv" in filepath:
return pd.DataFrame(
{
"basin_id": ["01013500", "01013501"],
"area": [100.0, 200.0],
"attr1": [1.0, 2.0],
"attr2": [3.0, 4.0],
}
)
elif "timeseries" in filepath:
return pd.DataFrame(
{
"date": pd.date_range(start="2001-01-01", periods=365, freq="D"),
"prcp": [0.1] * 365,
"pet": [0.05] * 365,
"streamflow": [1.0] * 365,
"surface_sm": [0.3] * 365,
}
)
return pd.DataFrame()

# 模拟 os.path.exists 返回值
def mock_exists(path):
return True

# 模拟 os.path.join 返回值
def mock_join(a, *p):
return a + "/" + "/".join(p)

# 使用 monkeypatch 模拟 os.listdir、os.path.isdir、os.path.exists、os.path.join 和 pandas.read_csv
monkeypatch.setattr(os, "listdir", mock_listdir)
monkeypatch.setattr(os.path, "isdir", mock_isdir)
monkeypatch.setattr(os.path, "exists", mock_exists)
monkeypatch.setattr(os.path, "join", mock_join)
monkeypatch.setattr(pd, "read_csv", mock_read_csv)

data_sources_dict.update({"mockdatasource": MockDatasource})
data_cfgs = {
"source_cfgs": {
"source_name": "mockdatasource",
"source_path": "mock_path",
"other_settings": {"time_unit": ["1D"]},
},
"object_ids": ["01013500", "01013501"],
"t_range_train": ["2001-01-01", "2002-01-01"],
"t_range_test": ["2002-01-01", "2003-01-01"],
"relevant_cols": ["prcp", "pet"],
"target_cols": ["streamflow", "surface_sm"],
"constant_cols": ["geol_1st_class", "geol_2nd_class"],
"forecast_history": 7,
"warmup_length": 14,
"forecast_length": 1,
"min_time_unit": "D",
"min_time_interval": 1,
"target_rm_nan": True,
"relevant_rm_nan": True,
"constant_rm_nan": True,
"prec_window": 3,
"en_output_size": 5,
}
is_tra_val_te = "train"
dataset = Seq2SeqDataset(data_cfgs, is_tra_val_te)
item = 0
(x, x_h, y), y_out = dataset[item]
assert isinstance(x, torch.Tensor)
assert isinstance(x_h, torch.Tensor)
assert isinstance(y, torch.Tensor)
assert isinstance(y_out, torch.Tensor)
assert (
x.shape[1]
== len(data_cfgs["relevant_cols"]) + len(data_cfgs["constant_cols"]) + 1
)
assert x_h.shape[1] == len(data_cfgs["constant_cols"]) + 1
assert y.shape[1] == len(data_cfgs["target_cols"])
assert y_out.shape == y.shape


def test_seq2seqdataset_getitem_test_mode():
data_sources_dict.update({"mockdatasource": MockDatasource})
data_cfgs = {
"source_cfgs": {
"source_name": "mockdatasource",
"source_path": "mock_path",
},
"object_ids": ["01013500", "01013501"],
"t_range_train": ["2001-01-01", "2002-01-01"],
"t_range_test": ["2002-01-01", "2003-01-01"],
"relevant_cols": ["prcp", "pet"],
"target_cols": ["streamflow", "surface_sm"],
"constant_cols": ["geol_1st_class", "geol_2nd_class"],
"forecast_history": 7,
"warmup_length": 14,
"forecast_length": 1,
"min_time_unit": "D",
"min_time_interval": 1,
"target_rm_nan": True,
"relevant_rm_nan": True,
"constant_rm_nan": True,
"prec_window": 3,
"en_output_size": 5,
}
is_tra_val_te = "test"
dataset = Seq2SeqDataset(data_cfgs, is_tra_val_te)
item = 0
(x, x_h), y = dataset[item]
assert isinstance(x, torch.Tensor)
assert isinstance(x_h, torch.Tensor)
assert isinstance(y, torch.Tensor)
assert (
x.shape[1]
== len(data_cfgs["relevant_cols"]) + len(data_cfgs["constant_cols"]) + 1
)
assert x_h.shape[1] == len(data_cfgs["constant_cols"]) + 1
assert y.shape[1] == len(data_cfgs["target_cols"])
4 changes: 0 additions & 4 deletions torchhydro/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,6 @@ def update_cfg(cfg_file, new_args):
cfg_file["data_cfgs"]["prec_window"] = new_args.model_hyperparam[
"prec_window"
]
if "en_output_size" in new_args.model_hyperparam.keys():
cfg_file["data_cfgs"]["en_output_size"] = new_args.model_hyperparam[
"en_output_size"
]
if new_args.batch_size is not None:
# raise AttributeError("Please set the batch_size!!!")
batch_size = new_args.batch_size
Expand Down
75 changes: 60 additions & 15 deletions torchhydro/datasets/data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:16:53
LastEditTime: 2024-11-05 11:29:27
LastEditTime: 2024-11-05 20:02:49
LastEditors: Wenyu Ouyang
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
FilePath: \torchhydro\torchhydro\datasets\data_sets.py
Expand Down Expand Up @@ -340,15 +340,31 @@ def _read_xyc(self):
x, y, c data
"""
# x
start_date = self.t_s_dict["t_final_range"][0]
end_date = self.t_s_dict["t_final_range"][1]
self._read_xyc_specified_time(start_date, end_date)

def _read_xyc_specified_time(self, start_date, end_date):
"""Read x, y, c data from data source with specified time range
We set this function as sometimes we need adjust the time range for some specific dataset,
such as seq2seq dataset (it needs one more period for the end of the time range)
Parameters
----------
start_date : str
start time
end_date : str
end time
"""
data_forcing_ds_ = self.data_source.read_ts_xrdataset(
self.t_s_dict["sites_id"],
self.t_s_dict["t_final_range"],
[start_date, end_date],
self.data_cfgs["relevant_cols"],
)
# y
data_output_ds_ = self.data_source.read_ts_xrdataset(
self.t_s_dict["sites_id"],
self.t_s_dict["t_final_range"],
[start_date, end_date],
self.data_cfgs["target_cols"],
)
if isinstance(data_output_ds_, dict) or isinstance(data_forcing_ds_, dict):
Expand Down Expand Up @@ -615,6 +631,26 @@ class Seq2SeqDataset(BaseDataset):
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
super(Seq2SeqDataset, self).__init__(data_cfgs, is_tra_val_te)

def _read_xyc(self):
"""
NOTE: the lookup table is same as BaseDataset,
but the data retrieved from datasource should has one more period,
because we include the concepts of start and end moment of the period
Returns
-------
tuple[xr.Dataset, xr.Dataset, xr.Dataset]
x, y, c data
"""
start_date = self.t_s_dict["t_final_range"][0]
end_date = self.t_s_dict["t_final_range"][1]
interval = self.data_cfgs["min_time_interval"]
# TODO: Now only support hour, need to better handle different time units, such as Month, Day
adjusted_end_date = (
datetime.strptime(end_date, "%Y-%m-%d-%H") + timedelta(hours=interval)
).strftime("%Y-%m-%d-%H")
self._read_xyc_specified_time(start_date, adjusted_end_date)

def _normalize(self):
x, y, c = super()._normalize()
# TODO: this work for minio? maybe better to move to basedataset
Expand All @@ -627,29 +663,38 @@ def __getitem__(self, item: int):
basin, time = self.lookup_table[item]
rho = self.rho
horizon = self.horizon
prec = self.data_cfgs["prec_window"]
en_output_size = self.data_cfgs["en_output_size"]

prec = self.data_cfgs.get("prec_window", 0)
# p cover all encoder-decoder periods; +1 means the period while +0 means start of the current period
p = self.x[basin, time + 1 : time + rho + horizon + 1, 0].reshape(-1, 1)
# s only cover encoder periods
s = self.x[basin, time : time + rho, 1:]
x = np.concatenate((p[:rho], s), axis=1)

c = self.c[basin, :]
c = np.tile(c, (rho + horizon, 1))
x = np.concatenate((x, c[:rho]), axis=1)

x_h = np.concatenate((p[rho:], c[rho:]), axis=1)
if self.c is None or self.c.shape[-1] == 0:
xc = x
else:
c = self.c[basin, :]
c = np.tile(c, (rho + horizon, 1))
xc = np.concatenate((x, c[:rho]), axis=1)
# xh cover decoder periods
try:
xh = np.concatenate((p[rho:], c[rho:]), axis=1)
except ValueError as e:
print(f"Error in np.concatenate: {e}")
print(f"p[rho:].shape: {p[rho:].shape}, c[rho:].shape: {c[rho:].shape}")
raise
# y cover specified encoder size (prec_window) and all decoder periods
y = self.y[basin, time + rho - prec + 1 : time + rho + horizon + 1, :]

if self.is_tra_val_te == "train":
return [
torch.from_numpy(x).float(),
torch.from_numpy(x_h).float(),
torch.from_numpy(xc).float(),
torch.from_numpy(xh).float(),
torch.from_numpy(y).float(),
], torch.from_numpy(y).float()
return [
torch.from_numpy(x).float(),
torch.from_numpy(x_h).float(),
torch.from_numpy(xc).float(),
torch.from_numpy(xh).float(),
], torch.from_numpy(y).float()


Expand Down
17 changes: 8 additions & 9 deletions torchhydro/models/seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-17 12:32:26
LastEditTime: 2024-11-01 12:01:16
LastEditTime: 2024-11-05 19:10:02
LastEditors: Wenyu Ouyang
Description:
FilePath: \torchhydro\torchhydro\models\seq2seq.py
Expand Down Expand Up @@ -123,7 +123,6 @@ def __init__(
forecast_length,
prec_window=0,
teacher_forcing_ratio=0.5,
en_output_size=1,
):
"""General Seq2Seq model
Expand All @@ -140,18 +139,15 @@ def __init__(
forecast_length : _type_
the length of the forecast, i.e., the periods of decoder outputs
prec_window : int, optional
starting index of decoder output for teacher forcing; default is 0
the encoder's final several outputs in the final output;
default is 0 which means no encoder output is included in the final output;
teacher_forcing_ratio : float, optional
the probability of using teacher forcing
en_output_size : int, optional
the encoder's final several outputs in the final output;
default is 1 which means the final encoder output is included in the final output
"""
super(GeneralSeq2Seq, self).__init__()
self.trg_len = forecast_length
self.prec_window = prec_window
self.teacher_forcing_ratio = teacher_forcing_ratio
self.en_output_size = en_output_size
self.encoder = Encoder(
input_dim=en_input_size, hidden_dim=hidden_size, output_dim=output_size
)
Expand Down Expand Up @@ -179,6 +175,9 @@ def forward(self, *src):
if trgs is None or self.teacher_forcing_ratio <= 0:
current_input = output
else:
# TODO: teacher forcing has no streamflow now? maybe need a mask to choose streamflow
# trgs is retrieved from the seq2seqdataset, and its time-length is prec_window(encoder) + all decoder steps
# hence for decoder step t, the target variable is trgs[:, prec_window + t, :]
sm_trg = trgs[:, (self.prec_window + t), 1].unsqueeze(1).unsqueeze(1)
# most of soil moisture from remote sensing are not nan,
# so if we meet nan values, we just ignore the teacher forcing
Expand All @@ -197,8 +196,8 @@ def forward(self, *src):
current_input = output

outputs = torch.stack(outputs, dim=1)
if self.en_output_size > 0:
prec_outputs = encoder_outputs[:, -self.en_output_size :, :]
if self.prec_window > 0:
prec_outputs = encoder_outputs[:, -self.prec_window :, :]
outputs = torch.cat((prec_outputs, outputs), dim=1)
return outputs

Expand Down

0 comments on commit eae1eb9

Please sign in to comment.