Skip to content

Commit

Permalink
Merge branch 'ecmwf:develop' into pr/aw_rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen authored Nov 26, 2024
2 parents c159e27 + 1abb65e commit 7bb2919
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/anemoi/training/losses/weightedloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def scale(

if scalar_indices is None:
return x * scalar

scalar = scalar.expand_as(x)
return x * scalar[scalar_indices]

def scale_by_node_weights(self, x: torch.Tensor, squash: bool = True) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def calculate_val_metrics(
metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric(
y_pred_postprocessed[..., indices],
y_postprocessed[..., indices],
scalar_indices=[..., indices],
scalar_indices=[..., indices] if -1 in metric.scalar else None,
)

return metrics
Expand Down

0 comments on commit 7bb2919

Please sign in to comment.