From caf67f1ea16a7a7f92d2096dfc763c54ad8d6a35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Haugen?= Date: Wed, 6 Nov 2024 18:21:41 +0200 Subject: [PATCH] Implementation of area weight rescaling --- lumi_train.sh | 14 +++---- .../config/diagnostics/eval_rollout.yaml | 12 +++--- .../training/config/graph/stretched_grid.yaml | 17 +++----- .../training/config/stretched_grid.yaml | 27 ++++++++---- src/anemoi/training/data/scaling.py | 42 +++++++++++++++++++ src/anemoi/training/train/forecaster.py | 15 ++++++- 6 files changed, 94 insertions(+), 33 deletions(-) diff --git a/lumi_train.sh b/lumi_train.sh index a653f9e0..c73a931d 100644 --- a/lumi_train.sh +++ b/lumi_train.sh @@ -1,12 +1,12 @@ #!/bin/bash -#SBATCH --output=/scratch/project_465001383/aifs/logs/name2.out -#SBATCH --error=/scratch/project_465001383/aifs/logs/name2.err -#SBATCH --nodes=1 +#SBATCH --output=/scratch/project_465001383/aifs/logs/anemoi_training_aw0.3.out +#SBATCH --error=/scratch/project_465001383/aifs/logs/anemoi_training_aw0.3.err +#SBATCH --nodes=4 #SBATCH --ntasks-per-node=8 #SBATCH --account=project_465001383 -#SBATCH --partition=dev-g +#SBATCH --partition=standard-g #SBATCH --gpus-per-node=8 -#SBATCH --time=00:15:00 +#SBATCH --time=01:00:00 #SBATCH --job-name=aifs #SBATCH --exclusive @@ -15,7 +15,7 @@ CONTAINER_SCRIPT=$PROJECT_DIR/aifs/run-pytorch/run-pytorch.sh #CHANGE THESE: CONTAINER=$PROJECT_DIR/aifs/container/containers/aifs-met-pytorch-2.2.0-rocm-5.6.1-py3.9-v2.0-new-correct-anemoi-models-sort-vars.sif -PYTHON_SCRIPT=$PROJECT_DIR/haugenha/anemoi-training-setup/anemoi-training-config/anemoi-training/src/lumi_train.py +PYTHON_SCIPT=$PROJECT_DIR/haugenha/anemoi-training-setup/weights_rescale/anemoi-training/src/lumi_train.py VENV=/users/haugenha/work/.venv-anemoi-training @@ -45,4 +45,4 @@ srun --cpu-bind=$CPU_BIND \ -B /opt/cray \ -B /usr/lib64 \ -B /usr/lib64/libjansson.so.4 \ - $CONTAINER $CONTAINER_SCRIPT $PYTHON_SCRIPT \ No newline at end of file + $CONTAINER $CONTAINER_SCRIPT $PYTHON_SCRIPT diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index 032faadf..ee7e9d3d 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -82,16 +82,16 @@ log: tensorboard: enabled: False mlflow: - enabled: False + enabled: True offline: False - authentication: False + authentication: True log_model: False - tracking_uri: ??? - experiment_name: 'anemoi-debug' + tracking_uri: https://mlflow.ecmwf.int + experiment_name: 'metno' project_name: 'Anemoi' - system: True + system: False terminal: True - run_name: null # If set to null, the run name will be the a random UUID + run_name: test_aw_rescale_0.3 # If set to null, the run name will be the a random UUID on_resume_create_child: True interval: 100 diff --git a/src/anemoi/training/config/graph/stretched_grid.yaml b/src/anemoi/training/config/graph/stretched_grid.yaml index 2f814fe2..1455c61d 100644 --- a/src/anemoi/training/config/graph/stretched_grid.yaml +++ b/src/anemoi/training/config/graph/stretched_grid.yaml @@ -1,21 +1,17 @@ overwrite: False -data: "stretched_grid" +data: "data" hidden: "hidden" nodes: - stretched_grid: + data: node_builder: - _target_: anemoi.graphs.nodes.CutOutZarrDatasetNodes - lam_dataset: ${hardware.paths.data}/${hardware.files.dataset_lam} - forcing_dataset: ${hardware.paths.data}/${hardware.files.dataset} - adjust: all -# min_distance_km: 0 + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: ${dataloader.dataset} attributes: area_weight: _target_: anemoi.graphs.nodes.attributes.AreaWeights norm: unit-max -# lam_weights_rescale: 1.0 hidden: node_builder: _target_: anemoi.graphs.nodes.StretchedTriNodes @@ -23,7 +19,7 @@ nodes: global_resolution: 5 reference_node_name: ${graph.data} mask_attr_name: cutout - margin_radius_km: 10 + margin_radius_km: 0 edges: - source_name: ${graph.data} @@ -35,7 +31,6 @@ edges: edge_length: _target_: anemoi.graphs.edges.attributes.EdgeLength norm: unit-max - invert: True edge_dirs: _target_: anemoi.graphs.edges.attributes.EdgeDirection norm: unit-std @@ -48,7 +43,6 @@ edges: edge_length: _target_: anemoi.graphs.edges.attributes.EdgeLength norm: unit-max - invert: True edge_dirs: _target_: anemoi.graphs.edges.attributes.EdgeDirection norm: unit-std @@ -61,7 +55,6 @@ edges: edge_length: _target_: anemoi.graphs.edges.attributes.EdgeLength norm: unit-max - invert: True edge_dirs: _target_: anemoi.graphs.edges.attributes.EdgeDirection norm: unit-std diff --git a/src/anemoi/training/config/stretched_grid.yaml b/src/anemoi/training/config/stretched_grid.yaml index 28e137d8..664c093f 100644 --- a/src/anemoi/training/config/stretched_grid.yaml +++ b/src/anemoi/training/config/stretched_grid.yaml @@ -27,10 +27,13 @@ dataloader: - dataset: ${hardware.paths.data}/${hardware.files.dataset_lam} - dataset: ${hardware.paths.data}/${hardware.files.dataset} adjust: all + missing_dates: ['2020-03-23T06', '2020-06-07T00', '2021-07-07T06', '2022-01-15T12', '2022-01-15T18', '2022-01-16T00', '2022-01-22T00', '2022-09-07T12', '2022-12-25T06'] + skip_missing_dates: True + expected_access: 3 limit_batches: - training: 20 - validation: 20 + training: null + validation: null training: start: 2020-02-05 @@ -53,7 +56,7 @@ dataloader: hardware: #change these to lumi paths num_gpus_per_node: 8 - num_nodes: 1 + num_nodes: 4 num_gpus_per_model: 1 paths: data: /pfs/lustrep4/scratch/project_465001383/aifs/dataset/ @@ -63,7 +66,7 @@ hardware: #change these to lumi paths files: dataset: ERA5/aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr #aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr dataset_lam: MEPS/aifs-meps-10km-2020-2024-6h-v6.zarr - graph: test-anemoi-training.pt + graph: test-anemoi-training_weights_rescale.pt warm_start: null #specific checkpoint to start from, defaults to last.ckpt data: @@ -87,10 +90,20 @@ training: run_id: null #path to store the experiment in with output_base as root, null for random name, =fork_run_id to continue training in the same folder. fork_run_id: null #path to the experiment to fork from with output_base as root load_weights_only: False #loads entire model if False, loads only weights if True - max_epochs: 50 + max_epochs: 750 lr: - rate: 5.0e-6 - iterations: 10000 + rate: 6.25e-05 + iterations: 150000 min: 8.0e-6 + loss_scaling: + spatial: + __target__: anemoi.training.data.scaling.StretchedGridCutoutWeighting + target_nodes: ${graph.data} + cutout_weight_frac_of_global: 0.3 + + + + + diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index 74ba9c23..c97710d1 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -10,6 +10,10 @@ from abc import abstractmethod import numpy as np +from torch_geometric.data import HeteroData +from scipy.spatial import SphericalVoronoi +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian + LOGGER = logging.getLogger(__name__) @@ -74,3 +78,41 @@ def scaler(plev: float) -> np.ndarray: del plev # unused # no scaling, always return 1.0 return 1.0 + +class BaseAreaWeights: + + def __init__(self, target_nodes: str, radius: float = 1.0, center: list = [0.0,0.0,0.0]): + self.target = target_nodes + self.radius = radius + self.center = center + + def global_area_weights(self, graph_data: HeteroData) -> np.ndarray: + lats, lons = graph_data[self.target].x[:,0], graph_data[self.target].x[:,1] + points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) + sv = SphericalVoronoi(points, self.radius, self.center) + area_weights = sv.calculate_areas() + return area_weights / np.max(area_weights) + + def area_weights(self, graph_data) -> np.ndarray: + return self.global_area_weights(graph_data) + +class StretchedGridCutoutAreaWeights(BaseAreaWeights): + + def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float, radius: float = 1.0, center: list = [0.0,0.0,0.0]): + super().__init__(target_nodes=target_nodes, radius=radius, center = center) + self.fraction = cutout_weight_frac_of_global + + def area_weights(self, graph_data: HeteroData) -> np.ndarray: + area_weights = self.global_area_weights(graph_data) + mask = graph_data[self.target]["cutout"].squeeze().bool() + + global_sum = np.sum(area_weights[~mask]) + weight_per_cutout_node = self.fraction * global_sum / sum(mask) + area_weights[mask] = weight_per_cutout_node + + return area_weights + + + + + diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 6738848a..6a8351ef 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -82,7 +82,11 @@ def __init__( self.save_hyperparameters() self.latlons_data = graph_data[config.graph.data].x - self.loss_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + loss_weights = self.get_node_weights(config, graph_data) + # For testing + torch.save(loss_weights, config.hardware.paths.output + "node_weights_rescaled.pt") + self.loss_weights = loss_weights + if config.model.get("output_mask", None) is not None: self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) @@ -191,6 +195,15 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t metric_ranges_validation[key] = [idx] return metric_ranges, metric_ranges_validation, loss_scaling + def get_node_weights(self, config: DictConfig, graph_data: HeteroData) -> torch.Tensor: + node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + + if config.training.loss_scaling.spatial: + spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial) + node_weights = torch.from_numpy(spatial_loss_scaler.area_weights(graph_data)) + + return node_weights + def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: LOGGER.debug("set_model_comm_group: %s", model_comm_group) self.model_comm_group = model_comm_group