Skip to content

Commit

Permalink
reset rolling and add activation for amm
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Jan 1, 2025
1 parent 61d3eba commit ea62a55
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 68 deletions.
2 changes: 1 addition & 1 deletion experiments/evaluate_with_era5land.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_config_data():
},
test_period=[("2015-06-01-01", "2016-05-31-01")],
which_first_tensor="batch",
rolling=True,
rolling=56,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
2 changes: 1 addition & 1 deletion experiments/evaluate_with_gpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_config_data():
},
test_period=[("2015-06-01-01", "2016-06-01-01")],
which_first_tensor="batch",
rolling=True,
rolling=56,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
2 changes: 1 addition & 1 deletion experiments/evaluate_with_gpm_streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_config_data():
},
test_period=[("2015-06-01-01", "2016-05-31-01")],
which_first_tensor="batch",
rolling=True,
rolling=56,
weight_path=os.path.join(train_path, "best_model.pth"),
stat_dict_file=os.path.join(train_path, "dapengscaler_stat.json"),
train_mode=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def seq2seq_config():
"lr_factor": 0.9,
},
which_first_tensor="batch",
rolling=True,
rolling=56,
calc_metrics=False,
early_stopping=True,
# ensemble=True,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_ann.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Author: Wenyu Ouyang
Date: 2025-01-01 10:20:02
LastEditTime: 2025-01-01 10:48:40
LastEditors: Wenyu Ouyang
Description: test function for multi-layer perceptron model
FilePath: \torchhydro\tests\test_ann.py
Copyright (c) 2023-2024 Wenyu Ouyang. All rights reserved.
"""

import pytest
from torch import nn
from torchhydro.models.ann import SimpleAnn
import torch


def test_get_activation_tanh():
model = SimpleAnn(1, 1)
activation = model._get_activation("tanh")
assert isinstance(activation, nn.Tanh)


def test_get_activation_sigmoid():
model = SimpleAnn(1, 1)
activation = model._get_activation("sigmoid")
assert isinstance(activation, nn.Sigmoid)


def test_get_activation_relu():
model = SimpleAnn(1, 1)
activation = model._get_activation("relu")
assert isinstance(activation, nn.ReLU)


def test_get_activation_linear():
model = SimpleAnn(1, 1)
activation = model._get_activation("linear")
assert isinstance(activation, nn.Identity)


def test_get_activation_not_implemented():
model = SimpleAnn(1, 1)
with pytest.raises(NotImplementedError):
model._get_activation("unsupported_activation")


def test_forward_single_layer():
model = SimpleAnn(3, 2, hidden_size=0)
x = torch.randn(5, 3)
output = model.forward(x)
assert output.shape == (5, 2)


def test_forward_multiple_layers():
model = SimpleAnn(3, 2, hidden_size=[4, 5], dr=[0.1, 0.2])
x = torch.randn(5, 3)
output = model.forward(x)
assert output.shape == (5, 2)


def test_forward_with_dropout():
model = SimpleAnn(3, 2, hidden_size=[4, 5], dr=[0.1, 0.2])
model.train() # Enable dropout
x = torch.randn(5, 3)
output = model.forward(x)
assert output.shape == (5, 2)


def test_forward_activation():
model = SimpleAnn(3, 2, hidden_size=[4, 5], activation="sigmoid")
x = torch.randn(5, 3)
output = model.forward(x)
assert output.shape == (5, 2)
2 changes: 1 addition & 1 deletion tests/test_seqforecast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2024-12-31 18:20:31
LastEditTime: 2024-12-31 19:11:26
LastEditTime: 2025-01-01 11:08:29
LastEditors: Wenyu Ouyang
Description:
FilePath: \torchhydro\tests\test_seqforecast.py
Expand Down
16 changes: 8 additions & 8 deletions torchhydro/configs/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2021-12-31 11:08:29
LastEditTime: 2024-11-05 10:46:09
LastEditTime: 2025-01-01 09:21:21
LastEditors: Wenyu Ouyang
Description: Config for hydroDL
FilePath: \torchhydro\torchhydro\configs\config.py
Expand Down Expand Up @@ -286,11 +286,11 @@ def default_config_file():
"metrics": ["NSE", "RMSE", "R2", "KGE", "FHV", "FLV"],
"fill_nan": "no",
"explainer": None,
# rolling means testdataloader will sample data with overlap time
# rolling is False meaning each time has only one output for one basin one variable
# rolling is True and the time_window must be prec_window+horizon now!
# for example, data is |1|2|3|4| time_window=2 then the samples are |1|2|, |2|3| and |3|4|
"rolling": False,
# rolling is 0 means decoder-only model's prediction -- each period has one prediction
# when rolling>0, such as 1, means perform forecasting each step after 1 period.
# For example, at 8:00am we perform one forecasting and our time-step is 3h,
# rolling=1 means 11:00, 14:00, 17:00 ..., we will perform forecasting
"rolling": 0,
"calc_metrics": True,
},
}
Expand Down Expand Up @@ -735,9 +735,9 @@ def cmd(
parser.add_argument(
"--rolling",
dest="rolling",
help="if False, evaluate 1-period output with a rolling window",
help="0 means no rolling; rolling>0, such as 1, means perform forecasting once after 1 period. For example, at 8:00am we perform one forecasting and our time-step is 3h, rolling=1 means 11:00, 14:00, 17:00 ..., we will perform forecasting",
default=rolling,
type=bool,
type=int,
)
parser.add_argument(
"--model_loader",
Expand Down
87 changes: 65 additions & 22 deletions torchhydro/models/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
FilePath: \torchhydro\torchhydro\models\ann.py
Copyright (c) 2021-2022 Wenyu Ouyang. All rights reserved.
"""

