Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor variable scaling, pressure level scalings only applied in specific circumstances #52

Draft
wants to merge 17 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ training_loss:
# Available scalars include:
# - 'variable': See `variable_loss_scaling` for more information
# - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function
scalars: ['variable', 'loss_weights_mask']
# - 'tendency': See `additional_scalars` for more information
scalars: ['variable', 'variable_pressure_level', 'loss_weights_mask']

ignore_nans: False

Expand Down Expand Up @@ -109,33 +110,49 @@ lr:
# Variable loss scaling
# 'variable' must be included in `scalars` in the losses for this to be applied.
variable_loss_scaling:
variable_groups:
default: sfc
pl: [q, t, u, v, w, z]
default: 1
pl:
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sfc:
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025
additional_scalars:
# pressure level scalar
- _target_: anemoi.training.train.scaling.ReluVariableLevelScaler
group: pl
y_intercept: 0.2
slope: 0.001
scale_dim: -1 # dimension on which scaling applied
name: "variable_pressure_level"
# tendency scalers
# scale the prognostic losses by the stdev of the variable tendencies (e.g. the 6-hourly differences of the data)
# useful if including slow vs fast evolving variables in the training (e.g. Land/Ocean vs Atmosphere)
# if using this option 'variable_loss_scalings' should all be set close to 1.0 for prognostic variables
# stdev tendency scaler
# - _target_: anemoi.training.data.scaling.StdevTendencyScaler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be train.scaling I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have just pushed an update now 🙏

# scale_dim: -1 # dimension on which scaling applied
# name: "tendency"
# var tendency scaler (this should be default!?)
# - _target_: anemoi.training.data.scaling.VarTendencyScaler
# scale_dim: -1 # dimension on which scaling applied
# name: "tendency"

metrics:
- z_500
- t_850
- u_850
- v_850

pressure_level_scaler:
_target_: anemoi.training.data.scaling.ReluPressureLevelScaler
minimum: 0.2
slope: 0.001

node_loss_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
Expand Down
5 changes: 5 additions & 0 deletions training/src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
def statistics(self) -> dict:
return self.ds_train.statistics

@cached_property
def statistics_tendencies(self) -> dict:
return self.ds_train.statistics_tendencies

@cached_property
def metadata(self) -> dict:
return self.ds_train.metadata
Expand Down Expand Up @@ -183,6 +187,7 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
timestep=self.config.data.timestep,
shuffle=shuffle,
grid_indices=self.grid_indices,
label=label,
Expand Down
12 changes: 12 additions & 0 deletions training/src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
timestep: str = "6h",
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1,
Expand All @@ -57,6 +58,8 @@ def __init__(
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
timestep : int, optional
the time frequency of the samples, by default '6h'
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
shuffle : bool, optional
Expand All @@ -73,6 +76,7 @@ def __init__(

self.rollout = rollout
self.timeincrement = timeincrement
self.timestep = timestep
self.grid_indices = grid_indices

# lazy init
Expand Down Expand Up @@ -104,6 +108,14 @@ def statistics(self) -> dict:
"""Return dataset statistics."""
return self.data.statistics

@cached_property
def statistics_tendencies(self) -> dict:
"""Return dataset tendency statistics."""
try:
return self.data.statistics_tendencies(self.timestep)
except (KeyError, AttributeError):
return None

@cached_property
def metadata(self) -> dict:
"""Return dataset metadata."""
Expand Down
79 changes: 0 additions & 79 deletions training/src/anemoi/training/data/scaling.py

This file was deleted.

73 changes: 32 additions & 41 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Optional
from typing import Union

import numpy as np
import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
Expand All @@ -31,6 +30,7 @@
from anemoi.models.interface import AnemoiModelInterface
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.train.scaling import GeneralVariableLossScaler
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask
Expand All @@ -48,6 +48,7 @@ def __init__(
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
statistics_tendencies: dict,
data_indices: IndexCollection,
metadata: dict,
supporting_arrays: dict,
Expand Down Expand Up @@ -95,10 +96,36 @@ def __init__(
self.latlons_data = graph_data[config.graph.data].x
self.node_weights = self.get_node_weights(config, graph_data)
self.node_weights = self.output_mask.apply(self.node_weights, dim=0, fill_value=0.0)
self.statistics_tendencies = statistics_tendencies

self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled

variable_scaling = self.get_variable_scaling(config, data_indices)
variable_scaling = GeneralVariableLossScaler(
config.training.variable_loss_scaling,
data_indices,
).get_variable_scaling()

# Instantiate the pressure level scaling class with the training configuration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is possible but I wonder if here we could instantiate from a list of scalars rather than specific ones? I think this is how it is done for the validation metrics

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this would be useful, as we want to allow more scalar methods for model levels/tendency/etc.
I will have a look.

config_container = OmegaConf.to_container(config.training.additional_scalars, resolve=False)
if isinstance(config_container, list):
scalar = [
(
instantiate(
scalar_config,
scaling_config=config.training.variable_loss_scaling,
data_indices=data_indices,
statistics=statistics,
statistics_tendencies=statistics_tendencies,
)
if scalar_config["name"] == "tendency"
else instantiate(
scalar_config,
scaling_config=config.training.variable_loss_scaling,
data_indices=data_indices,
)
)
for scalar_config in config_container
]

self.internal_metric_ranges, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices)

Expand All @@ -120,6 +147,9 @@ def __init__(
"loss_weights_mask": ((-2, -1), torch.ones((1, 1))),
"limited_area_mask": (2, limited_area_mask),
}
# add addtional user-defined scalars
[self.scalars.update({scale.name: (scale.scale_dim, scale.get_variable_scaling())}) for scale in scalar]

self.updated_loss_mask = False

self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs)
Expand Down Expand Up @@ -299,45 +329,6 @@ def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) ->

return metric_ranges, metric_ranges_validation

@staticmethod
def get_variable_scaling(
config: DictConfig,
data_indices: IndexCollection,
) -> torch.Tensor:
variable_loss_scaling = (
np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32)
* config.training.variable_loss_scaling.default
)
pressure_level = instantiate(config.training.pressure_level_scaler)

LOGGER.info(
"Pressure level scaling: use scaler %s with slope %.4f and minimum %.2f",
type(pressure_level).__name__,
pressure_level.slope,
pressure_level.minimum,
)

for key, idx in data_indices.internal_model.output.name_to_index.items():
split = key.split("_")
if len(split) > 1 and split[-1].isdigit():
# Apply pressure level scaling
if split[0] in config.training.variable_loss_scaling.pl:
variable_loss_scaling[idx] = config.training.variable_loss_scaling.pl[
split[0]
] * pressure_level.scaler(
int(split[-1]),
)
else:
LOGGER.debug("Parameter %s was not scaled.", key)
else:
# Apply surface variable scaling
if key in config.training.variable_loss_scaling.sfc:
variable_loss_scaling[idx] = config.training.variable_loss_scaling.sfc[key]
else:
LOGGER.debug("Parameter %s was not scaled.", key)

return torch.from_numpy(variable_loss_scaling)

@staticmethod
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weighting = instantiate(config.training.node_loss_weights)
Expand Down
Loading
Loading