Skip to content

Commit

Permalink
merge dev, add some todos and found some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Apr 7, 2024
1 parent 21bb5fd commit a504b76
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 35 deletions.
16 changes: 2 additions & 14 deletions torchhydro/datasets/data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2022-02-13 21:20:18
LastEditTime: 2024-04-06 21:09:51
LastEditTime: 2024-04-07 21:45:45
LastEditors: Wenyu Ouyang
Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology
FilePath: \torchhydro\torchhydro\datasets\data_sets.py
Expand Down Expand Up @@ -537,6 +537,7 @@ def _read_xyc(self):
return x, y, c

def _normalize(self):
# TODO: bug for x -- after norm potential_evaporation is all nan
var_to_source_map = self.data_cfgs["var_to_source_map"]
for var_name in var_to_source_map:
source_name = var_to_source_map[var_name]
Expand All @@ -553,19 +554,6 @@ def _normalize(self):
self.target_scaler = scaler_hub.target_scaler
return scaler_hub.x, scaler_hub.y, scaler_hub.c

def __len__(self):
main_source_length = ...
return main_source_length

def __getitem__(self, idx):
# 合并来自不同数据源的数据
x = {}
for source_name, config in self.data_cfgs.items():
# 根据配置读取和预处理数据
data = self.read_and_process(source_name, idx, config)
x.update(data)
return x


class HydroGridDataset(BaseDataset):
def __init__(self, data_cfgs: dict, is_tra_val_te: str):
Expand Down
38 changes: 17 additions & 21 deletions torchhydro/trainers/deep_hydro.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2024-04-07 21:01:44
LastEditTime: 2024-04-07 21:25:48
LastEditors: Wenyu Ouyang
Description: HydroDL model class
FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py
Expand Down Expand Up @@ -248,6 +248,7 @@ def model_train(self) -> None:
training_cfgs, data_cfgs
)
logger = TrainLogger(model_filepath, self.cfgs, opt)
# TODO: need refactor opt config
scheduler = ExponentialLR(opt, gamma=training_cfgs["lr_factor"])
if training_cfgs["weight_decay"] is not None:
for param_group in opt.param_groups:
Expand Down Expand Up @@ -497,7 +498,7 @@ def _get_dataloader(self, training_cfgs, data_cfgs):
nt = train_dataset.y.time.size
if data_cfgs["sampler"] == "HydroSampler":
sampler = HydroSampler(train_dataset)
else:
elif data_cfgs["sampler"] == "KuaiSampler":
sampler = KuaiSampler(
train_dataset,
batch_size=batch_size,
Expand All @@ -506,6 +507,8 @@ def _get_dataloader(self, training_cfgs, data_cfgs):
ngrid=ngrid,
nt=nt,
)
else:
raise NotImplementedError("This sampler not implemented yet")
data_loader = DataLoader(
train_dataset,
batch_size=training_cfgs["batch_size"],
Expand All @@ -517,26 +520,19 @@ def _get_dataloader(self, training_cfgs, data_cfgs):
)
if data_cfgs["t_range_valid"] is not None:
valid_dataset = self.validdataset

batch_size_valid = training_cfgs["batch_size"]
if data_cfgs["sampler"] == "HydroSampler":
eval_num_samples = valid_dataset.num_samples
validation_data_loader = DataLoader(
valid_dataset,
batch_size=int(eval_num_samples / ngrid),
shuffle=False,
num_workers=worker_num,
pin_memory=pin_memory,
timeout=0,
)
else:
validation_data_loader = DataLoader(
valid_dataset,
batch_size=training_cfgs["batch_size"],
shuffle=False,
num_workers=worker_num,
pin_memory=pin_memory,
timeout=0,
)
# for HydroSampler when evaluating, we need to set new batch size
# TODO: may be same for other samplers
batch_size_valid = int(valid_dataset.num_samples / ngrid)
validation_data_loader = DataLoader(
valid_dataset,
batch_size=batch_size_valid,
shuffle=False,
num_workers=worker_num,
pin_memory=pin_memory,
timeout=0,
)
return data_loader, validation_data_loader

return data_loader, None
Expand Down

0 comments on commit a504b76

Please sign in to comment.