Skip to content

Commit

Permalink
Implementation of area weight rescaling
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Nov 6, 2024
1 parent 2a7fb15 commit caf67f1
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 33 deletions.
14 changes: 7 additions & 7 deletions lumi_train.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/bin/bash
#SBATCH --output=/scratch/project_465001383/aifs/logs/name2.out
#SBATCH --error=/scratch/project_465001383/aifs/logs/name2.err
#SBATCH --nodes=1
#SBATCH --output=/scratch/project_465001383/aifs/logs/anemoi_training_aw0.3.out
#SBATCH --error=/scratch/project_465001383/aifs/logs/anemoi_training_aw0.3.err
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --account=project_465001383
#SBATCH --partition=dev-g
#SBATCH --partition=standard-g
#SBATCH --gpus-per-node=8
#SBATCH --time=00:15:00
#SBATCH --time=01:00:00
#SBATCH --job-name=aifs
#SBATCH --exclusive

Expand All @@ -15,7 +15,7 @@ CONTAINER_SCRIPT=$PROJECT_DIR/aifs/run-pytorch/run-pytorch.sh

#CHANGE THESE:
CONTAINER=$PROJECT_DIR/aifs/container/containers/aifs-met-pytorch-2.2.0-rocm-5.6.1-py3.9-v2.0-new-correct-anemoi-models-sort-vars.sif
PYTHON_SCRIPT=$PROJECT_DIR/haugenha/anemoi-training-setup/anemoi-training-config/anemoi-training/src/lumi_train.py
PYTHON_SCIPT=$PROJECT_DIR/haugenha/anemoi-training-setup/weights_rescale/anemoi-training/src/lumi_train.py
VENV=/users/haugenha/work/.venv-anemoi-training


Expand Down Expand Up @@ -45,4 +45,4 @@ srun --cpu-bind=$CPU_BIND \
-B /opt/cray \
-B /usr/lib64 \
-B /usr/lib64/libjansson.so.4 \
$CONTAINER $CONTAINER_SCRIPT $PYTHON_SCRIPT
$CONTAINER $CONTAINER_SCRIPT $PYTHON_SCRIPT
12 changes: 6 additions & 6 deletions src/anemoi/training/config/diagnostics/eval_rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ log:
tensorboard:
enabled: False
mlflow:
enabled: False
enabled: True
offline: False
authentication: False
authentication: True
log_model: False
tracking_uri: ???
experiment_name: 'anemoi-debug'
tracking_uri: https://mlflow.ecmwf.int
experiment_name: 'metno'
project_name: 'Anemoi'
system: True
system: False
terminal: True
run_name: null # If set to null, the run name will be the a random UUID
run_name: test_aw_rescale_0.3 # If set to null, the run name will be the a random UUID
on_resume_create_child: True
interval: 100

Expand Down
17 changes: 5 additions & 12 deletions src/anemoi/training/config/graph/stretched_grid.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
overwrite: False

data: "stretched_grid"
data: "data"
hidden: "hidden"

nodes:
stretched_grid:
data:
node_builder:
_target_: anemoi.graphs.nodes.CutOutZarrDatasetNodes
lam_dataset: ${hardware.paths.data}/${hardware.files.dataset_lam}
forcing_dataset: ${hardware.paths.data}/${hardware.files.dataset}
adjust: all
# min_distance_km: 0
_target_: anemoi.graphs.nodes.ZarrDatasetNodes
dataset: ${dataloader.dataset}
attributes:
area_weight:
_target_: anemoi.graphs.nodes.attributes.AreaWeights
norm: unit-max
# lam_weights_rescale: 1.0
hidden:
node_builder:
_target_: anemoi.graphs.nodes.StretchedTriNodes
lam_resolution: 8
global_resolution: 5
reference_node_name: ${graph.data}
mask_attr_name: cutout
margin_radius_km: 10
margin_radius_km: 0

edges:
- source_name: ${graph.data}
Expand All @@ -35,7 +31,6 @@ edges:
edge_length:
_target_: anemoi.graphs.edges.attributes.EdgeLength
norm: unit-max
invert: True
edge_dirs:
_target_: anemoi.graphs.edges.attributes.EdgeDirection
norm: unit-std
Expand All @@ -48,7 +43,6 @@ edges:
edge_length:
_target_: anemoi.graphs.edges.attributes.EdgeLength
norm: unit-max
invert: True
edge_dirs:
_target_: anemoi.graphs.edges.attributes.EdgeDirection
norm: unit-std
Expand All @@ -61,7 +55,6 @@ edges:
edge_length:
_target_: anemoi.graphs.edges.attributes.EdgeLength
norm: unit-max
invert: True
edge_dirs:
_target_: anemoi.graphs.edges.attributes.EdgeDirection
norm: unit-std
Expand Down
27 changes: 20 additions & 7 deletions src/anemoi/training/config/stretched_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ dataloader:
- dataset: ${hardware.paths.data}/${hardware.files.dataset_lam}
- dataset: ${hardware.paths.data}/${hardware.files.dataset}
adjust: all
missing_dates: ['2020-03-23T06', '2020-06-07T00', '2021-07-07T06', '2022-01-15T12', '2022-01-15T18', '2022-01-16T00', '2022-01-22T00', '2022-09-07T12', '2022-12-25T06']
skip_missing_dates: True
expected_access: 3

