Skip to content

Commit

Permalink
Merge branch 'ecmwf:develop' into pr/aw_rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen authored Nov 12, 2024
2 parents c576efb + f4932e5 commit 6a527a9
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 74 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Keep it human-readable, your future self will thank you!
### Changed
- Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118)
- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67)
- Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135)

## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28

Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ precip_and_related_fields: [tp, cp]

callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot
every_n_epochs: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ precip_and_related_fields: [tp, cp]

callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot
every_n_epochs: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
Expand Down
47 changes: 6 additions & 41 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ def on_validation_batch_end(
self._plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch)


class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback):
"""Visualize the node trainable features defined."""
class GraphTrainableFeaturesPlot(BasePerEpochPlotCallback):
"""Visualize the node & edge trainable features defined."""

def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None:
"""Initialise the GraphTrainableFeaturesPlot callback.
Expand All @@ -456,57 +456,22 @@ def _plot(

fig = plot_graph_node_features(model)

tag = "node_trainable_params"
exp_log_tag = "node_trainable_params"

self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
tag=tag,
exp_log_tag=exp_log_tag,
tag="node_trainable_params",
exp_log_tag="node_trainable_params",
)


class GraphEdgeTrainableFeaturesPlot(BasePerEpochPlotCallback):
"""Trainable edge features plot.
Visualize the trainable features defined at the edges between meshes.
"""

def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None:
"""Plot trainable edge features.
Parameters
----------
config : OmegaConf
Config object
every_n_epochs : int | None, optional
Override for frequency to plot at, by default None
"""
super().__init__(config, every_n_epochs=every_n_epochs)

@rank_zero_only
def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
epoch: int,
) -> None:
_ = epoch

model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model
fig = plot_graph_edge_features(model)

tag = "edge_trainable_params"
exp_log_tag = "edge_trainable_params"

self._output_figure(
trainer.logger,
fig,
epoch=trainer.current_epoch,
tag=tag,
exp_log_tag=exp_log_tag,
tag="edge_trainable_params",
exp_log_tag="edge_trainable_params",
)


Expand Down
39 changes: 10 additions & 29 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import matplotlib.pyplot as plt
import matplotlib.style as mplstyle
import numpy as np
import torch
from anemoi.models.layers.mapper import GraphEdgeMixin
from matplotlib.collections import LineCollection
from matplotlib.colors import BoundaryNorm
Expand All @@ -31,6 +30,7 @@

if TYPE_CHECKING:
from matplotlib.figure import Figure
from torch import nn

from dataclasses import dataclass

Expand Down Expand Up @@ -680,25 +680,7 @@ def edge_plot(
fig.colorbar(psc, ax=ax)


def sincos_to_latlon(sincos_coords: torch.Tensor) -> torch.Tensor:
"""Get the lat/lon coordinates from the model.
Parameters
----------
sincos_coords: torch.Tensor of shape (N, 4)
Sine and cosine of latitude and longitude coordinates.
Returns
-------
torch.Tensor of shape (N, 2)
Lat/lon coordinates.
"""
ndim = sincos_coords.shape[1] // 2
sin_y, cos_y = sincos_coords[:, :ndim], sincos_coords[:, ndim:]
return torch.atan2(sin_y, cos_y)


def plot_graph_node_features(model: torch.nn.Module) -> Figure:
def plot_graph_node_features(model: nn.Module) -> Figure:
"""Plot trainable graph node features.
Parameters
Expand All @@ -712,14 +694,13 @@ def plot_graph_node_features(model: torch.nn.Module) -> Figure:
Figure object handle
"""
nrows = len(nodes_name := model._graph_data.node_types)
ncols = min(getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name)
ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name)
figsize = (ncols * 4, nrows * 3)
fig, ax = plt.subplots(nrows, ncols, figsize=figsize)

for row, mesh in enumerate(nodes_name):
sincos_coords = getattr(model, f"latlons_{mesh}")
latlons = sincos_to_latlon(sincos_coords).cpu().numpy()
features = getattr(model, f"trainable_{mesh}").trainable.cpu().detach().numpy()
for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()):
latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy()
node_features = trainable_tensor.trainable.cpu().detach().numpy()

lat, lon = latlons[:, 0], latlons[:, 1]

Expand All @@ -730,14 +711,14 @@ def plot_graph_node_features(model: torch.nn.Module) -> Figure:
ax_,
lon=lon,
lat=lat,
data=features[..., i],
data=node_features[..., i],
title=f"{mesh} trainable feature #{i + 1}",
)

return fig


def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0.05) -> Figure:
def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure:
"""Plot trainable graph edge features.
Parameters
Expand Down Expand Up @@ -766,8 +747,8 @@ def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0.
fig, ax = plt.subplots(nrows, ncols, figsize=figsize)

for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()):
src_coords = sincos_to_latlon(getattr(model, f"latlons_{src}")).cpu().numpy()
dst_coords = sincos_to_latlon(getattr(model, f"latlons_{dst}")).cpu().numpy()
src_coords = model.node_attributes.get_coordinates(src).cpu().numpy()
dst_coords = model.node_attributes.get_coordinates(dst).cpu().numpy()
edge_index = graph_mapper.edge_index_base.cpu().numpy()
edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy()

Expand Down

0 comments on commit 6a527a9

Please sign in to comment.