diff --git a/graphs/CHANGELOG.md b/graphs/CHANGELOG.md index 5b411893..4d7dbb32 100644 --- a/graphs/CHANGELOG.md +++ b/graphs/CHANGELOG.md @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.2...HEAD) +### Added + +- feat: Support for multi-dimensional node attributes in plots (#48) + ## [0.4.2 - Optimisations and lat-lon](https://github.com/ecmwf/anemoi-graphs/compare/0.4.1...0.4.2) - 2024-12-19 ### Added diff --git a/graphs/src/anemoi/graphs/plotting/interactive_html.py b/graphs/src/anemoi/graphs/plotting/interactive_html.py index 7021bf16..f3162507 100644 --- a/graphs/src/anemoi/graphs/plotting/interactive_html.py +++ b/graphs/src/anemoi/graphs/plotting/interactive_html.py @@ -15,6 +15,7 @@ import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go +import torch from matplotlib.colors import rgb2hex from torch_geometric.data import HeteroData @@ -197,25 +198,26 @@ def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optiona for node_attr in node_attrs: node_attr_values = graph[nodes_name][node_attr].float().numpy() - # Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors - if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1: - continue - - node_traces[node_attr] = go.Scattergeo( - lat=node_latitudes, - lon=node_longitudes, - name=" ".join(node_attr.split("_")).capitalize(), - mode="markers", - hoverinfo="text", - marker={ - "color": node_attr_values.squeeze().tolist(), - "showscale": True, - "colorscale": "RdBu", - "colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"}, - "size": 5, - }, - visible=False, - ) + if node_attr_values.ndim == 1: + node_attr_values = torch.unsqueeze(node_attr_values, -1) + + for attr_dim in range(node_attr_values.shape[1]): + suffix = "" if node_attr_values.shape[1] == 1 else f"_[{attr_dim}]" + node_traces[node_attr + suffix] = go.Scattergeo( + lat=node_latitudes, + lon=node_longitudes, + name=" ".join((node_attr + suffix).split("_")).capitalize(), + mode="markers", + hoverinfo="text", + marker={ + "color": node_attr_values[:, attr_dim].squeeze().tolist(), + "showscale": True, + "colorscale": "RdBu", + "colorbar": {"thickness": 15, "title": node_attr + suffix, "xanchor": "left"}, + "size": 5, + }, + visible=False, + ) # Create and add slider slider_steps = []