Skip to content

Commit

Permalink
Minor changes for revision of paper (#56)
Browse files Browse the repository at this point in the history
* Update citation and mention wave model

* Expose internal timestep of the model

* Fix title in README

* Expose level aggregation stabilisation tweak

* Test that changing the new flags changes the output

* Add more fine-tuning advice

* Make naming more consistent

* Test that model runs under DDP wrapping

* Add `Batch.{to,from}_netcdf`
  • Loading branch information
wesselb authored Dec 22, 2024
1 parent 8b11659 commit e6aeac6
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 21 deletions.
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
<img src="docs/aurora.jpg" alt="Aurora logo" width="200"/>
</p>

# Aurora: A Foundation Model of the Atmosphere
# Aurora: A Foundation Model for the Earth System

[![CI](https://github.com/microsoft/Aurora/actions/workflows/ci.yaml/badge.svg)](https://github.com/microsoft/Aurora/actions/workflows/ci.yaml)
[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://microsoft.github.io/aurora)
[![Paper](https://img.shields.io/badge/arXiv-2405.13063-blue)](https://arxiv.org/abs/2405.13063)

Implementation of the Aurora model for atmospheric forecasting.
Implementation of the Aurora model for Earth system forecasting.

_The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting._
_We are working on the fine-tuned version for air pollution forecasting, which will be included in due time._
_We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time._

[Link to the paper on arXiv.](https://arxiv.org/abs/2405.13063)

Expand All @@ -22,8 +22,8 @@ Cite us as follows:

```
@misc{bodnar2024aurora,
title = {Aurora: A Foundation Model of the Atmosphere},
author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan Weyn and Haiyu Dong and Anna Vaughan and Jayesh K. Gupta and Kit Tambiratnam and Alex Archibald and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris},
title = {Aurora: A Foundation Model for the Earth System},
author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Anna Vaughan and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan A. Weyn and Haiyu Dong and Jayesh K. Gupta and Kit Thambiratnam and Alexander T. Archibald and Chun-Chieh Wu and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris},
year = {2024},
url = {https://arxiv.org/abs/2405.13063},
eprint = {2405.13063},
Expand All @@ -48,10 +48,11 @@ Contents:
Aurora is a machine learning model that can predict atmospheric variables, such as temperature.
It is a _foundation model_, which means that it was first generally trained on a lot of data,
and then can be adapted to specialised atmospheric forecasting tasks with relatively little data.
We provide three such specialised versions:
We provide four such specialised versions:
one for medium-resolution weather prediction,
one for high-resolution weather prediction,
and one for air pollution prediction.
one for air pollution prediction,
and one for ocean wave prediction.

## Getting Started

Expand Down Expand Up @@ -127,7 +128,7 @@ Our goal in publishing this code is
This code has not been developed nor tested for non-academic purposes and hence should not be used as such.

### Limitations
Although Aurora was trained to accurately predict future weather and air pollution,
Although Aurora was trained to accurately predict future weather, air pollution, and ocean waves,
Aurora is based on neural networks, which means that there are no strict guarantees that predictions will always be accurate.
Altering the inputs, providing a sample that was not in the training set,
or even providing a sample that was in the training set but is simply unlucky may result in arbitrarily poor predictions.
Expand Down Expand Up @@ -183,7 +184,7 @@ make docs

To locally view the documentation, open `docs/_build/index.html` in your browser.

### Why is the fine-tuned version of Aurora for air quality forecasting missing?
### Why are the fine-tuned versions of Aurora for air quality and ocean wave forecasting missing?

The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting.
We are working on the fine-tuned version for air pollution forecasting, which will be included in due time.
We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time.
77 changes: 76 additions & 1 deletion aurora/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import dataclasses
from datetime import datetime
from functools import partial
from typing import Callable
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
Expand Down Expand Up @@ -220,6 +221,80 @@ def regrid(self, res: float) -> "Batch":
),
)

def to_netcdf(self, path: str | Path) -> None:
"""Write the batch to a file.
This requires `xarray` and `netcdf4` to be installed.
"""
try:
import xarray as xr
except ImportError as e:
raise RuntimeError("`xarray` must be installed.") from e

ds = xr.Dataset(
{
**{
f"surf_{k}": (("batch", "history", "latitude", "longitude"), _np(v))
for k, v in self.surf_vars.items()
},
**{
f"static_{k}": (("latitude", "longitude"), _np(v))
for k, v in self.static_vars.items()
},
**{
f"atmos_{k}": (("batch", "history", "level", "latitude", "longitude"), _np(v))
for k, v in self.atmos_vars.items()
},
},
coords={
"latitude": _np(self.metadata.lat),
"longitude": _np(self.metadata.lon),
"time": list(self.metadata.time),
"level": list(self.metadata.atmos_levels),
"rollout_step": self.metadata.rollout_step,
},
)
ds.to_netcdf(path)

@classmethod
def from_netcdf(cls, path: str | Path) -> "Batch":
"""Load a batch from a file."""
try:
import xarray as xr
except ImportError as e:
raise RuntimeError("`xarray` must be installed.") from e

ds = xr.load_dataset(path, engine="netcdf4")

surf_vars: List[str] = []
static_vars: List[str] = []
atmos_vars: List[str] = []

for k in ds:
if k.startswith("surf_"):
surf_vars.append(k.removeprefix("surf_"))
elif k.startswith("static_"):
static_vars.append(k.removeprefix("static_"))
elif k.startswith("atmos_"):
atmos_vars.append(k.removeprefix("atmos_"))

return Batch(
surf_vars={k: torch.from_numpy(ds[f"surf_{k}"].values) for k in surf_vars},
static_vars={k: torch.from_numpy(ds[f"static_{k}"].values) for k in static_vars},
atmos_vars={k: torch.from_numpy(ds[f"atmos_{k}"].values) for k in atmos_vars},
metadata=Metadata(
lat=torch.from_numpy(ds.latitude.values),
lon=torch.from_numpy(ds.longitude.values),
time=tuple(ds.time.values.astype("datetime64[s]").tolist()),
atmos_levels=tuple(ds.level.values),
rollout_step=int(ds.rollout_step.values),
),
)


def _np(x: torch.Tensor) -> np.ndarray:
return x.detach().cpu().numpy()


def interpolate(
v: torch.Tensor,
Expand Down
13 changes: 10 additions & 3 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
dec_mlp_ratio: float = 2.0,
perceiver_ln_eps: float = 1e-5,
max_history_size: int = 2,
timestep: timedelta = timedelta(hours=6),
stabilise_level_agg: bool = False,
use_lora: bool = True,
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
Expand Down Expand Up @@ -96,6 +98,9 @@ def __init__(
max_history_size (int, optional): Maximum number of history steps. You can load
checkpoints with a smaller `max_history_size`, but you cannot load checkpoints
with a larger `max_history_size`.
timestep (timedelta, optional): Timestep of the model. Defaults to 6 hours.
stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an
additional layer normalisation. Defaults to `False`.
use_lora (bool, optional): Use LoRA adaptation.
lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out
steps.
Expand All @@ -115,6 +120,7 @@ def __init__(
self.surf_stats = surf_stats or dict()
self.autocast = autocast
self.max_history_size = max_history_size
self.timestep = timestep

if self.surf_stats:
warnings.warn(
Expand All @@ -138,6 +144,7 @@ def __init__(
latent_levels=latent_levels,
max_history_size=max_history_size,
perceiver_ln_eps=perceiver_ln_eps,
stabilise_level_agg=stabilise_level_agg,
)

self.backbone = Swin3DTransformerBackbone(
Expand Down Expand Up @@ -202,19 +209,19 @@ def forward(self, batch: Batch) -> Batch:

x = self.encoder(
batch,
lead_time=timedelta(hours=6),
lead_time=self.timestep,
)
with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext():
x = self.backbone(
x,
lead_time=timedelta(hours=6),
lead_time=self.timestep,
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)
pred = self.decoder(
x,
batch,
lead_time=timedelta(hours=6),
lead_time=self.timestep,
patch_res=patch_res,
)

Expand Down
4 changes: 4 additions & 0 deletions aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
mlp_ratio: float = 4.0,
max_history_size: int = 2,
perceiver_ln_eps: float = 1e-5,
stabilise_level_agg: bool = False,
) -> None:
"""Initialise.
Expand All @@ -67,6 +68,8 @@ def __init__(
to `2`.
perceiver_ln_eps (float, optional): Epsilon value for layer normalisation in the
Perceiver. Defaults to 1e-5.
stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an
additional layer normalisation. Defaults to `False`.
"""
super().__init__()

Expand Down Expand Up @@ -120,6 +123,7 @@ def __init__(
drop=drop_rate,
mlp_ratio=mlp_ratio,
ln_eps=perceiver_ln_eps,
ln_k_q=stabilise_level_agg,
)

# Drop patches after encoding.
Expand Down
20 changes: 19 additions & 1 deletion aurora/model/perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
context_dim: int,
head_dim: int = 64,
num_heads: int = 8,
ln_k_q: bool = False,
) -> None:
"""Initialise.
Expand All @@ -105,6 +106,7 @@ def __init__(
context_dim (int): Dimensionality of the context features also given as input.
head_dim (int): Attention head dimensionality.
num_heads (int): Number of heads.
ln_k_q (bool): Apply an extra layer norm. to the keys and queries.
"""
super().__init__()
self.num_heads = num_heads
Expand All @@ -115,6 +117,13 @@ def __init__(
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False)

if ln_k_q:
self.ln_k = nn.LayerNorm(num_heads * head_dim)
self.ln_q = nn.LayerNorm(num_heads * head_dim)
else:
self.ln_k = lambda x: x
self.ln_q = lambda x: x

def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Run the cross-attention module.
Expand All @@ -131,6 +140,11 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:

q = self.to_q(latents) # (B, L1, D2) to (B, L1, D)
k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D)

# Apply LN before (!) splitting the heads.
k = self.ln_k(k)
q = self.ln_q(q)

q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))

out = F.scaled_dot_product_attention(q, k, v)
Expand All @@ -152,6 +166,7 @@ def __init__(
drop: float = 0.0,
residual_latent: bool = True,
ln_eps: float = 1e-5,
ln_k_q: bool = False,
) -> None:
"""Initialise.
Expand All @@ -168,13 +183,15 @@ def __init__(
Defaults to `True`.
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
`1e-5`.
ln_k_q (bool, optional): Apply an extra layer norm. to the keys and queries of the first
resampling layer. Defaults to `False`.
"""
super().__init__()

self.residual_latent = residual_latent
self.layers = nn.ModuleList([])
mlp_hidden_dim = int(latent_dim * mlp_ratio)
for _ in range(depth):
for i in range(depth):
self.layers.append(
nn.ModuleList(
[
Expand All @@ -183,6 +200,7 @@ def __init__(
context_dim=context_dim,
head_dim=head_dim,
num_heads=num_heads,
ln_k_q=ln_k_q if i == 0 else False,
),
MLP(dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop),
nn.LayerNorm(latent_dim, eps=ln_eps),
Expand Down
5 changes: 5 additions & 0 deletions docs/beware.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ exactly the right variables
at exactly the right pressure levels
from exactly the right source.

This also means that the performance of the model will be sensitive to how the
data is regridded.
For optimal performance, you should ensure that the data is regridded
exactly like the data seen during pretraining and fine-tuning.

(t0-vs-analysis)=
## HRES IFS T0 Versus HRES IFS Analysis

Expand Down
57 changes: 57 additions & 0 deletions docs/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ loss = ...
loss.backward()
```

## Exploding Gradients

When fine-tuning, you may run into very large gradient values.
Gradient clipping and internal layer normalisation layers mitigate the impact
of large gradients,
meaning that large gradients will not immediately lead to abnormal model outputs and loss values.
Nevertheless, if gradients do blow up, the model will not learn anymore and eventually the loss value
will also blow up.
You should carefully monitor the value of the gradients to detect exploding gradients.

One cause of exploding gradients is too large values for internal activations.
Typically this can be fixed by judiciously inserting a layer normalisation layer.

We have identified the level aggregation as weak point of the model that can be susceptible
to exploding gradients.
You can stabilise the level aggregation of the model
by setting the following flag in the constructor: `stabilise_level_agg=True`.
Note that `stabilise_level_agg=True` will considerably perturb the model,
so significant additional fine-tuning may be required to get to the desired level of performance.

```python
from aurora import Aurora
from aurora.normalisation import locations, scales

model = Aurora(
use_lora=False,
stabilise_level_agg=True, # Insert extra layer norm. to mitigate exploding gradients.
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
```

## Extending Aurora with New Variables

Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`,
Expand Down Expand Up @@ -66,6 +97,18 @@ scales["new_static_var"] = 1.0
scales["new_atmos_var"] = 1.0
```

To more efficiently learn new variables, it is recommended to use a separate learning rate for
the patch embeddings of the new variables in the encoder and decoder.
For example, if you are using Adam, you can try `1e-3` for the new patch embeddings
and `3e-4` for the other parameters.

By default, patch embeddings in the encoder for new variables are initialised randomly.
This means that adding new variables to the model perturbs the predictions for the existing
variables.
If you do not want this, you can alternatively initialise the new patch embeddings in the encoder
to zero.
The relevant parameter dictionaries are `model.encoder.{surf,atmos}_token_embeds.weights`.

## Other Model Extensions

It is possible to extend to model in any way you like.
Expand All @@ -83,3 +126,17 @@ model = Aurora(...)

model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
```

## Triple Check Your Fine-Tuning Data!

When fine-tuning the model, it is absolutely essential to carefully check your fine-tuning data.

* Are the old (and possibly new) normalisation statistics appropriate for the new data?

* Is any data missing?

* Does the data contains zeros or NaNs?

* Does the data contain any outliers that could possibly interfere with fine-tuning?

_Et cetera._
Loading

0 comments on commit e6aeac6

Please sign in to comment.