From f4932e5928e2b8da5a167cd795e66d7b4da17a34 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:49:59 +0100 Subject: [PATCH] Feature/update graph callbacks (#135) * fix: merge node and edge callbacks * fix: update to support anemoi-models PR #64 * fix: minor fix * fix: update CHANGELOG.md * fix: trainable tensor * fix: import --- CHANGELOG.md | 1 + .../config/diagnostics/plot/detailed.yaml | 3 +- .../config/diagnostics/plot/rollout_eval.yaml | 3 +- .../training/diagnostics/callbacks/plot.py | 47 +++---------------- src/anemoi/training/diagnostics/plots.py | 39 ++++----------- 5 files changed, 19 insertions(+), 74 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc9d40f0..791f5977 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ Keep it human-readable, your future self will thank you! ### Changed - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) - Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) +- Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135) ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index 6ed15fa4..b759c17b 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -24,8 +24,7 @@ precip_and_related_fields: [tp, cp] callbacks: # Add plot callbacks here - - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot - - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot every_n_epochs: 5 - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss # group parameters by categories when visualizing contributions to the loss diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml index 4c440e24..642e6e6b 100644 --- a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -24,8 +24,7 @@ precip_and_related_fields: [tp, cp] callbacks: # Add plot callbacks here - - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot - - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot every_n_epochs: 5 - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss # group parameters by categories when visualizing contributions to the loss diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 93f9aa17..869a69fb 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -429,8 +429,8 @@ def on_validation_batch_end( self._plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) -class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback): - """Visualize the node trainable features defined.""" +class GraphTrainableFeaturesPlot(BasePerEpochPlotCallback): + """Visualize the node & edge trainable features defined.""" def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None: """Initialise the GraphTrainableFeaturesPlot callback. @@ -456,57 +456,22 @@ def _plot( fig = plot_graph_node_features(model) - tag = "node_trainable_params" - exp_log_tag = "node_trainable_params" - self._output_figure( trainer.logger, fig, epoch=trainer.current_epoch, - tag=tag, - exp_log_tag=exp_log_tag, + tag="node_trainable_params", + exp_log_tag="node_trainable_params", ) - -class GraphEdgeTrainableFeaturesPlot(BasePerEpochPlotCallback): - """Trainable edge features plot. - - Visualize the trainable features defined at the edges between meshes. - """ - - def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None: - """Plot trainable edge features. - - Parameters - ---------- - config : OmegaConf - Config object - every_n_epochs : int | None, optional - Override for frequency to plot at, by default None - """ - super().__init__(config, every_n_epochs=every_n_epochs) - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - epoch: int, - ) -> None: - _ = epoch - - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model fig = plot_graph_edge_features(model) - tag = "edge_trainable_params" - exp_log_tag = "edge_trainable_params" - self._output_figure( trainer.logger, fig, epoch=trainer.current_epoch, - tag=tag, - exp_log_tag=exp_log_tag, + tag="edge_trainable_params", + exp_log_tag="edge_trainable_params", ) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 6bf72637..dde80018 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -16,7 +16,6 @@ import matplotlib.pyplot as plt import matplotlib.style as mplstyle import numpy as np -import torch from anemoi.models.layers.mapper import GraphEdgeMixin from matplotlib.collections import LineCollection from matplotlib.colors import BoundaryNorm @@ -31,6 +30,7 @@ if TYPE_CHECKING: from matplotlib.figure import Figure + from torch import nn from dataclasses import dataclass @@ -680,25 +680,7 @@ def edge_plot( fig.colorbar(psc, ax=ax) -def sincos_to_latlon(sincos_coords: torch.Tensor) -> torch.Tensor: - """Get the lat/lon coordinates from the model. - - Parameters - ---------- - sincos_coords: torch.Tensor of shape (N, 4) - Sine and cosine of latitude and longitude coordinates. - - Returns - ------- - torch.Tensor of shape (N, 2) - Lat/lon coordinates. - """ - ndim = sincos_coords.shape[1] // 2 - sin_y, cos_y = sincos_coords[:, :ndim], sincos_coords[:, ndim:] - return torch.atan2(sin_y, cos_y) - - -def plot_graph_node_features(model: torch.nn.Module) -> Figure: +def plot_graph_node_features(model: nn.Module) -> Figure: """Plot trainable graph node features. Parameters @@ -712,14 +694,13 @@ def plot_graph_node_features(model: torch.nn.Module) -> Figure: Figure object handle """ nrows = len(nodes_name := model._graph_data.node_types) - ncols = min(getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name) + ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize) - for row, mesh in enumerate(nodes_name): - sincos_coords = getattr(model, f"latlons_{mesh}") - latlons = sincos_to_latlon(sincos_coords).cpu().numpy() - features = getattr(model, f"trainable_{mesh}").trainable.cpu().detach().numpy() + for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()): + latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy() + node_features = trainable_tensor.trainable.cpu().detach().numpy() lat, lon = latlons[:, 0], latlons[:, 1] @@ -730,14 +711,14 @@ def plot_graph_node_features(model: torch.nn.Module) -> Figure: ax_, lon=lon, lat=lat, - data=features[..., i], + data=node_features[..., i], title=f"{mesh} trainable feature #{i + 1}", ) return fig -def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0.05) -> Figure: +def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure: """Plot trainable graph edge features. Parameters @@ -766,8 +747,8 @@ def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0. fig, ax = plt.subplots(nrows, ncols, figsize=figsize) for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()): - src_coords = sincos_to_latlon(getattr(model, f"latlons_{src}")).cpu().numpy() - dst_coords = sincos_to_latlon(getattr(model, f"latlons_{dst}")).cpu().numpy() + src_coords = model.node_attributes.get_coordinates(src).cpu().numpy() + dst_coords = model.node_attributes.get_coordinates(dst).cpu().numpy() edge_index = graph_mapper.edge_index_base.cpu().numpy() edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy()