Skip to content

Commit

Permalink
new interface and setting
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed May 27, 2024
1 parent 1aa85ea commit a13e3fd
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions torchhydro/datasets/data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:16:53
LastEditTime: 2024-05-27 17:05:14
LastEditTime: 2024-05-27 17:48:17
LastEditors: Wenyu Ouyang
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
FilePath: \torchhydro\torchhydro\datasets\data_sets.py
Expand Down Expand Up @@ -141,7 +141,7 @@ def nt(self):
"""
if isinstance(self.t_s_dict["t_final_range"][0], tuple):
trange_type_num = len(self.t_s_dict["t_final_range"])
if trange_type_num != self.ngrid:
if trange_type_num not in [self.ngrid, 1]:
raise ValueError(
"The number of time ranges should be equal to the number of basins "
"if you choose different time ranges for different basins"
Expand Down Expand Up @@ -191,7 +191,7 @@ def times(self):
if isinstance(self.t_s_dict["t_final_range"][0], tuple):
times_ = []
trange_type_num = len(self.t_s_dict["t_final_range"])
if trange_type_num != self.ngrid:
if trange_type_num not in [self.ngrid, 1]:
raise ValueError(
"The number of time ranges should be equal to the number of basins "
"if you choose different time ranges for different basins"
Expand Down Expand Up @@ -426,7 +426,7 @@ def __len__(self):

class DplDataset(BaseDataset):
"""pytorch dataset for Differential parameter learning"""

# TODO: USE NUMPY ARRAY INSTEAD OF DATAARRAY FOR GET_ITEM
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
"""
Parameters
Expand Down Expand Up @@ -1131,7 +1131,8 @@ def _prepare_forcing(self):
var_subset_list = []
for start_date, end_date in t_range:
adjusted_start_date = (
datetime.strptime(start_date, "%Y-%m-%d-%H") - timedelta(hours=self.rho)
datetime.strptime(start_date, "%Y-%m-%d-%H")
- timedelta(hours=self.rho * self.data_cfgs["min_time_interval"])
).strftime("%Y-%m-%d-%H")
adjusted_end_date = (
datetime.strptime(end_date, "%Y-%m-%d-%H")
Expand Down Expand Up @@ -1161,12 +1162,19 @@ def _prepare_target(self):
for start_date, end_date in t_range:
adjusted_start_date = (
datetime.strptime(start_date, "%Y-%m-%d-%H")
- timedelta(hours=(self.data_cfgs["prec_window"]))
- timedelta(
hours=(
self.data_cfgs["prec_window"]
* self.data_cfgs["min_time_interval"]
)
)
).strftime("%Y-%m-%d-%H")

adjusted_end_date = (
datetime.strptime(end_date, "%Y-%m-%d-%H")
+ timedelta(hours=self.forecast_length)
+ timedelta(
hours=self.forecast_length * self.data_cfgs["min_time_interval"]
)
).strftime("%Y-%m-%d-%H")
subset = data.sel(time=slice(adjusted_start_date, adjusted_end_date))
subset_list.append(subset)
Expand Down

0 comments on commit a13e3fd

Please sign in to comment.