-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor variable scaling, pressure level scalings only applied in specific circumstances #52
base: develop
Are you sure you want to change the base?
Changes from 3 commits
511ed18
7ddf6d6
3ddeccc
be4602c
195af07
2644c18
718fc57
a34ac02
b91af11
c22c50b
1f4a532
2843d98
c978871
f56f9b2
be90000
e474ae9
f005f84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
from typing import Optional | ||
from typing import Union | ||
|
||
import numpy as np | ||
import pytorch_lightning as pl | ||
import torch | ||
from hydra.utils import instantiate | ||
|
@@ -31,6 +30,7 @@ | |
from anemoi.models.interface import AnemoiModelInterface | ||
from anemoi.training.losses.utils import grad_scaler | ||
from anemoi.training.losses.weightedloss import BaseWeightedLoss | ||
from anemoi.training.train.scaling import GeneralVariableLossScaler | ||
from anemoi.training.utils.jsonify import map_config_to_primitives | ||
from anemoi.training.utils.masks import Boolean1DMask | ||
from anemoi.training.utils.masks import NoOutputMask | ||
|
@@ -98,7 +98,18 @@ def __init__( | |
|
||
self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled | ||
|
||
variable_scaling = self.get_variable_scaling(config, data_indices) | ||
variable_scaling = GeneralVariableLossScaler( | ||
config.training.variable_loss_scaling, | ||
data_indices, | ||
).get_variable_scaling() | ||
|
||
# Instantiate the pressure level scaling class with the training configuration | ||
pressurelevelscaler = instantiate( | ||
config.training.pressure_level_scaler, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this config location is wrong? It should be training.variable_loss_scaling.pressure_level_scalar There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, in the config file, the pressure_level_scalar should be defined directly in training as before. |
||
scaling_config=config.training.variable_loss_scaling, | ||
data_indices=data_indices, | ||
) | ||
pressure_level_scaling = pressurelevelscaler.get_variable_scaling() | ||
|
||
self.internal_metric_ranges, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices) | ||
|
||
|
@@ -117,6 +128,7 @@ def __init__( | |
# Filled after first application of preprocessor. dimension=[-2, -1] (latlon, n_outputs). | ||
self.scalars = { | ||
"variable": (-1, variable_scaling), | ||
"variable_pressure_level": (-1, pressure_level_scaling), | ||
"loss_weights_mask": ((-2, -1), torch.ones((1, 1))), | ||
"limited_area_mask": (2, limited_area_mask), | ||
} | ||
|
@@ -299,45 +311,6 @@ def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> | |
|
||
return metric_ranges, metric_ranges_validation | ||
|
||
@staticmethod | ||
def get_variable_scaling( | ||
config: DictConfig, | ||
data_indices: IndexCollection, | ||
) -> torch.Tensor: | ||
variable_loss_scaling = ( | ||
np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32) | ||
* config.training.variable_loss_scaling.default | ||
) | ||
pressure_level = instantiate(config.training.pressure_level_scaler) | ||
|
||
LOGGER.info( | ||
"Pressure level scaling: use scaler %s with slope %.4f and minimum %.2f", | ||
type(pressure_level).__name__, | ||
pressure_level.slope, | ||
pressure_level.minimum, | ||
) | ||
|
||
for key, idx in data_indices.internal_model.output.name_to_index.items(): | ||
split = key.split("_") | ||
if len(split) > 1 and split[-1].isdigit(): | ||
# Apply pressure level scaling | ||
if split[0] in config.training.variable_loss_scaling.pl: | ||
variable_loss_scaling[idx] = config.training.variable_loss_scaling.pl[ | ||
split[0] | ||
] * pressure_level.scaler( | ||
int(split[-1]), | ||
) | ||
else: | ||
LOGGER.debug("Parameter %s was not scaled.", key) | ||
else: | ||
# Apply surface variable scaling | ||
if key in config.training.variable_loss_scaling.sfc: | ||
variable_loss_scaling[idx] = config.training.variable_loss_scaling.sfc[key] | ||
else: | ||
LOGGER.debug("Parameter %s was not scaled.", key) | ||
|
||
return torch.from_numpy(variable_loss_scaling) | ||
|
||
@staticmethod | ||
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: | ||
node_weighting = instantiate(config.training.node_loss_weights) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this is possible but I wonder if here we could instantiate from a list of scalars rather than specific ones? I think this is how it is done for the validation metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this would be useful, as we want to allow more scalar methods for model levels/tendency/etc.
I will have a look.