-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f76de65
commit 09550be
Showing
2 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
Author: Wenyu Ouyang | ||
Date: 2023-10-05 16:16:48 | ||
LastEditTime: 2023-10-20 19:59:38 | ||
LastEditors: Wenyu Ouyang | ||
Description: Transfer learning for local basins with hydro_opendata | ||
FilePath: \torchhydro\tests\test_tl_opendata.py | ||
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved. | ||
""" | ||
import os | ||
import pytest | ||
import hydrodataset as hds | ||
from hydroutils.hydro_file import get_lastest_file_in_a_dir | ||
from torchhydro.configs.config import cmd, default_config_file, update_cfg | ||
from torchhydro.trainers.trainer import train_and_evaluate | ||
|
||
|
||
@pytest.fixture() | ||
def var_c_target(): | ||
return [ | ||
"elev_mean", | ||
"slope_mean", | ||
"area_gages2", | ||
"frac_forest", | ||
"lai_max", | ||
"lai_diff", | ||
] | ||
|
||
|
||
@pytest.fixture() | ||
def var_c_source(): | ||
return [ | ||
"elev_mean", | ||
"slope_mean", | ||
"area_gages2", | ||
"frac_forest", | ||
"lai_max", | ||
"lai_diff", | ||
"dom_land_cover_frac", | ||
"dom_land_cover", | ||
"root_depth_50", | ||
"soil_depth_statsgo", | ||
"soil_porosity", | ||
"soil_conductivity", | ||
"max_water_content", | ||
"geol_1st_class", | ||
"geol_2nd_class", | ||
"geol_porostiy", | ||
"geol_permeability", | ||
] | ||
|
||
|
||
@pytest.fixture() | ||
def var_t_target(): | ||
return ["dayl", "prcp", "srad"] | ||
|
||
|
||
@pytest.fixture() | ||
def var_t_source(): | ||
return ["dayl", "prcp", "srad", "tmax", "tmin", "vp"] | ||
|
||
|
||
def test_transfer_gages_lstm_model( | ||
var_c_source, var_c_target, var_t_source, var_t_target | ||
): | ||
weight_dir = os.path.join( | ||
os.getcwd(), | ||
"results", | ||
"test_camels", | ||
"exp1", | ||
) | ||
weight_path = get_lastest_file_in_a_dir(weight_dir) | ||
project_name = "test_caravan/exp6" | ||
args = cmd( | ||
sub=project_name, | ||
source="Caravan", | ||
source_path=os.path.join(hds.ROOT_DIR, "caravan"), | ||
source_region="Global", | ||
download=0, | ||
ctx=[0], | ||
model_type="TransLearn", | ||
model_name="KaiLSTM", | ||
model_hyperparam={ | ||
"linear_size": len(var_c_target) + len(var_t_target), | ||
"n_input_features": len(var_c_source) + len(var_t_source), | ||
"n_output_features": 1, | ||
"n_hidden_states": 256, | ||
}, | ||
opt="Adadelta", | ||
loss_func="RMSESum", | ||
batch_size=5, | ||
rho=20, | ||
rs=1234, | ||
train_period=["2010-10-01", "2011-10-01"], | ||
test_period=["2011-10-01", "2012-10-01"], | ||
scaler="DapengScaler", | ||
sampler="KuaiSampler", | ||
dataset="StreamflowDataset", | ||
weight_path=weight_path, | ||
weight_path_add={ | ||
"freeze_params": ["lstm.b_hh", "lstm.b_ih", "lstm.w_hh", "lstm.w_ih"] | ||
}, | ||
continue_train=True, | ||
train_epoch=20, | ||
te=20, | ||
save_epoch=10, | ||
var_t=var_t_target, | ||
var_c=var_c_target, | ||
var_out=["streamflow"], | ||
gage_id=[ | ||
"01055000", | ||
"01057000", | ||
"01170100", | ||
], | ||
) | ||
cfg = default_config_file() | ||
update_cfg(cfg, args) | ||
train_and_evaluate(cfg) | ||
print("All processes are finished!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from abc import ABC | ||
from pathlib import Path | ||
from typing import Union | ||
import hydrodataset as hds | ||
from hydrodataset import HydroDataset | ||
import numpy as np | ||
|
||
|
||
class HydroData(ABC): | ||
"""An interface for reading multi-modal data sources. | ||
Parameters | ||
---------- | ||
ABC : _type_ | ||
_description_ | ||
""" | ||
|
||
def __init__(self, data_path): | ||
self.data_source_dir = Path(hds.ROOT_DIR, data_path) | ||
if not self.data_source_dir.is_dir(): | ||
self.data_source_dir.mkdir(parents=True) | ||
|
||
def get_name(self): | ||
raise NotImplementedError | ||
|
||
def set_data_source_describe(self): | ||
raise NotImplementedError | ||
|
||
def check_data_ready(self): | ||
raise NotImplementedError | ||
|
||
def read_input(self): | ||
raise NotImplementedError | ||
|
||
def read_target(self): | ||
raise NotImplementedError | ||
|
||
|
||
class HydroOpendata(HydroData): | ||
"""A class for reading public data sources. | ||
Typically, we read GPM/GFS/ERA5/SMAP/NOAA weather station data as forcing data for hydrological models. | ||
read USGS NWIS data as input/target data for hydrological models. | ||
Parameters | ||
---------- | ||
HydroData : _type_ | ||
_description_ | ||
""" | ||
|
||
def __init__(self, data_path): | ||
super().__init__(data_path) | ||
|
||
def get_name(self): | ||
return "HydroOpendata" | ||
|
||
def set_data_source_describe(self): | ||
self.grid_data_source = "MINIO" | ||
self.grid_data_source_url = ( | ||
"https://s3.us-east-2.amazonaws.com/minio.t-integration.cloud.tibco.com" | ||
) | ||
self.grid_data_source_bucket = "test" | ||
self.ts_data_source = "Local" | ||
|
||
def check_data_ready(self): | ||
raise NotImplementedError | ||
|
||
def read_input(self): | ||
raise NotImplementedError | ||
|
||
def read_target(self): | ||
raise NotImplementedError | ||
|
||
|
||
class HydroDatasetSim(HydroDataset): | ||
"""A class for reading hydrodataset, but not really ready datasets, | ||
just some data directorys organized like a ready dataset. | ||
Typically, we read data from our self-made data. | ||
Parameters | ||
---------- | ||
HydroData : _type_ | ||
_description_ | ||
""" | ||
|
||
def __init__(self, data_path): | ||
super().__init__(data_path) | ||
# the naming convention for basin ids are needed | ||
# we use GRDC station's ids as our default coding convention | ||
# GRDC station ids are 7 digits, the first 1 digit is continent code, | ||
# the second 4 digits are sub-region related code | ||
# | ||
|
||
def get_name(self): | ||
return "HydroDatasetSim" | ||
|
||
def set_data_source_describe(self): | ||
self.attr_data_dir = Path(self.data_source_dir, "attr") | ||
self.forcing_data_dir = Path(self.data_source_dir, "forcing") | ||
self.streamflow_data_dir = Path(self.data_source_dir, "streamflow") | ||
|
||
def read_object_ids(self, object_params=None) -> np.array: | ||
|
||
raise NotImplementedError | ||
|
||
def read_target_cols( | ||
self, object_ids=None, t_range_list=None, target_cols=None, **kwargs | ||
) -> np.array: | ||
raise NotImplementedError | ||
|
||
def read_relevant_cols( | ||
self, object_ids=None, t_range_list: list = None, relevant_cols=None, **kwargs | ||
) -> Union[np.array, list]: | ||
"""3d data (site_num * time_length * var_num), time-series data""" | ||
raise NotImplementedError | ||
|
||
def read_constant_cols( | ||
self, object_ids=None, constant_cols=None, **kwargs | ||
) -> np.array: | ||
"""2d data (site_num * var_num), non-time-series data""" | ||
raise NotImplementedError | ||
|
||
def read_other_cols( | ||
self, object_ids=None, other_cols: dict = None, **kwargs | ||
) -> dict: | ||
"""some data which cannot be easily treated as constant vars or time-series with same length as relevant vars | ||
CONVENTION: other_cols is a dict, where each item is also a dict with all params in it""" | ||
raise NotImplementedError | ||
|
||
def get_constant_cols(self) -> np.array: | ||
"""the constant cols in this data_source""" | ||
raise NotImplementedError | ||
|
||
def get_relevant_cols(self) -> np.array: | ||
"""the relevant cols in this data_source""" | ||
raise NotImplementedError | ||
|
||
def get_target_cols(self) -> np.array: | ||
"""the target cols in this data_source""" | ||
raise NotImplementedError | ||
|
||
def get_other_cols(self) -> dict: | ||
"""the other cols in this data_source""" | ||
raise NotImplementedError | ||
|
||
def cache_xrdataset(self, **kwargs): | ||
"""cache xarray dataset and pandas feather for faster reading""" | ||
raise NotImplementedError | ||
|
||
def read_ts_xrdataset( | ||
self, | ||
gage_id_lst: list = None, | ||
t_range: list = None, | ||
var_lst: list = None, | ||
**kwargs | ||
): | ||
"""read time-series xarray dataset""" | ||
raise NotImplementedError | ||
|
||
def read_attr_xrdataset(self, gage_id_lst=None, var_lst=None, **kwargs): | ||
"""read attribute pandas feather""" | ||
raise NotImplementedError | ||
|
||
def read_area(self, gage_id_lst=None): | ||
"""read area of each basin/unit""" | ||
raise NotImplementedError | ||
|
||
def read_mean_prcp(self, gage_id_lst=None): | ||
"""read mean precipitation of each basin/unit""" | ||
raise NotImplementedError |