Skip to content

Commit

Permalink
Pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Nov 11, 2024
1 parent 17be84b commit e99a5a7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
23 changes: 14 additions & 9 deletions src/anemoi/training/data/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from abc import abstractmethod

import numpy as np

import torch
from torch_geometric.data import HeteroData
from scipy.spatial import SphericalVoronoi
from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian
from scipy.spatial import SphericalVoronoi
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +82,7 @@ def scaler(plev: float) -> np.ndarray:
# no scaling, always return 1.0
return 1.0


class BaseAreaWeights:
"""Method for overwriting the area-weights stored in the graph object."""

Expand All @@ -96,29 +96,34 @@ def __init__(self, target_nodes: str):
"""
self.target = target_nodes

def area_weights(self, graph_data) -> torch.Tensor:
def area_weights(self, graph_data: HeteroData) -> torch.Tensor:
return torch.from_numpy(self.global_area_weights(graph_data))

def global_area_weights(self, graph_data: HeteroData) -> np.ndarray:
lats, lons = graph_data[self.target].x[:,0], graph_data[self.target].x[:,1]
lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1]
points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons)))
sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0])
area_weights = sv.calculate_areas()

return area_weights / np.max(area_weights)


class StretchedGridCutoutAreaWeights(BaseAreaWeights):
"""Rescale area weight of nodes inside the cutout area by setting their sum to a fraction of the sum of global nodes area weight."""
"""Rescale area weight of nodes inside the cutout area.
Sum of the area weight of cutout nodes set to a specified fraction of sum of global nodes.
"""

def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float):
"""Initialize area weights with target nodes and scaling factor for the cutout nodes area weight.
Parameters
----------
target_nodes : str
Name of the set of nodes to be rescaled (defined when creating the graph).
cutout_weight_frac_of_global: float
Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of the sum of global nodes area weight.
Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of
the sum of global nodes area weight.
"""
super().__init__(target_nodes=target_nodes)
self.fraction = cutout_weight_frac_of_global
Expand All @@ -131,4 +136,4 @@ def area_weights(self, graph_data: HeteroData) -> torch.Tensor:
weight_per_cutout_node = self.fraction * global_sum / sum(mask)
area_weights[mask] = weight_per_cutout_node

return torch.from_numpy(area_weights)
return torch.from_numpy(area_weights)
9 changes: 3 additions & 6 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,16 @@ def get_feature_weights(
LOGGER.debug("Parameter %s was not scaled.", key)

return torch.from_numpy(loss_scaling)

@staticmethod
def get_node_weights(
config: DictConfig,
graph_data: HeteroData
) -> torch.Tensor:
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze()

if "spatial" in config.training.loss_scaling:
spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial)
node_weights = spatial_loss_scaler.area_weights(graph_data)
LOGGER.info("Rescaling area weights")

return node_weights

def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None:
Expand Down

0 comments on commit e99a5a7

Please sign in to comment.