Skip to content

Commit

Permalink
Merge branch 'dev' into mtl
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Apr 9, 2024
2 parents f599e21 + d39ca71 commit b41796a
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 215 deletions.
24 changes: 11 additions & 13 deletions tests/test_train_mean_lstm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
Author: Xinzhuo Wu
Date: 2023-07-25 16:47:19
LastEditTime: 2024-04-08 09:59:17
LastEditors: Wenyu Ouyang
Date: 2024-04-08 18:13:05
LastEditTime: 2024-04-08 18:13:05
LastEditors: Xinzhuo Wu
Description: Test a full training and evaluating process
FilePath: \torchhydro\tests\test_train_mean_lstm.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
FilePath:/torchhydro/tests/test_train_mean_lstm.py
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
"""

import pytest
Expand All @@ -21,15 +21,13 @@ def config():
sub=project_name,
source_cfgs={
"source": "HydroMean",
"source_path": [
{
"forcing": "basins-origin/hour_data/1h/mean_data/mean_data_forcing",
"target": "basins-origin/hour_data/1h/mean_data/mean_data_target",
"attributes": "basins-origin/attributes.nc",
}
],
"source_path": {
"forcing": "basins-origin/hour_data/1h/mean_data/mean_data_forcing",
"target": "basins-origin/hour_data/1h/mean_data/mean_data_target",
"attributes": "basins-origin/attributes.nc",
},
},
ctx=[0],
ctx=[1],
model_name="SimpleLSTMForecast",
model_hyperparam={
"input_size": 16,
Expand Down
53 changes: 2 additions & 51 deletions torchhydro/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def default_config_file():
# sampler for pytorch dataloader, here we mainly use it for Kuai Fang's sampler in all his DL papers
"sampler": None,
"loading_batch": None,
"user": None,
},
"training_cfgs": {
# if train_mode is False, don't train and evaluate
Expand Down Expand Up @@ -336,12 +335,6 @@ def cmd(
ensemble_items=None,
early_stopping=None,
patience=None,
user=None,
endpoint_url=None,
access_key=None,
secret_key=None,
bucket_name=None,
folder_prefix=None,
):
"""input args from cmd"""
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -766,55 +759,13 @@ def cmd(
default=early_stopping,
type=bool,
)
parser.add_argument(
"--user",
dest="user",
help="user_name to distinguish trainer or tester",
default=user,
type=str,
)
parser.add_argument(
"--patience",
dest="patience",
help="patience config",
default=patience,
type=int,
)
parser.add_argument(
"--endpoint_url",
dest="endpoint_url",
help="endpoint_url",
default=endpoint_url,
type=str,
)
parser.add_argument(
"--access_key",
dest="access_key",
help="access_key",
default=access_key,
type=str,
)
parser.add_argument(
"--secret_key",
dest="secret_key",
help="secret_key",
default=secret_key,
type=str,
)
parser.add_argument(
"--bucket_name",
dest="bucket_name",
help="bucket_name",
default=bucket_name,
type=str,
)
parser.add_argument(
"--folder_prefix",
dest="folder_prefix",
help="folder_prefix",
default=folder_prefix,
type=str,
)
# To make pytest work in PyCharm, here we use the following code instead of "args = parser.parse_args()":
# https://blog.csdn.net/u014742995/article/details/100119905
args, unknown = parser.parse_known_args()
Expand Down Expand Up @@ -1087,8 +1038,8 @@ def update_cfg(cfg_file, new_args):
cfg_file["training_cfgs"]["patience"] = new_args.patience
if new_args.early_stopping is not None:
cfg_file["training_cfgs"]["early_stopping"] = new_args.early_stopping
if new_args.user is not None:
cfg_file["data_cfgs"]["user"] = new_args.user
if new_args.lr_scheduler is not None:
cfg_file["training_cfgs"]["lr_scheduler"] = new_args.lr_scheduler
# print("the updated config:\n", json.dumps(cfg_file, indent=4, ensure_ascii=False))


Expand Down
159 changes: 49 additions & 110 deletions torchhydro/datasets/data_scalers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Author: Wenyu Ouyang
Date: 2024-04-08 18:17:44
LastEditTime: 2024-04-08 18:17:44
LastEditors: Xinzhuo Wu
Description: normalize the data
FilePath: /torchhydro/torchhydro/datasets/data_scalers.py
Copyright (c) 2024-2024 Wenyu Ouyang. All rights reserved.
"""

import copy
import json
import os
Expand Down Expand Up @@ -220,6 +230,7 @@ def __init__(
self.log_norm_cols = gamma_norm_cols + prcp_norm_cols
self.pbm_norm = pbm_norm
self.data_source = data_source
self.read_meanprep()
# save stat_dict of training period in test_path for valid/test
stat_file = os.path.join(data_cfgs["test_path"], "dapengscaler_stat.json")
# for testing sometimes such as pub cases, we need stat_dict_file from trained dataset
Expand All @@ -236,6 +247,17 @@ def __init__(
with open(stat_file, "r") as fp:
self.stat_dict = json.load(fp)

def read_meanprep(self):
if isinstance(self.data_source, HydroBasins):
self.mean_prep = self.data_source.read_MP(
self.t_s_dict["sites_id"],
self.data_cfgs["source_cfgs"]["source_path"]["attributes"],
)
else:
self.mean_prep = self.data_source.read_mean_prcp(
self.t_s_dict["sites_id"]
).to_array()

def inverse_transform(self, target_values):
"""
Denormalization for output variables
Expand Down Expand Up @@ -265,41 +287,14 @@ def inverse_transform(self, target_values):
)
for i in range(len(self.data_cfgs["target_cols"])):
var = self.data_cfgs["target_cols"][i]
if var in self.prcp_norm_cols:
if var == "qobs_mm_per_hour": # should be deleted
mean_prep = self.read_attr_xrdataset(
self.t_s_dict["sites_id"], ["p_mean"]
)
pred.loc[dict(variable=var)] = _prcp_norm(
pred.sel(variable=var).to_numpy(),
mean_prep.to_array().to_numpy().T,
to_norm=False,
)
elif isinstance(self.data_source, HydroBasins):
mean_prep = self.data_source.read_MP(
self.t_s_dict["sites_id"],
self.data_cfgs["data_path"]["attributes"],
)
pred.loc[dict(variable=var)] = self.mean_prcp_norm(
pred.sel(variable=var).to_numpy().T,
mean_prep.to_numpy(),
to_norm=False,
)
else:
mean_prep = self.data_source.read_mean_prcp(
self.t_s_dict["sites_id"]
)
pred.loc[dict(variable=var)] = _prcp_norm(
pred.sel(variable=var).to_numpy(),
mean_prep.to_array().to_numpy().T,
to_norm=False,
)
pred.loc[dict(variable=var)] = _prcp_norm(
pred.sel(variable=var).to_numpy(),
self.mean_prep.to_numpy().T,
to_norm=False,
)
# add attrs for units
pred.attrs.update(self.data_target.attrs)
# trans to xarray dataset
pred_ds = pred.to_dataset(dim="variable")

return pred_ds
return pred.to_dataset(dim="variable")

def cal_stat_all(self):
"""
Expand All @@ -316,31 +311,10 @@ def cal_stat_all(self):
for i in range(len(target_cols)):
var = target_cols[i]
if var in self.prcp_norm_cols:
if var == "qobs_mm_per_hour": # should be deleted if data all in minio
mean_prep = self.read_attr_xrdataset(
self.t_s_dict["sites_id"], ["p_mean"]
)
stat_dict[var] = cal_stat_prcp_norm(
self.data_target.sel(variable=var).to_numpy(),
mean_prep.to_array().to_numpy().T,
)
elif isinstance(self.data_source, HydroBasins):
mean_prep = self.data_source.read_MP(
self.t_s_dict["sites_id"],
self.data_cfgs["data_path"]["attributes"],
)
stat_dict[var] = self.mean_cal_stat_prcp_norm(
self.data_target.sel(variable=var).to_numpy().T,
mean_prep.to_numpy(),
)
else:
mean_prep = self.data_source.read_mean_prcp(
self.t_s_dict["sites_id"]
)
stat_dict[var] = cal_stat_prcp_norm(
self.data_target.sel(variable=var).to_numpy(),
mean_prep.to_array().to_numpy().T,
)
stat_dict[var] = cal_stat_prcp_norm(
self.data_target.sel(variable=var).to_numpy(),
self.mean_prep.to_numpy().T,
)
elif var in self.gamma_norm_cols:
stat_dict[var] = cal_stat_gamma(
self.data_target.sel(variable=var).to_numpy()
Expand Down Expand Up @@ -392,34 +366,11 @@ def get_data_obs(self, to_norm: bool = True) -> np.array:
for i in range(len(target_cols)):
var = target_cols[i]
if var in self.prcp_norm_cols:
if var == "qobs_mm_per_hour": # should be deleted
mean_prep = self.read_attr_xrdataset(
self.t_s_dict["sites_id"], ["p_mean"]
)
out.loc[dict(variable=var)] = _prcp_norm(
data.sel(variable=var).to_numpy(),
mean_prep.to_array().to_numpy().T,
to_norm=True,
)
elif isinstance(self.data_source, HydroBasins):
mean_prep = self.data_source.read_MP(
self.t_s_dict["sites_id"],
self.data_cfgs["data_path"]["attributes"],
)
out.loc[dict(variable=var)] = self.mean_prcp_norm(
data.sel(variable=var).to_numpy().T,
mean_prep.to_numpy(),
to_norm=True,
)
else:
mean_prep = self.data_source.read_mean_prcp(
self.t_s_dict["sites_id"]
)
out.loc[dict(variable=var)] = _prcp_norm(
data.sel(variable=var).to_numpy(),
mean_prep.to_array().to_numpy().T,
to_norm=True,
)
out.loc[dict(variable=var)] = _prcp_norm(
data.sel(variable=var).to_numpy(),
self.mean_prep.to_numpy().T,
to_norm=True,
)
out.attrs["units"][var] = "dimensionless"
out = _trans_norm(
out,
Expand All @@ -430,13 +381,6 @@ def get_data_obs(self, to_norm: bool = True) -> np.array:
)
return out

# temporarily used, it's related to hydrodataset
def read_attr_xrdataset(self, gage_id_lst=None, var_lst=None, **kwargs):
if var_lst is None or len(var_lst) == 0:
return None
attr = xr.open_dataset(os.path.join("/ftproot", "camelsus_attributes_us.nc"))
return attr[var_lst].sel(basin=gage_id_lst)

def get_data_ts(self, to_norm=True) -> np.array:
"""
Get dynamic input data
Expand Down Expand Up @@ -499,17 +443,6 @@ def load_data(self):
c = self.get_data_const()
return x, y, c

def mean_cal_stat_prcp_norm(self, x, meanprep):
tempprep = np.tile(meanprep, (x.shape[0], 1))
flowua = (x / tempprep).T
return cal_stat_gamma(flowua)

def mean_prcp_norm(
self, x: np.array, mean_prep: np.array, to_norm: bool
) -> np.array:
tempprep = np.tile(mean_prep, (x.shape[0], 1))
return (x / tempprep).T if to_norm else (x * tempprep).T


class MutiBasinScaler(object):
def __init__(
Expand Down Expand Up @@ -620,9 +553,7 @@ def cal_stat_all(self):
if var in self.gamma_norm_cols:
stat_dict[var] = self.grid_cal_stat_gamma(data_smap, var)

# const attribute
attr_lst = self.data_cfgs["constant_cols"]
if attr_lst:
if attr_lst := self.data_cfgs["constant_cols"]:
data_attr = self.data_attr
for k in range(len(attr_lst)):
var = attr_lst[k]
Expand Down Expand Up @@ -688,8 +619,16 @@ def get_data_const(self, to_norm=True) -> np.array:

def load_data(self):
x = self.get_data_ts(0).compute()
g = self.get_data_ts(1).compute() if self.data_cfgs["relevant_cols"][1] != ["None"] else None
s = self.get_data_ts(2).compute() if self.data_cfgs["relevant_cols"][2] != ["None"] else None
g = (
self.get_data_ts(1).compute()
if self.data_cfgs["relevant_cols"][1] != ["None"]
else None
)
s = (
self.get_data_ts(2).compute()
if self.data_cfgs["relevant_cols"][2] != ["None"]
else None
)
y = self.get_data_obs().compute()
c = self.get_data_const().compute() if self.data_cfgs["constant_cols"] else None
return x, y, c, g, s
Expand Down Expand Up @@ -746,4 +685,4 @@ def grid_prcp_norm(
self, x: np.array, mean_prep: np.array, to_norm: bool
) -> np.array:
tempprep = np.tile(mean_prep, (x.shape[0], 1))
return x / tempprep if to_norm else x * tempprep
return x / tempprep if to_norm else x * tempprep
Loading

0 comments on commit b41796a

Please sign in to comment.