diff --git a/tests/test_train_mean_lstm.py b/tests/test_train_mean_lstm.py index 0d2d3a8..bb450d3 100644 --- a/tests/test_train_mean_lstm.py +++ b/tests/test_train_mean_lstm.py @@ -1,11 +1,11 @@ """ Author: Xinzhuo Wu -Date: 2023-07-25 16:47:19 -LastEditTime: 2024-04-08 09:59:17 -LastEditors: Wenyu Ouyang +Date: 2024-04-08 18:13:05 +LastEditTime: 2024-04-08 18:13:05 +LastEditors: Xinzhuo Wu Description: Test a full training and evaluating process -FilePath: \torchhydro\tests\test_train_mean_lstm.py -Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. +FilePath:/torchhydro/tests/test_train_mean_lstm.py +Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved. """ import pytest @@ -21,15 +21,13 @@ def config(): sub=project_name, source_cfgs={ "source": "HydroMean", - "source_path": [ - { - "forcing": "basins-origin/hour_data/1h/mean_data/mean_data_forcing", - "target": "basins-origin/hour_data/1h/mean_data/mean_data_target", - "attributes": "basins-origin/attributes.nc", - } - ], + "source_path": { + "forcing": "basins-origin/hour_data/1h/mean_data/mean_data_forcing", + "target": "basins-origin/hour_data/1h/mean_data/mean_data_target", + "attributes": "basins-origin/attributes.nc", + }, }, - ctx=[0], + ctx=[1], model_name="SimpleLSTMForecast", model_hyperparam={ "input_size": 16, diff --git a/torchhydro/configs/config.py b/torchhydro/configs/config.py index 134cf72..33fa414 100644 --- a/torchhydro/configs/config.py +++ b/torchhydro/configs/config.py @@ -205,7 +205,6 @@ def default_config_file(): # sampler for pytorch dataloader, here we mainly use it for Kuai Fang's sampler in all his DL papers "sampler": None, "loading_batch": None, - "user": None, }, "training_cfgs": { # if train_mode is False, don't train and evaluate @@ -336,12 +335,6 @@ def cmd( ensemble_items=None, early_stopping=None, patience=None, - user=None, - endpoint_url=None, - access_key=None, - secret_key=None, - bucket_name=None, - folder_prefix=None, ): """input args from cmd""" parser = argparse.ArgumentParser( @@ -766,13 +759,6 @@ def cmd( default=early_stopping, type=bool, ) - parser.add_argument( - "--user", - dest="user", - help="user_name to distinguish trainer or tester", - default=user, - type=str, - ) parser.add_argument( "--patience", dest="patience", @@ -780,41 +766,6 @@ def cmd( default=patience, type=int, ) - parser.add_argument( - "--endpoint_url", - dest="endpoint_url", - help="endpoint_url", - default=endpoint_url, - type=str, - ) - parser.add_argument( - "--access_key", - dest="access_key", - help="access_key", - default=access_key, - type=str, - ) - parser.add_argument( - "--secret_key", - dest="secret_key", - help="secret_key", - default=secret_key, - type=str, - ) - parser.add_argument( - "--bucket_name", - dest="bucket_name", - help="bucket_name", - default=bucket_name, - type=str, - ) - parser.add_argument( - "--folder_prefix", - dest="folder_prefix", - help="folder_prefix", - default=folder_prefix, - type=str, - ) # To make pytest work in PyCharm, here we use the following code instead of "args = parser.parse_args()": # https://blog.csdn.net/u014742995/article/details/100119905 args, unknown = parser.parse_known_args() @@ -1087,8 +1038,8 @@ def update_cfg(cfg_file, new_args): cfg_file["training_cfgs"]["patience"] = new_args.patience if new_args.early_stopping is not None: cfg_file["training_cfgs"]["early_stopping"] = new_args.early_stopping - if new_args.user is not None: - cfg_file["data_cfgs"]["user"] = new_args.user + if new_args.lr_scheduler is not None: + cfg_file["training_cfgs"]["lr_scheduler"] = new_args.lr_scheduler # print("the updated config:\n", json.dumps(cfg_file, indent=4, ensure_ascii=False)) diff --git a/torchhydro/datasets/data_scalers.py b/torchhydro/datasets/data_scalers.py index 20221af..c21b0ee 100644 --- a/torchhydro/datasets/data_scalers.py +++ b/torchhydro/datasets/data_scalers.py @@ -1,3 +1,13 @@ +""" +Author: Wenyu Ouyang +Date: 2024-04-08 18:17:44 +LastEditTime: 2024-04-08 18:17:44 +LastEditors: Xinzhuo Wu +Description: normalize the data +FilePath: /torchhydro/torchhydro/datasets/data_scalers.py +Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved. +""" + import copy import json import os @@ -220,6 +230,7 @@ def __init__( self.log_norm_cols = gamma_norm_cols + prcp_norm_cols self.pbm_norm = pbm_norm self.data_source = data_source + self.read_meanprep() # save stat_dict of training period in test_path for valid/test stat_file = os.path.join(data_cfgs["test_path"], "dapengscaler_stat.json") # for testing sometimes such as pub cases, we need stat_dict_file from trained dataset @@ -236,6 +247,17 @@ def __init__( with open(stat_file, "r") as fp: self.stat_dict = json.load(fp) + def read_meanprep(self): + if isinstance(self.data_source, HydroBasins): + self.mean_prep = self.data_source.read_MP( + self.t_s_dict["sites_id"], + self.data_cfgs["source_cfgs"]["source_path"]["attributes"], + ) + else: + self.mean_prep = self.data_source.read_mean_prcp( + self.t_s_dict["sites_id"] + ).to_array() + def inverse_transform(self, target_values): """ Denormalization for output variables @@ -265,41 +287,14 @@ def inverse_transform(self, target_values): ) for i in range(len(self.data_cfgs["target_cols"])): var = self.data_cfgs["target_cols"][i] - if var in self.prcp_norm_cols: - if var == "qobs_mm_per_hour": # should be deleted - mean_prep = self.read_attr_xrdataset( - self.t_s_dict["sites_id"], ["p_mean"] - ) - pred.loc[dict(variable=var)] = _prcp_norm( - pred.sel(variable=var).to_numpy(), - mean_prep.to_array().to_numpy().T, - to_norm=False, - ) - elif isinstance(self.data_source, HydroBasins): - mean_prep = self.data_source.read_MP( - self.t_s_dict["sites_id"], - self.data_cfgs["data_path"]["attributes"], - ) - pred.loc[dict(variable=var)] = self.mean_prcp_norm( - pred.sel(variable=var).to_numpy().T, - mean_prep.to_numpy(), - to_norm=False, - ) - else: - mean_prep = self.data_source.read_mean_prcp( - self.t_s_dict["sites_id"] - ) - pred.loc[dict(variable=var)] = _prcp_norm( - pred.sel(variable=var).to_numpy(), - mean_prep.to_array().to_numpy().T, - to_norm=False, - ) + pred.loc[dict(variable=var)] = _prcp_norm( + pred.sel(variable=var).to_numpy(), + self.mean_prep.to_numpy().T, + to_norm=False, + ) # add attrs for units pred.attrs.update(self.data_target.attrs) - # trans to xarray dataset - pred_ds = pred.to_dataset(dim="variable") - - return pred_ds + return pred.to_dataset(dim="variable") def cal_stat_all(self): """ @@ -316,31 +311,10 @@ def cal_stat_all(self): for i in range(len(target_cols)): var = target_cols[i] if var in self.prcp_norm_cols: - if var == "qobs_mm_per_hour": # should be deleted if data all in minio - mean_prep = self.read_attr_xrdataset( - self.t_s_dict["sites_id"], ["p_mean"] - ) - stat_dict[var] = cal_stat_prcp_norm( - self.data_target.sel(variable=var).to_numpy(), - mean_prep.to_array().to_numpy().T, - ) - elif isinstance(self.data_source, HydroBasins): - mean_prep = self.data_source.read_MP( - self.t_s_dict["sites_id"], - self.data_cfgs["data_path"]["attributes"], - ) - stat_dict[var] = self.mean_cal_stat_prcp_norm( - self.data_target.sel(variable=var).to_numpy().T, - mean_prep.to_numpy(), - ) - else: - mean_prep = self.data_source.read_mean_prcp( - self.t_s_dict["sites_id"] - ) - stat_dict[var] = cal_stat_prcp_norm( - self.data_target.sel(variable=var).to_numpy(), - mean_prep.to_array().to_numpy().T, - ) + stat_dict[var] = cal_stat_prcp_norm( + self.data_target.sel(variable=var).to_numpy(), + self.mean_prep.to_numpy().T, + ) elif var in self.gamma_norm_cols: stat_dict[var] = cal_stat_gamma( self.data_target.sel(variable=var).to_numpy() @@ -392,34 +366,11 @@ def get_data_obs(self, to_norm: bool = True) -> np.array: for i in range(len(target_cols)): var = target_cols[i] if var in self.prcp_norm_cols: - if var == "qobs_mm_per_hour": # should be deleted - mean_prep = self.read_attr_xrdataset( - self.t_s_dict["sites_id"], ["p_mean"] - ) - out.loc[dict(variable=var)] = _prcp_norm( - data.sel(variable=var).to_numpy(), - mean_prep.to_array().to_numpy().T, - to_norm=True, - ) - elif isinstance(self.data_source, HydroBasins): - mean_prep = self.data_source.read_MP( - self.t_s_dict["sites_id"], - self.data_cfgs["data_path"]["attributes"], - ) - out.loc[dict(variable=var)] = self.mean_prcp_norm( - data.sel(variable=var).to_numpy().T, - mean_prep.to_numpy(), - to_norm=True, - ) - else: - mean_prep = self.data_source.read_mean_prcp( - self.t_s_dict["sites_id"] - ) - out.loc[dict(variable=var)] = _prcp_norm( - data.sel(variable=var).to_numpy(), - mean_prep.to_array().to_numpy().T, - to_norm=True, - ) + out.loc[dict(variable=var)] = _prcp_norm( + data.sel(variable=var).to_numpy(), + self.mean_prep.to_numpy().T, + to_norm=True, + ) out.attrs["units"][var] = "dimensionless" out = _trans_norm( out, @@ -430,13 +381,6 @@ def get_data_obs(self, to_norm: bool = True) -> np.array: ) return out - # temporarily used, it's related to hydrodataset - def read_attr_xrdataset(self, gage_id_lst=None, var_lst=None, **kwargs): - if var_lst is None or len(var_lst) == 0: - return None - attr = xr.open_dataset(os.path.join("/ftproot", "camelsus_attributes_us.nc")) - return attr[var_lst].sel(basin=gage_id_lst) - def get_data_ts(self, to_norm=True) -> np.array: """ Get dynamic input data @@ -499,17 +443,6 @@ def load_data(self): c = self.get_data_const() return x, y, c - def mean_cal_stat_prcp_norm(self, x, meanprep): - tempprep = np.tile(meanprep, (x.shape[0], 1)) - flowua = (x / tempprep).T - return cal_stat_gamma(flowua) - - def mean_prcp_norm( - self, x: np.array, mean_prep: np.array, to_norm: bool - ) -> np.array: - tempprep = np.tile(mean_prep, (x.shape[0], 1)) - return (x / tempprep).T if to_norm else (x * tempprep).T - class MutiBasinScaler(object): def __init__( @@ -620,9 +553,7 @@ def cal_stat_all(self): if var in self.gamma_norm_cols: stat_dict[var] = self.grid_cal_stat_gamma(data_smap, var) - # const attribute - attr_lst = self.data_cfgs["constant_cols"] - if attr_lst: + if attr_lst := self.data_cfgs["constant_cols"]: data_attr = self.data_attr for k in range(len(attr_lst)): var = attr_lst[k] @@ -688,8 +619,16 @@ def get_data_const(self, to_norm=True) -> np.array: def load_data(self): x = self.get_data_ts(0).compute() - g = self.get_data_ts(1).compute() if self.data_cfgs["relevant_cols"][1] != ["None"] else None - s = self.get_data_ts(2).compute() if self.data_cfgs["relevant_cols"][2] != ["None"] else None + g = ( + self.get_data_ts(1).compute() + if self.data_cfgs["relevant_cols"][1] != ["None"] + else None + ) + s = ( + self.get_data_ts(2).compute() + if self.data_cfgs["relevant_cols"][2] != ["None"] + else None + ) y = self.get_data_obs().compute() c = self.get_data_const().compute() if self.data_cfgs["constant_cols"] else None return x, y, c, g, s @@ -746,4 +685,4 @@ def grid_prcp_norm( self, x: np.array, mean_prep: np.array, to_norm: bool ) -> np.array: tempprep = np.tile(mean_prep, (x.shape[0], 1)) - return x / tempprep if to_norm else x * tempprep \ No newline at end of file + return x / tempprep if to_norm else x * tempprep diff --git a/torchhydro/datasets/data_sets.py b/torchhydro/datasets/data_sets.py index e0ef334..ff8fc66 100644 --- a/torchhydro/datasets/data_sets.py +++ b/torchhydro/datasets/data_sets.py @@ -1,11 +1,11 @@ """ Author: Wenyu Ouyang -Date: 2022-02-13 21:20:18 -LastEditTime: 2024-04-08 12:03:12 +Date: 2024-04-08 18:16:53 +LastEditTime: 2024-04-09 08:01:14 LastEditors: Wenyu Ouyang Description: A pytorch dataset class; references to https://github.com/neuralhydrology/neuralhydrology FilePath: \torchhydro\torchhydro\datasets\data_sets.py -Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved. +Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved. """ import logging @@ -512,7 +512,7 @@ def __init__(self, data_cfgs: dict, is_tra_val_te: str): @property def data_source(self): - return HydroBasins(self.data_cfgs["data_path"]) + return HydroBasins(self.data_cfgs["source_cfgs"]["source_path"]) def _normalize(self): x, y, c = super()._normalize() @@ -521,14 +521,14 @@ def _normalize(self): def _read_xyc(self): data_target_ds = self._prepare_target() if data_target_ds is not None: - y_origin = super()._trans2da_and_setunits(data_target_ds) + y_origin = self._trans2da_and_setunits(data_target_ds) else: y_origin = None data_forcing_ds = self._prepare_forcing() if data_forcing_ds is not None: - x_origin = super()._trans2da_and_setunits(data_forcing_ds) + x_origin = self._trans2da_and_setunits(data_forcing_ds) else: x_origin = None @@ -536,9 +536,9 @@ def _read_xyc(self): data_attr_ds = self.data_source.read_BA_xrdataset( self.t_s_dict["sites_id"], self.data_cfgs["constant_cols"], - self.data_cfgs["data_path"]["attributes"], + self.data_cfgs["source_cfgs"]["source_path"]["attributes"], ) - c_orgin = super()._trans2da_and_setunits(data_attr_ds) + c_orgin = self._trans2da_and_setunits(data_attr_ds) else: c_orgin = None self.x_origin, self.y_origin, self.c_origin = x_origin, y_origin, c_orgin @@ -547,7 +547,7 @@ def _prepare_target(self): gage_id_lst = self.t_s_dict["sites_id"] t_range = self.t_s_dict["t_final_range"] var_lst = self.data_cfgs["target_cols"] - path = self.data_cfgs["data_path"]["target"] + path = self.data_cfgs["source_cfgs"]["source_path"]["target"] if var_lst is None or not var_lst: return None @@ -561,7 +561,7 @@ def _prepare_target(self): for start_date, end_date in t_range: adjusted_end_date = ( datetime.strptime(end_date, "%Y-%m-%d") - + timedelta(hours=self.forecast_length) + + timedelta(hours=self.data_cfgs["forecast_length"]) ).strftime("%Y-%m-%d") subset = data.sel(time=slice(start_date, adjusted_end_date)) subset_list.append(subset) @@ -570,7 +570,7 @@ def _prepare_target(self): def _create_lookup_table(self): lookup = [] basins = self.t_s_dict["sites_id"] - forecast_length = self.forecast_length + forecast_length = self.data_cfgs["forecast_length"] warmup_length = self.warmup_length dates = self.y["time"].to_numpy() time_num = len(self.t_s_dict["t_final_range"]) @@ -645,7 +645,7 @@ def _prepare_forcing(self): gage_id_lst = self.t_s_dict["sites_id"] t_range = self.t_s_dict["t_final_range"] var_lst = self.data_cfgs["relevant_cols"] - path = self.data_cfgs["data_path"]["forcing"] + path = self.data_cfgs["source_cfgs"]["source_path"]["forcing"] if var_lst is None: return None diff --git a/torchhydro/trainers/deep_hydro.py b/torchhydro/trainers/deep_hydro.py index 0b35c38..6548484 100644 --- a/torchhydro/trainers/deep_hydro.py +++ b/torchhydro/trainers/deep_hydro.py @@ -1,11 +1,11 @@ """ Author: Wenyu Ouyang -Date: 2021-12-31 11:08:29 -LastEditTime: 2024-04-08 10:50:12 -LastEditors: Wenyu Ouyang +Date: 2024-04-08 18:15:48 +LastEditTime: 2024-04-08 18:15:48 +LastEditors: Xinzhuo Wu Description: HydroDL model class -FilePath: \torchhydro\torchhydro\trainers\deep_hydro.py -Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved. +FilePath:/torchhydro/torchhydro/trainers/deep_hydro.py +Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved. """ from abc import ABC, abstractmethod @@ -268,7 +268,6 @@ def model_train(self) -> None: training_cfgs, criterion, validation_data_loader, valid_logs ) - lr_val_loss = training_cfgs["lr_val_loss"] scheduler.step() logger.save_session_param( epoch, total_loss, n_iter_ep, valid_loss, valid_metrics @@ -276,9 +275,8 @@ def model_train(self) -> None: logger.save_model_and_params(self.model, epoch, self.cfgs) if es and not es.check_loss( self.model, - valid_loss if lr_val_loss else list(valid_metrics.items())[0][1][0], + valid_loss, self.cfgs["data_cfgs"]["test_path"], - lr_val_loss, ): print("Stopping model now") break @@ -289,13 +287,17 @@ def model_train(self) -> None: return self.model.state_dict(), sum(logger.epoch_loss) / len(logger.epoch_loss) def _get_scheduler(self, training_cfgs, opt): - # TODO: not finished yet - lr_scheduler = training_cfgs["lr_scheduler"] - if lr_scheduler is not None and epoch in lr_scheduler.keys(): - for param_group in opt.param_groups: - param_group["lr"] = lr_scheduler[epoch] - scheduler = ExponentialLR(opt, gamma=training_cfgs["lr_factor"]) - return scheduler + return ( + ExponentialLR(opt, gamma=training_cfgs["lr_scheduler"]["lr_factor"]) + if "lr_factor" in training_cfgs["lr_scheduler"] + and "lr_patience" not in training_cfgs["lr_scheduler"] + else ReduceLROnPlateau( + opt, + mode="min", + factor=training_cfgs["lr_factor"], + patience=training_cfgs["lr_patience"], + ) + ) def _1epoch_valid( self, training_cfgs, criterion, validation_data_loader, valid_logs @@ -452,7 +454,7 @@ def inference(self) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: not support return_cell_states yet return_cell_state = False if return_cell_state: - return cellstates_when_inference(seq_first, data_cfgs, pred) + return cellstates_when_inference(seq_first, data_cfgs, pred) pred_xr, obs_xr = denormalize4eval(test_dataloader, pred, obs) return pred_xr, obs_xr, self.testdataset diff --git a/torchhydro/trainers/train_utils.py b/torchhydro/trainers/train_utils.py index 7639bc8..f78bfef 100644 --- a/torchhydro/trainers/train_utils.py +++ b/torchhydro/trainers/train_utils.py @@ -1,11 +1,11 @@ """ Author: Wenyu Ouyang -Date: 2023-09-21 15:06:12 -LastEditTime: 2024-04-08 10:20:18 -LastEditors: Wenyu Ouyang +Date: 2024-04-08 18:16:26 +LastEditTime: 2024-04-08 18:16:26 +LastEditors: Xinzhuo Wu Description: Some basic functions for training -FilePath: \torchhydro\torchhydro\trainers\train_utils.py -Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. +FilePath: /torchhydro/torchhydro/trainers/train_utils.py +Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved. """ import copy @@ -220,19 +220,13 @@ def __init__( self.counter = 0 self.best_score = None - def check_loss(self, model, validation_loss, save_dir, lr_val_loss) -> bool: + def check_loss(self, model, validation_loss, save_dir) -> bool: score = validation_loss if self.best_score is None: self.save_model_checkpoint(model, save_dir) self.best_score = score - elif ( - (score + self.min_delta >= self.best_score) - if lr_val_loss - else (score + self.min_delta <= self.best_score) - ): - # if not self.cumulative_delta and score > self.best_score: - # self.best_score = score + elif score + self.min_delta >= self.best_score: self.counter += 1 print("Epochs without Model Update:", self.counter) if self.counter >= self.patience: