diff --git a/tests/test_seq2seq.py b/tests/test_seq2seq.py index eff92fc..0f29048 100644 --- a/tests/test_seq2seq.py +++ b/tests/test_seq2seq.py @@ -1,7 +1,7 @@ """ Author: Wenyu Ouyang Date: 2024-04-17 12:55:24 -LastEditTime: 2024-11-01 12:02:49 +LastEditTime: 2024-11-01 18:02:29 LastEditors: Wenyu Ouyang Description: Test funcs for seq2seq model FilePath: \torchhydro\tests\test_seq2seq.py @@ -21,6 +21,7 @@ import xarray as xr import torch.multiprocessing as mp +from torchhydro import SETTING from torchhydro.configs.config import cmd, default_config_file, update_cfg from torchhydro.trainers.deep_hydro import train_worker from torchhydro.trainers.trainer import train_and_evaluate @@ -32,11 +33,6 @@ logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) -show = pd.read_csv( - os.path.join(pathlib.Path(__file__).parent.parent, "data/basin_id(all).csv"), - dtype={"id": str}, -) -# gage_id = show["id"].values.tolist() gage_id = [ "songliao_21401050", "songliao_21401550", @@ -45,16 +41,13 @@ @pytest.fixture() def config(): - # 设置测试所需的项目名称和默认配置文件 project_name = os.path.join("train_with_gpm", "ex_test") config_data = default_config_file() - - # 填充测试所需的命令行参数 args = cmd( sub=project_name, source_cfgs={ "source": "HydroMean", - "source_path": "/ftproot/basins-interim/", + "source_path": SETTING["local_data_path"]["datasets-interim"], }, ctx=[0], model_name="Seq2Seq", @@ -102,7 +95,7 @@ def config(): scaler="DapengScaler", train_epoch=2, save_epoch=1, - train_mode=False, + train_mode=True, train_period=[("2016-06-01-01", "2016-08-01-01")], test_period=[("2015-06-01-01", "2015-08-01-01")], valid_period=[("2015-06-01-01", "2015-08-01-01")],