diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index dcacd3ae..f890a1f1 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -13,11 +13,10 @@ from abc import abstractmethod import numpy as np - import torch -from torch_geometric.data import HeteroData -from scipy.spatial import SphericalVoronoi from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian +from scipy.spatial import SphericalVoronoi +from torch_geometric.data import HeteroData LOGGER = logging.getLogger(__name__) @@ -83,6 +82,7 @@ def scaler(plev: float) -> np.ndarray: # no scaling, always return 1.0 return 1.0 + class BaseAreaWeights: """Method for overwriting the area-weights stored in the graph object.""" @@ -96,29 +96,34 @@ def __init__(self, target_nodes: str): """ self.target = target_nodes - def area_weights(self, graph_data) -> torch.Tensor: + def area_weights(self, graph_data: HeteroData) -> torch.Tensor: return torch.from_numpy(self.global_area_weights(graph_data)) def global_area_weights(self, graph_data: HeteroData) -> np.ndarray: - lats, lons = graph_data[self.target].x[:,0], graph_data[self.target].x[:,1] + lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1] points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) area_weights = sv.calculate_areas() return area_weights / np.max(area_weights) + class StretchedGridCutoutAreaWeights(BaseAreaWeights): - """Rescale area weight of nodes inside the cutout area by setting their sum to a fraction of the sum of global nodes area weight.""" + """Rescale area weight of nodes inside the cutout area. + + Sum of the area weight of cutout nodes set to a specified fraction of sum of global nodes. + """ def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float): """Initialize area weights with target nodes and scaling factor for the cutout nodes area weight. - + Parameters ---------- target_nodes : str Name of the set of nodes to be rescaled (defined when creating the graph). cutout_weight_frac_of_global: float - Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of the sum of global nodes area weight. + Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of + the sum of global nodes area weight. """ super().__init__(target_nodes=target_nodes) self.fraction = cutout_weight_frac_of_global @@ -131,4 +136,4 @@ def area_weights(self, graph_data: HeteroData) -> torch.Tensor: weight_per_cutout_node = self.fraction * global_sum / sum(mask) area_weights[mask] = weight_per_cutout_node - return torch.from_numpy(area_weights) \ No newline at end of file + return torch.from_numpy(area_weights) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 3b2ccb5b..94caa435 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -289,19 +289,16 @@ def get_feature_weights( LOGGER.debug("Parameter %s was not scaled.", key) return torch.from_numpy(loss_scaling) - + @staticmethod - def get_node_weights( - config: DictConfig, - graph_data: HeteroData - ) -> torch.Tensor: + def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() if "spatial" in config.training.loss_scaling: spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial) node_weights = spatial_loss_scaler.area_weights(graph_data) LOGGER.info("Rescaling area weights") - + return node_weights def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: