diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py index 909de1f2..e229b71b 100644 --- a/src/anemoi/training/losses/nodeweights.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -108,6 +108,20 @@ class ReweightedGraphNodeAttribute(GraphNodeAttribute): """ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float): + """Initialize reweighted graph node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + scaled_attribute: str + name of node attribute defining the subset of nodes to be scaled, key in HeteroData graph object + weight_frac_of_total: float + sum of weight of subset nodes as a fraction of sum of weight of all nodes after rescaling + + """ super().__init__(target_nodes=target_nodes, node_attribute=node_attribute) self.scaled_attribute = scaled_attribute self.fraction = weight_frac_of_total