Skip to content

Commit

Permalink
rolling should be deprecated in the near future
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Nov 6, 2024
1 parent eae1eb9 commit 9c69216
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 140 deletions.
1 change: 0 additions & 1 deletion experiments/evaluate_with_era5land.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def get_config_data():
test_period=[("2015-06-01-01", "2016-05-31-01")],
which_first_tensor="batch",
rolling=True,
long_seq_pred=False,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
1 change: 0 additions & 1 deletion experiments/evaluate_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get_config_data():
test_period=[("2015-06-01-01", "2016-06-01-01")],
which_first_tensor="batch",
rolling=True,
long_seq_pred=False,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
1 change: 0 additions & 1 deletion experiments/evaluate_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def get_config_data():
test_period=[("2015-06-01-01", "2016-05-31-01")],
which_first_tensor="batch",
rolling=True,
long_seq_pred=False,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
2 changes: 0 additions & 2 deletions experiments/train_with_era5land.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def config():
"lr_factor": 0.9,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
calc_metrics=False,
early_stopping=True,
# ensemble=True,
Expand Down
2 changes: 0 additions & 2 deletions experiments/train_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def config():
"lr_factor": 0.9,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
calc_metrics=False,
early_stopping=True,
# ensemble=True,
Expand Down
1 change: 0 additions & 1 deletion experiments/train_with_gpm_dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def create_config():
"lr_factor": 0.96,
},
which_first_tensor="batch",
rolling=False,
static=False,
early_stopping=True,
patience=8,
Expand Down
2 changes: 0 additions & 2 deletions experiments/train_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def create_config():
"lr_factor": 0.96,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
early_stopping=True,
patience=8,
model_type="MTL",
Expand Down
7 changes: 1 addition & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ def s2s_args(basin4test):
"lr_factor": 0.96,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
calc_metrics=False,
early_stopping=True,
patience=8,
Expand Down Expand Up @@ -355,8 +353,6 @@ def trans_args(basin4test):
"lr_factor": 0.96,
},
which_first_tensor="sequence",
rolling=False,
long_seq_pred=False,
calc_metrics=False,
early_stopping=True,
patience=8,
Expand Down Expand Up @@ -536,8 +532,7 @@ def seq2seq_config():
"lr_factor": 0.9,
},
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
rolling=True,
calc_metrics=False,
early_stopping=True,
# ensemble=True,
Expand Down
1 change: 0 additions & 1 deletion tests/test_evaluate_grid_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def config_data():
test_period=[
("2017-07-01", "2017-09-29"),
],
rolling=False,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "MutiBasinScaler_stat.json"),
continue_train=False,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_evaluate_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def config_data():
},
test_period=[("2015-05-01", "2016-05-31")],
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_pretrain_dataenhanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def config():
# "14306500",
],
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
early_stopping=True,
patience=10,
ensemble=True,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_pretrain_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def config():
"01414500",
],
which_first_tensor="batch",
rolling=False,
long_seq_pred=False,
early_stopping=True,
patience=4,
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_train_grid_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def config():
which_first_tensor="sequence",
early_stopping=True,
patience=4, # 连续n次valid loss不下降,则停止训练,与early_stopping配合使用
rolling=False, # evaluate 不采用滚动预测
ensemble=True, # 交叉验证
ensemble_items={
"kfold": 5, # exi_0即17年验证,...exi_4即21年验证
Expand Down
17 changes: 5 additions & 12 deletions torchhydro/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,11 @@ def default_config_file():
"metrics": ["NSE", "RMSE", "R2", "KGE", "FHV", "FLV"],
"fill_nan": "no",
"explainer": None,
"rolling": None,
"long_seq_pred": True,
# rolling means testdataloader will sample data with overlap time
# rolling is False meaning each time has only one output for one basin one variable
# rolling is True and the time_window must be prec_window+horizon now!
# for example, data is |1|2|3|4| time_window=2 then the samples are |1|2|, |2|3| and |3|4|
"rolling": False,
"calc_metrics": True,
},
}
Expand Down Expand Up @@ -352,7 +355,6 @@ def cmd(
fill_nan=None,
explainer=None,
rolling=None,
long_seq_pred=None,
calc_metrics=None,
start_epoch=1,
stat_dict_file=None,
Expand Down Expand Up @@ -744,13 +746,6 @@ def cmd(
default=model_loader,
type=int,
)
parser.add_argument(
"--long_seq_pred",
dest="long_seq_pred",
help="if True, direct, one-step, long-term sequence prediction",
default=long_seq_pred,
type=bool,
)
parser.add_argument(
"--calc_metrics",
dest="calc_metrics",
Expand Down Expand Up @@ -983,8 +978,6 @@ def update_cfg(cfg_file, new_args):
cfg_file["data_cfgs"]["constant_only"] = bool(new_args.constant_only != 0)
else:
cfg_file["data_cfgs"]["target_as_input"] = True
if new_args.long_seq_pred is not None:
cfg_file["evaluation_cfgs"]["long_seq_pred"] = new_args.long_seq_pred
if new_args.calc_metrics is not None:
cfg_file["evaluation_cfgs"]["calc_metrics"] = new_args.calc_metrics
if new_args.train_epoch is not None:
Expand Down
23 changes: 18 additions & 5 deletions torchhydro/datasets/data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:16:53
LastEditTime: 2024-11-05 20:02:49
LastEditTime: 2024-11-06 08:11:13
LastEditors: Wenyu Ouyang
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
FilePath: \torchhydro\torchhydro\datasets\data_sets.py
Expand Down Expand Up @@ -645,10 +645,23 @@ def _read_xyc(self):
start_date = self.t_s_dict["t_final_range"][0]
end_date = self.t_s_dict["t_final_range"][1]
interval = self.data_cfgs["min_time_interval"]
# TODO: Now only support hour, need to better handle different time units, such as Month, Day
adjusted_end_date = (
datetime.strptime(end_date, "%Y-%m-%d-%H") + timedelta(hours=interval)
).strftime("%Y-%m-%d-%H")
time_unit = self.data_cfgs["min_time_unit"]

# Determine the date format
date_format = detect_date_format(end_date)

# Adjust the end date based on the time unit
end_date_dt = datetime.strptime(end_date, date_format)
if time_unit == "h":
adjusted_end_date = (end_date_dt + timedelta(hours=interval)).strftime(
date_format
)
elif time_unit == "D":
adjusted_end_date = (end_date_dt + timedelta(days=interval)).strftime(
date_format
)
else:
raise ValueError(f"Unsupported time unit: {time_unit}")
self._read_xyc_specified_time(start_date, adjusted_end_date)

def _normalize(self):
Expand Down
Loading

0 comments on commit 9c69216

Please sign in to comment.