Skip to content

Commit

Permalink
framework for data source
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Oct 23, 2023
1 parent f76de65 commit 09550be
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 0 deletions.
119 changes: 119 additions & 0 deletions tests/test_tl_opendata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Author: Wenyu Ouyang
Date: 2023-10-05 16:16:48
LastEditTime: 2023-10-20 19:59:38
LastEditors: Wenyu Ouyang
Description: Transfer learning for local basins with hydro_opendata
FilePath: \torchhydro\tests\test_tl_opendata.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""
import os
import pytest
import hydrodataset as hds
from hydroutils.hydro_file import get_lastest_file_in_a_dir
from torchhydro.configs.config import cmd, default_config_file, update_cfg
from torchhydro.trainers.trainer import train_and_evaluate


@pytest.fixture()
def var_c_target():
return [
"elev_mean",
"slope_mean",
"area_gages2",
"frac_forest",
"lai_max",
"lai_diff",
]


@pytest.fixture()
def var_c_source():
return [
"elev_mean",
"slope_mean",
"area_gages2",
"frac_forest",
"lai_max",
"lai_diff",
"dom_land_cover_frac",
"dom_land_cover",
"root_depth_50",
"soil_depth_statsgo",
"soil_porosity",
"soil_conductivity",
"max_water_content",
"geol_1st_class",
"geol_2nd_class",
"geol_porostiy",
"geol_permeability",
]


@pytest.fixture()
def var_t_target():
return ["dayl", "prcp", "srad"]


@pytest.fixture()
def var_t_source():
return ["dayl", "prcp", "srad", "tmax", "tmin", "vp"]


def test_transfer_gages_lstm_model(
var_c_source, var_c_target, var_t_source, var_t_target
):
weight_dir = os.path.join(
os.getcwd(),
"results",
"test_camels",
"exp1",
)
weight_path = get_lastest_file_in_a_dir(weight_dir)
project_name = "test_caravan/exp6"
args = cmd(
sub=project_name,
source="Caravan",
source_path=os.path.join(hds.ROOT_DIR, "caravan"),
source_region="Global",
download=0,
ctx=[0],
model_type="TransLearn",
model_name="KaiLSTM",
model_hyperparam={
"linear_size": len(var_c_target) + len(var_t_target),
"n_input_features": len(var_c_source) + len(var_t_source),
"n_output_features": 1,
"n_hidden_states": 256,
},
opt="Adadelta",
loss_func="RMSESum",
batch_size=5,
rho=20,
rs=1234,
train_period=["2010-10-01", "2011-10-01"],
test_period=["2011-10-01", "2012-10-01"],
scaler="DapengScaler",
sampler="KuaiSampler",
dataset="StreamflowDataset",
weight_path=weight_path,
weight_path_add={
"freeze_params": ["lstm.b_hh", "lstm.b_ih", "lstm.w_hh", "lstm.w_ih"]
},
continue_train=True,
train_epoch=20,
te=20,
save_epoch=10,
var_t=var_t_target,
var_c=var_c_target,
var_out=["streamflow"],
gage_id=[
"01055000",
"01057000",
"01170100",
],
)
cfg = default_config_file()
update_cfg(cfg, args)
train_and_evaluate(cfg)
print("All processes are finished!")
171 changes: 171 additions & 0 deletions torchhydro/datasets/data_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from abc import ABC
from pathlib import Path
from typing import Union
import hydrodataset as hds
from hydrodataset import HydroDataset
import numpy as np


class HydroData(ABC):
"""An interface for reading multi-modal data sources.
Parameters
----------
ABC : _type_
_description_
"""

def __init__(self, data_path):
self.data_source_dir = Path(hds.ROOT_DIR, data_path)
if not self.data_source_dir.is_dir():
self.data_source_dir.mkdir(parents=True)

def get_name(self):
raise NotImplementedError

def set_data_source_describe(self):
raise NotImplementedError

def check_data_ready(self):
raise NotImplementedError

def read_input(self):
raise NotImplementedError

def read_target(self):
raise NotImplementedError


class HydroOpendata(HydroData):
"""A class for reading public data sources.
Typically, we read GPM/GFS/ERA5/SMAP/NOAA weather station data as forcing data for hydrological models.
read USGS NWIS data as input/target data for hydrological models.
Parameters
----------
HydroData : _type_
_description_
"""

def __init__(self, data_path):
super().__init__(data_path)

def get_name(self):
return "HydroOpendata"

def set_data_source_describe(self):
self.grid_data_source = "MINIO"
self.grid_data_source_url = (
"https://s3.us-east-2.amazonaws.com/minio.t-integration.cloud.tibco.com"
)
self.grid_data_source_bucket = "test"
self.ts_data_source = "Local"

def check_data_ready(self):
raise NotImplementedError

def read_input(self):
raise NotImplementedError

def read_target(self):
raise NotImplementedError


class HydroDatasetSim(HydroDataset):
"""A class for reading hydrodataset, but not really ready datasets,
just some data directorys organized like a ready dataset.
Typically, we read data from our self-made data.
Parameters
----------
HydroData : _type_
_description_
"""

def __init__(self, data_path):
super().__init__(data_path)
# the naming convention for basin ids are needed
# we use GRDC station's ids as our default coding convention
# GRDC station ids are 7 digits, the first 1 digit is continent code,
# the second 4 digits are sub-region related code
#

def get_name(self):
return "HydroDatasetSim"

def set_data_source_describe(self):
self.attr_data_dir = Path(self.data_source_dir, "attr")
self.forcing_data_dir = Path(self.data_source_dir, "forcing")
self.streamflow_data_dir = Path(self.data_source_dir, "streamflow")

def read_object_ids(self, object_params=None) -> np.array:

raise NotImplementedError

def read_target_cols(
self, object_ids=None, t_range_list=None, target_cols=None, **kwargs
) -> np.array:
raise NotImplementedError

def read_relevant_cols(
self, object_ids=None, t_range_list: list = None, relevant_cols=None, **kwargs
) -> Union[np.array, list]:
"""3d data (site_num * time_length * var_num), time-series data"""
raise NotImplementedError

def read_constant_cols(
self, object_ids=None, constant_cols=None, **kwargs
) -> np.array:
"""2d data (site_num * var_num), non-time-series data"""
raise NotImplementedError

def read_other_cols(
self, object_ids=None, other_cols: dict = None, **kwargs
) -> dict:
"""some data which cannot be easily treated as constant vars or time-series with same length as relevant vars
CONVENTION: other_cols is a dict, where each item is also a dict with all params in it"""
raise NotImplementedError

def get_constant_cols(self) -> np.array:
"""the constant cols in this data_source"""
raise NotImplementedError

def get_relevant_cols(self) -> np.array:
"""the relevant cols in this data_source"""
raise NotImplementedError

def get_target_cols(self) -> np.array:
"""the target cols in this data_source"""
raise NotImplementedError

def get_other_cols(self) -> dict:
"""the other cols in this data_source"""
raise NotImplementedError

def cache_xrdataset(self, **kwargs):
"""cache xarray dataset and pandas feather for faster reading"""
raise NotImplementedError

def read_ts_xrdataset(
self,
gage_id_lst: list = None,
t_range: list = None,
var_lst: list = None,
**kwargs
):
"""read time-series xarray dataset"""
raise NotImplementedError

def read_attr_xrdataset(self, gage_id_lst=None, var_lst=None, **kwargs):
"""read attribute pandas feather"""
raise NotImplementedError

def read_area(self, gage_id_lst=None):
"""read area of each basin/unit"""
raise NotImplementedError

def read_mean_prcp(self, gage_id_lst=None):
"""read mean precipitation of each basin/unit"""
raise NotImplementedError

0 comments on commit 09550be

Please sign in to comment.