limit_batches:
training: 20
validation: 20
training: null
validation: null

training:
start: 2020-02-05
Expand All @@ -53,7 +56,7 @@ dataloader:

hardware: #change these to lumi paths
num_gpus_per_node: 8
num_nodes: 1
num_nodes: 4
num_gpus_per_model: 1
paths:
data: /pfs/lustrep4/scratch/project_465001383/aifs/dataset/
Expand All @@ -63,7 +66,7 @@ hardware: #change these to lumi paths
files:
dataset: ERA5/aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr #aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6.zarr
dataset_lam: MEPS/aifs-meps-10km-2020-2024-6h-v6.zarr
graph: test-anemoi-training.pt
graph: test-anemoi-training_weights_rescale.pt
warm_start: null #specific checkpoint to start from, defaults to last.ckpt

data:
Expand All @@ -87,10 +90,20 @@ training:
run_id: null #path to store the experiment in with output_base as root, null for random name, =fork_run_id to continue training in the same folder.
fork_run_id: null #path to the experiment to fork from with output_base as root
load_weights_only: False #loads entire model if False, loads only weights if True
max_epochs: 50
max_epochs: 750
lr:
rate: 5.0e-6
iterations: 10000
rate: 6.25e-05
iterations: 150000
min: 8.0e-6

loss_scaling:
spatial:
__target__: anemoi.training.data.scaling.StretchedGridCutoutWeighting
target_nodes: ${graph.data}
cutout_weight_frac_of_global: 0.3






42 changes: 42 additions & 0 deletions src/anemoi/training/data/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from abc import abstractmethod

import numpy as np
from torch_geometric.data import HeteroData
from scipy.spatial import SphericalVoronoi
from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian


LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,3 +78,41 @@ def scaler(plev: float) -> np.ndarray:
del plev # unused
# no scaling, always return 1.0
return 1.0

class BaseAreaWeights:

def __init__(self, target_nodes: str, radius: float = 1.0, center: list = [0.0,0.0,0.0]):
self.target = target_nodes
self.radius = radius
self.center = center

def global_area_weights(self, graph_data: HeteroData) -> np.ndarray:
lats, lons = graph_data[self.target].x[:,0], graph_data[self.target].x[:,1]
points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons)))
sv = SphericalVoronoi(points, self.radius, self.center)
area_weights = sv.calculate_areas()
return area_weights / np.max(area_weights)

def area_weights(self, graph_data) -> np.ndarray:
return self.global_area_weights(graph_data)

class StretchedGridCutoutAreaWeights(BaseAreaWeights):

def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float, radius: float = 1.0, center: list = [0.0,0.0,0.0]):
super().__init__(target_nodes=target_nodes, radius=radius, center = center)
self.fraction = cutout_weight_frac_of_global

def area_weights(self, graph_data: HeteroData) -> np.ndarray:
area_weights = self.global_area_weights(graph_data)
mask = graph_data[self.target]["cutout"].squeeze().bool()

global_sum = np.sum(area_weights[~mask])
weight_per_cutout_node = self.fraction * global_sum / sum(mask)
area_weights[mask] = weight_per_cutout_node

return area_weights





15 changes: 14 additions & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def __init__(
self.save_hyperparameters()

self.latlons_data = graph_data[config.graph.data].x
self.loss_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze()
loss_weights = self.get_node_weights(config, graph_data)
# For testing
torch.save(loss_weights, config.hardware.paths.output + "node_weights_rescaled.pt")
self.loss_weights = loss_weights


if config.model.get("output_mask", None) is not None:
self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask])
Expand Down Expand Up @@ -191,6 +195,15 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t
metric_ranges_validation[key] = [idx]
return metric_ranges, metric_ranges_validation, loss_scaling

def get_node_weights(self, config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze()

if config.training.loss_scaling.spatial:
spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial)
node_weights = torch.from_numpy(spatial_loss_scaler.area_weights(graph_data))

return node_weights

def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None:
LOGGER.debug("set_model_comm_group: %s", model_comm_group)
self.model_comm_group = model_comm_group
Expand Down

0 comments on commit caf67f1

Please sign in to comment.