Skip to content

Commit

Permalink
Docstrings GraphNodeAttributes, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Nov 15, 2024
1 parent 5df91e1 commit d0d2b57
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions src/anemoi/training/losses/nodeweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,48 @@


class GraphNodeAttribute:
"""Method to load and optionally change the weighting of node attributes in the graph."""
"""Base class to load and optionally change the weight attribute of nodes in the graph.
Attributes
----------
target: str
name of target nodes, key in HeteroData graph object
node_attribute: str
name of node weight attribute, key in HeteroData graph object
Methods
-------
weights(self, graph_data)
Load node weight attribute. Compute area weights if they can not be found in graph
object.
"""

def __init__(self, target_nodes: str, node_attribute: str):
"""Initialize graph node attribute with target nodes and node attribute.
Parameters
----------
target_nodes: str
name of nodes, key in HeteroData graph object
node_attribute: str
name of node weight attribute, key in HeteroData graph object
"""
self.target = target_nodes
self.node_attribute = node_attribute

def area_weights(self, graph_data: HeteroData) -> np.ndarray:
"""Nodes weighted by the size of the geographical area they represent.
Parameters
----------
graph_data: HeteroData
graph object
Returns
-------
np.ndarray
area weights of the target nodes
"""
lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1]
points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons)))
sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0])
Expand All @@ -35,11 +70,26 @@ def area_weights(self, graph_data: HeteroData) -> np.ndarray:
return area_weights / np.max(area_weights)

def weights(self, graph_data: HeteroData) -> torch.Tensor:
try:
"""Returns weight of type self.node_attribute for nodes self.target.
Attempts to load from graph_data and calculates area weights for the target
nodes if they do not exist.
Parameters
----------
graph_data: HeteroData
graph object
Returns
-------
torch.Tensor
weight of target nodes
"""
if self.node_attribute in graph_data[self.target]:
attr_weight = graph_data[self.target][self.node_attribute].squeeze()

LOGGER.info("Loading node attribute %s from the graph", self.node_attribute)
except KeyError:
else:
attr_weight = torch.from_numpy(self.area_weights(graph_data))

LOGGER.info(
Expand Down Expand Up @@ -69,6 +119,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor:
unmasked_sum = torch.sum(attr_weight[~mask])
weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask)
attr_weight[mask] = weight_per_masked_node

LOGGER.info(
"Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes",
self.scaled_attribute,
Expand Down

0 comments on commit d0d2b57

Please sign in to comment.