Skip to content

Commit

Permalink
Small fixes - training now worked for all cases
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen committed Nov 14, 2024
1 parent 8dc5e11 commit cc4f38b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,6 @@ pressure_level_scaler:
slope: 0.001

node_loss_weights:
_target_: anemoi.traininig.losses.nodeweights.GraphNodeAttribute
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
node_attribute: area_weight
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor:

LOGGER.info("Loading node attribute %s from the graph", self.node_attribute)
except KeyError:
attr_weight = torch.from_numpy(self.global_area_weights(graph_data))
attr_weight = torch.from_numpy(self.area_weights(graph_data))

LOGGER.info(
"Node attribute %s not found in graph. Default area weighting will be used",
Expand All @@ -51,7 +51,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor:


class ReweightedGraphNodeAttribute(GraphNodeAttribute):
"""Method to reweight a subset of the target nodes defined by scaled_attributes.
"""Method to reweight a subset of the target nodes defined by scaled_attribute.
Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum
over all nodes.
Expand All @@ -63,26 +63,15 @@ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str
self.fraction = weight_frac_of_total

def weights(self, graph_data: HeteroData) -> torch.Tensor:
try:
attr_weight = graph_data[self.target][self.node_attribute].squeeze()

LOGGER.info("Loading node attribute %s from the graph", self.node_attribute)
except KeyError:
attr_weight = torch.from_numpy(self.global_area_weights(graph_data))

LOGGER.info(
"Node attribute %s not found in graph. Default area weighting will be used",
self.node_attribute,
)
attr_weight = super().weights(graph_data)

mask = graph_data[self.target][self.scaled_attribute].squeeze().bool()

unmasked_sum = torch.sum(attr_weight[~mask])
weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask)
attr_weight[mask] = weight_per_masked_node
LOGGER.info(
"Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes",
self.node_attribute,
self.scaled_attribute,
self.fraction,
)

Expand Down

0 comments on commit cc4f38b

Please sign in to comment.