from typing import Union

import torch
import torch.nn.functional as F
from torch import nn


class SimpleAnn(torch.nn.Module):
def __init__(self, nx: int, ny: int, hidden_size: Union[int, tuple, list] = None,
dr: Union[float, tuple, list] = 0.0):
class SimpleAnn(nn.Module):
def __init__(
self,
nx: int,
ny: int,
hidden_size: Union[int, tuple, list] = None,
dr: Union[float, tuple, list] = 0.0,
activation: str = "relu",
):
"""
A simple multi-layer NN model with final linear layer
Expand All @@ -31,56 +37,93 @@ def __init__(self, nx: int, ny: int, hidden_size: Union[int, tuple, list] = None
dr
dropout rate of layers, default is 0.0 which means no dropout;
here we set number of dropout layers to (number of nn layers - 1)
activation
activation function for hidden layers, default is "relu"
"""
super(SimpleAnn, self).__init__()
linear_list = torch.nn.ModuleList()
dropout_list = torch.nn.ModuleList()
linear_list = nn.ModuleList()
dropout_list = nn.ModuleList()
if (
hidden_size is None
or (type(hidden_size) is int and hidden_size == 0)
or (type(hidden_size) in [tuple, list] and len(hidden_size) < 1)
hidden_size is None
or (type(hidden_size) is int and hidden_size == 0)
or (type(hidden_size) in [tuple, list] and len(hidden_size) < 1)
):
linear_list.add_module("linear1", torch.nn.Linear(nx, ny))
linear_list.add_module("linear1", nn.Linear(nx, ny))
elif type(hidden_size) is int:
if type(dr) in [tuple, list]:
dr = dr[0]
linear_list.add_module("linear1", torch.nn.Linear(nx, hidden_size))
linear_list.add_module("linear1", nn.Linear(nx, hidden_size))
# dropout layer do not have additional weights, so we do not name them here
dropout_list.append(torch.nn.Dropout(dr))
linear_list.add_module("linear2", torch.nn.Linear(hidden_size, ny))
dropout_list.append(nn.Dropout(dr))
linear_list.add_module("linear2", nn.Linear(hidden_size, ny))
else:
linear_list.add_module("linear1", torch.nn.Linear(nx, hidden_size[0]))
linear_list.add_module("linear1", nn.Linear(nx, hidden_size[0]))
if type(dr) is float:
dr = [dr] * len(hidden_size)
elif len(dr) != len(hidden_size):
raise ArithmeticError(
"We set dropout layer for each nn layer, please check the number of dropout layers")
"We set dropout layer for each nn layer, please check the number of dropout layers"
)
# dropout_list.add_module("dropout1", torch.nn.Dropout(dr[0]))
dropout_list.append(torch.nn.Dropout(dr[0]))
dropout_list.append(nn.Dropout(dr[0]))
for i in range(len(hidden_size) - 1):
linear_list.add_module(
"linear%d" % (i + 1 + 1),
torch.nn.Linear(hidden_size[i], hidden_size[i + 1]),
nn.Linear(hidden_size[i], hidden_size[i + 1]),
)
dropout_list.append(
torch.nn.Dropout(dr[i + 1]),
nn.Dropout(dr[i + 1]),
)
linear_list.add_module(
"linear%d" % (len(hidden_size) + 1),
torch.nn.Linear(hidden_size[-1], ny),
nn.Linear(hidden_size[-1], ny),
)
self.linear_list = linear_list
self.dropout_list = dropout_list
self.activation = self._get_activation(activation)

def forward(self, x):
for i, model in enumerate(self.linear_list):
if i == 0:
if len(self.linear_list) == 1:
return model(x)
else:
out = F.relu(self.dropout_list[i](model(x)))
out = self.activation(self.dropout_list[i](model(x)))
elif i == len(self.linear_list) - 1:
# in final layer, no relu again
return model(out)
else:
out = F.relu(self.dropout_list[i](model(out)))
out = self.activation(self.dropout_list[i](model(out)))

def _get_activation(self, name: str) -> nn.Module:
"""a function to get activation function by name, reference from:
https://github.com/neuralhydrology/neuralhydrology/blob/master/neuralhydrology/modelzoo/fc.py
Parameters
----------
name : str
_description_
Returns
-------
nn.Module
_description_
Raises
------
NotImplementedError
_description_
"""
if name.lower() == "tanh":
activation = nn.Tanh()
elif name.lower() == "sigmoid":
activation = nn.Sigmoid()
elif name.lower() == "relu":
activation = nn.ReLU()
elif name.lower() == "linear":
activation = nn.Identity()
else:
raise NotImplementedError(
f"{name} currently not supported as activation in this class"
)
return activation
Loading

0 comments on commit ea62a55

Please sign in to comment.