From e6aeac6777ecbedef77bfe2d25a87a843ba1e5f8 Mon Sep 17 00:00:00 2001 From: Wessel Date: Sun, 22 Dec 2024 13:10:07 +0100 Subject: [PATCH] Minor changes for revision of paper (#56) * Update citation and mention wave model * Expose internal timestep of the model * Fix title in README * Expose level aggregation stabilisation tweak * Test that changing the new flags changes the output * Add more fine-tuning advice * Make naming more consistent * Test that model runs under DDP wrapping * Add `Batch.{to,from}_netcdf` --- README.md | 21 ++++++----- aurora/batch.py | 77 ++++++++++++++++++++++++++++++++++++++- aurora/model/aurora.py | 13 +++++-- aurora/model/encoder.py | 4 ++ aurora/model/perceiver.py | 20 +++++++++- docs/beware.md | 5 +++ docs/finetuning.md | 57 +++++++++++++++++++++++++++++ docs/intro.md | 11 +++--- pyproject.toml | 2 + tests/test_batch.py | 22 +++++++++++ tests/test_model.py | 68 +++++++++++++++++++++++++++++++++- 11 files changed, 279 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 80e1de1..ed93cb5 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,16 @@ Aurora logo

-# Aurora: A Foundation Model of the Atmosphere +# Aurora: A Foundation Model for the Earth System [![CI](https://github.com/microsoft/Aurora/actions/workflows/ci.yaml/badge.svg)](https://github.com/microsoft/Aurora/actions/workflows/ci.yaml) [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://microsoft.github.io/aurora) [![Paper](https://img.shields.io/badge/arXiv-2405.13063-blue)](https://arxiv.org/abs/2405.13063) -Implementation of the Aurora model for atmospheric forecasting. +Implementation of the Aurora model for Earth system forecasting. _The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting._ -_We are working on the fine-tuned version for air pollution forecasting, which will be included in due time._ +_We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time._ [Link to the paper on arXiv.](https://arxiv.org/abs/2405.13063) @@ -22,8 +22,8 @@ Cite us as follows: ``` @misc{bodnar2024aurora, - title = {Aurora: A Foundation Model of the Atmosphere}, - author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan Weyn and Haiyu Dong and Anna Vaughan and Jayesh K. Gupta and Kit Tambiratnam and Alex Archibald and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris}, + title = {Aurora: A Foundation Model for the Earth System}, + author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Anna Vaughan and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan A. Weyn and Haiyu Dong and Jayesh K. Gupta and Kit Thambiratnam and Alexander T. Archibald and Chun-Chieh Wu and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris}, year = {2024}, url = {https://arxiv.org/abs/2405.13063}, eprint = {2405.13063}, @@ -48,10 +48,11 @@ Contents: Aurora is a machine learning model that can predict atmospheric variables, such as temperature. It is a _foundation model_, which means that it was first generally trained on a lot of data, and then can be adapted to specialised atmospheric forecasting tasks with relatively little data. -We provide three such specialised versions: +We provide four such specialised versions: one for medium-resolution weather prediction, one for high-resolution weather prediction, -and one for air pollution prediction. +one for air pollution prediction, +and one for ocean wave prediction. ## Getting Started @@ -127,7 +128,7 @@ Our goal in publishing this code is This code has not been developed nor tested for non-academic purposes and hence should not be used as such. ### Limitations -Although Aurora was trained to accurately predict future weather and air pollution, +Although Aurora was trained to accurately predict future weather, air pollution, and ocean waves, Aurora is based on neural networks, which means that there are no strict guarantees that predictions will always be accurate. Altering the inputs, providing a sample that was not in the training set, or even providing a sample that was in the training set but is simply unlucky may result in arbitrarily poor predictions. @@ -183,7 +184,7 @@ make docs To locally view the documentation, open `docs/_build/index.html` in your browser. -### Why is the fine-tuned version of Aurora for air quality forecasting missing? +### Why are the fine-tuned versions of Aurora for air quality and ocean wave forecasting missing? The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting. -We are working on the fine-tuned version for air pollution forecasting, which will be included in due time. +We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time. diff --git a/aurora/batch.py b/aurora/batch.py index 9f96356..3221211 100644 --- a/aurora/batch.py +++ b/aurora/batch.py @@ -3,7 +3,8 @@ import dataclasses from datetime import datetime from functools import partial -from typing import Callable +from pathlib import Path +from typing import Callable, List import numpy as np import torch @@ -220,6 +221,80 @@ def regrid(self, res: float) -> "Batch": ), ) + def to_netcdf(self, path: str | Path) -> None: + """Write the batch to a file. + + This requires `xarray` and `netcdf4` to be installed. + """ + try: + import xarray as xr + except ImportError as e: + raise RuntimeError("`xarray` must be installed.") from e + + ds = xr.Dataset( + { + **{ + f"surf_{k}": (("batch", "history", "latitude", "longitude"), _np(v)) + for k, v in self.surf_vars.items() + }, + **{ + f"static_{k}": (("latitude", "longitude"), _np(v)) + for k, v in self.static_vars.items() + }, + **{ + f"atmos_{k}": (("batch", "history", "level", "latitude", "longitude"), _np(v)) + for k, v in self.atmos_vars.items() + }, + }, + coords={ + "latitude": _np(self.metadata.lat), + "longitude": _np(self.metadata.lon), + "time": list(self.metadata.time), + "level": list(self.metadata.atmos_levels), + "rollout_step": self.metadata.rollout_step, + }, + ) + ds.to_netcdf(path) + + @classmethod + def from_netcdf(cls, path: str | Path) -> "Batch": + """Load a batch from a file.""" + try: + import xarray as xr + except ImportError as e: + raise RuntimeError("`xarray` must be installed.") from e + + ds = xr.load_dataset(path, engine="netcdf4") + + surf_vars: List[str] = [] + static_vars: List[str] = [] + atmos_vars: List[str] = [] + + for k in ds: + if k.startswith("surf_"): + surf_vars.append(k.removeprefix("surf_")) + elif k.startswith("static_"): + static_vars.append(k.removeprefix("static_")) + elif k.startswith("atmos_"): + atmos_vars.append(k.removeprefix("atmos_")) + + return Batch( + surf_vars={k: torch.from_numpy(ds[f"surf_{k}"].values) for k in surf_vars}, + static_vars={k: torch.from_numpy(ds[f"static_{k}"].values) for k in static_vars}, + atmos_vars={k: torch.from_numpy(ds[f"atmos_{k}"].values) for k in atmos_vars}, + metadata=Metadata( + lat=torch.from_numpy(ds.latitude.values), + lon=torch.from_numpy(ds.longitude.values), + time=tuple(ds.time.values.astype("datetime64[s]").tolist()), + atmos_levels=tuple(ds.level.values), + rollout_step=int(ds.rollout_step.values), + ), + ) + + +def _np(x: torch.Tensor) -> np.ndarray: + return x.detach().cpu().numpy() + def interpolate( v: torch.Tensor, diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 17d7e39..fba4594 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -50,6 +50,8 @@ def __init__( dec_mlp_ratio: float = 2.0, perceiver_ln_eps: float = 1e-5, max_history_size: int = 2, + timestep: timedelta = timedelta(hours=6), + stabilise_level_agg: bool = False, use_lora: bool = True, lora_steps: int = 40, lora_mode: LoRAMode = "single", @@ -96,6 +98,9 @@ def __init__( max_history_size (int, optional): Maximum number of history steps. You can load checkpoints with a smaller `max_history_size`, but you cannot load checkpoints with a larger `max_history_size`. + timestep (timedelta, optional): Timestep of the model. Defaults to 6 hours. + stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an + additional layer normalisation. Defaults to `False`. use_lora (bool, optional): Use LoRA adaptation. lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out steps. @@ -115,6 +120,7 @@ def __init__( self.surf_stats = surf_stats or dict() self.autocast = autocast self.max_history_size = max_history_size + self.timestep = timestep if self.surf_stats: warnings.warn( @@ -138,6 +144,7 @@ def __init__( latent_levels=latent_levels, max_history_size=max_history_size, perceiver_ln_eps=perceiver_ln_eps, + stabilise_level_agg=stabilise_level_agg, ) self.backbone = Swin3DTransformerBackbone( @@ -202,19 +209,19 @@ def forward(self, batch: Batch) -> Batch: x = self.encoder( batch, - lead_time=timedelta(hours=6), + lead_time=self.timestep, ) with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext(): x = self.backbone( x, - lead_time=timedelta(hours=6), + lead_time=self.timestep, patch_res=patch_res, rollout_step=batch.metadata.rollout_step, ) pred = self.decoder( x, batch, - lead_time=timedelta(hours=6), + lead_time=self.timestep, patch_res=patch_res, ) diff --git a/aurora/model/encoder.py b/aurora/model/encoder.py index bcbdafa..c2b4c87 100644 --- a/aurora/model/encoder.py +++ b/aurora/model/encoder.py @@ -43,6 +43,7 @@ def __init__( mlp_ratio: float = 4.0, max_history_size: int = 2, perceiver_ln_eps: float = 1e-5, + stabilise_level_agg: bool = False, ) -> None: """Initialise. @@ -67,6 +68,8 @@ def __init__( to `2`. perceiver_ln_eps (float, optional): Epsilon value for layer normalisation in the Perceiver. Defaults to 1e-5. + stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an + additional layer normalisation. Defaults to `False`. """ super().__init__() @@ -120,6 +123,7 @@ def __init__( drop=drop_rate, mlp_ratio=mlp_ratio, ln_eps=perceiver_ln_eps, + ln_k_q=stabilise_level_agg, ) # Drop patches after encoding. diff --git a/aurora/model/perceiver.py b/aurora/model/perceiver.py index dbb4e92..eda3122 100644 --- a/aurora/model/perceiver.py +++ b/aurora/model/perceiver.py @@ -97,6 +97,7 @@ def __init__( context_dim: int, head_dim: int = 64, num_heads: int = 8, + ln_k_q: bool = False, ) -> None: """Initialise. @@ -105,6 +106,7 @@ def __init__( context_dim (int): Dimensionality of the context features also given as input. head_dim (int): Attention head dimensionality. num_heads (int): Number of heads. + ln_k_q (bool): Apply an extra layer norm. to the keys and queries. """ super().__init__() self.num_heads = num_heads @@ -115,6 +117,13 @@ def __init__( self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False) self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False) + if ln_k_q: + self.ln_k = nn.LayerNorm(num_heads * head_dim) + self.ln_q = nn.LayerNorm(num_heads * head_dim) + else: + self.ln_k = lambda x: x + self.ln_q = lambda x: x + def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """Run the cross-attention module. @@ -131,6 +140,11 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: q = self.to_q(latents) # (B, L1, D2) to (B, L1, D) k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D) + + # Apply LN before (!) splitting the heads. + k = self.ln_k(k) + q = self.ln_q(q) + q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v)) out = F.scaled_dot_product_attention(q, k, v) @@ -152,6 +166,7 @@ def __init__( drop: float = 0.0, residual_latent: bool = True, ln_eps: float = 1e-5, + ln_k_q: bool = False, ) -> None: """Initialise. @@ -168,13 +183,15 @@ def __init__( Defaults to `True`. ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to `1e-5`. + ln_k_q (bool, optional): Apply an extra layer norm. to the keys and queries of the first + resampling layer. Defaults to `False`. """ super().__init__() self.residual_latent = residual_latent self.layers = nn.ModuleList([]) mlp_hidden_dim = int(latent_dim * mlp_ratio) - for _ in range(depth): + for i in range(depth): self.layers.append( nn.ModuleList( [ @@ -183,6 +200,7 @@ def __init__( context_dim=context_dim, head_dim=head_dim, num_heads=num_heads, + ln_k_q=ln_k_q if i == 0 else False, ), MLP(dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop), nn.LayerNorm(latent_dim, eps=ln_eps), diff --git a/docs/beware.md b/docs/beware.md index f8235a1..6141661 100644 --- a/docs/beware.md +++ b/docs/beware.md @@ -17,6 +17,11 @@ exactly the right variables at exactly the right pressure levels from exactly the right source. +This also means that the performance of the model will be sensitive to how the +data is regridded. +For optimal performance, you should ensure that the data is regridded +exactly like the data seen during pretraining and fine-tuning. + (t0-vs-analysis)= ## HRES IFS T0 Versus HRES IFS Analysis diff --git a/docs/finetuning.md b/docs/finetuning.md index 266f894..4787242 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -37,6 +37,37 @@ loss = ... loss.backward() ``` +## Exploding Gradients + +When fine-tuning, you may run into very large gradient values. +Gradient clipping and internal layer normalisation layers mitigate the impact +of large gradients, +meaning that large gradients will not immediately lead to abnormal model outputs and loss values. +Nevertheless, if gradients do blow up, the model will not learn anymore and eventually the loss value +will also blow up. +You should carefully monitor the value of the gradients to detect exploding gradients. + +One cause of exploding gradients is too large values for internal activations. +Typically this can be fixed by judiciously inserting a layer normalisation layer. + +We have identified the level aggregation as weak point of the model that can be susceptible +to exploding gradients. +You can stabilise the level aggregation of the model +by setting the following flag in the constructor: `stabilise_level_agg=True`. +Note that `stabilise_level_agg=True` will considerably perturb the model, +so significant additional fine-tuning may be required to get to the desired level of performance. + +```python +from aurora import Aurora +from aurora.normalisation import locations, scales + +model = Aurora( + use_lora=False, + stabilise_level_agg=True, # Insert extra layer norm. to mitigate exploding gradients. +) +model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False) +``` + ## Extending Aurora with New Variables Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`, @@ -66,6 +97,18 @@ scales["new_static_var"] = 1.0 scales["new_atmos_var"] = 1.0 ``` +To more efficiently learn new variables, it is recommended to use a separate learning rate for +the patch embeddings of the new variables in the encoder and decoder. +For example, if you are using Adam, you can try `1e-3` for the new patch embeddings +and `3e-4` for the other parameters. + +By default, patch embeddings in the encoder for new variables are initialised randomly. +This means that adding new variables to the model perturbs the predictions for the existing +variables. +If you do not want this, you can alternatively initialise the new patch embeddings in the encoder +to zero. +The relevant parameter dictionaries are `model.encoder.{surf,atmos}_token_embeds.weights`. + ## Other Model Extensions It is possible to extend to model in any way you like. @@ -83,3 +126,17 @@ model = Aurora(...) model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False) ``` + +## Triple Check Your Fine-Tuning Data! + +When fine-tuning the model, it is absolutely essential to carefully check your fine-tuning data. + +* Are the old (and possibly new) normalisation statistics appropriate for the new data? + +* Is any data missing? + +* Does the data contains zeros or NaNs? + +* Does the data contain any outliers that could possibly interfere with fine-tuning? + +_Et cetera._ diff --git a/docs/intro.md b/docs/intro.md index 1219874..5c55f62 100644 --- a/docs/intro.md +++ b/docs/intro.md @@ -8,20 +8,21 @@ For details on how exactly the model works, [please see the paper on arXiv.](htt Aurora is a machine learning model that can predict atmospheric variables, such as temperature. It is a _foundation model_, which means that it was first generally trained on a lot of data, and then can adapted to specialised atmospheric forecasting tasks with relatively little data. -We provide three such specialised versions: +We provide four such specialised versions: one for medium-resolution weather prediction, one for high-resolution weather prediction, -and one for air pollution prediction. +one for air pollution prediction, +and one for ocean wave prediction. The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting. -We are working on the fine-tuned version for air pollution forecasting, which will be included in due time. +We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time. Cite us as follows: ``` @misc{bodnar2024aurora, - title = {Aurora: A Foundation Model of the Atmosphere}, - author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan Weyn and Haiyu Dong and Anna Vaughan and Jayesh K. Gupta and Kit Tambiratnam and Alex Archibald and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris}, + title = {Aurora: A Foundation Model for the Earth System}, + author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Anna Vaughan and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan A. Weyn and Haiyu Dong and Jayesh K. Gupta and Kit Thambiratnam and Alexander T. Archibald and Chun-Chieh Wu and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris}, year = {2024}, url = {https://arxiv.org/abs/2405.13063}, eprint = {2405.13063}, diff --git a/pyproject.toml b/pyproject.toml index 92c8861..50c76e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,8 @@ dev = [ "pre-commit", "jupyter-book", "scipy", + "xarray", + "netcdf4", ] [project.urls] diff --git a/tests/test_batch.py b/tests/test_batch.py index 7c68716..c1b1f0c 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1,5 +1,7 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +from pathlib import Path + import numpy as np from tests.conftest import SavedBatch @@ -35,3 +37,23 @@ def test_interpolation(test_input_output: tuple[Batch, SavedBatch]) -> None: np.testing.assert_allclose(batch.metadata.lat, batch_regridded.metadata.lat, atol=1e-10) np.testing.assert_allclose(batch.metadata.lon, batch_regridded.metadata.lon, atol=1e-10) + + +def test_save_load(test_input_output: tuple[Batch, SavedBatch], tmp_path: Path) -> None: + batch, _ = test_input_output + + batch.to_netcdf(tmp_path / "batch.nc") + batch_loaded = Batch.from_netcdf(tmp_path / "batch.nc") + + for k in batch.surf_vars: + np.testing.assert_allclose(batch.surf_vars[k], batch_loaded.surf_vars[k]) + for k in batch.static_vars: + np.testing.assert_allclose(batch.static_vars[k], batch_loaded.static_vars[k]) + for k in batch.atmos_vars: + np.testing.assert_allclose(batch.atmos_vars[k], batch_loaded.atmos_vars[k]) + + np.testing.assert_allclose(batch.metadata.lat, batch_loaded.metadata.lat) + np.testing.assert_allclose(batch.metadata.lon, batch_loaded.metadata.lon) + assert batch.metadata.time == batch_loaded.metadata.time + assert batch.metadata.atmos_levels == batch_loaded.metadata.atmos_levels + assert batch.metadata.rollout_step == batch_loaded.metadata.rollout_step diff --git a/tests/test_model.py b/tests/test_model.py index be006e6..a0e00e0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,15 +1,19 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +import os +from datetime import timedelta + import numpy as np import pytest import torch +import torch.distributed as dist from tests.conftest import SavedBatch from aurora import Aurora, AuroraSmall, Batch -@pytest.fixture() +@pytest.fixture(scope="session") def aurora_small() -> Aurora: model = AuroraSmall(use_lora=True) model.load_checkpoint( @@ -72,6 +76,23 @@ def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> assert pred.metadata.time == tuple(test_output["metadata"]["time"]) +def test_aurora_small_ddp( + aurora_small: Aurora, test_input_output: tuple[Batch, SavedBatch] +) -> None: + batch, test_output = test_input_output + + if not dist.is_initialized(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("gloo", rank=0, world_size=1) + + aurora_small = torch.nn.parallel.DistributedDataParallel(aurora_small) + + # Just test that it runs. + with torch.inference_mode(): + aurora_small.forward(batch) + + def test_aurora_small_decoder_init() -> None: aurora_small = AuroraSmall(use_lora=True) @@ -120,3 +141,48 @@ def test_aurora_small_lat_lon_matrices( pred_matrix.atmos_vars[k], rtol=1e-5, ) + + +def test_aurora_small_flags(test_input_output: tuple[Batch, SavedBatch]) -> None: + batch, test_output = test_input_output + + flag_collections: list[dict] = [ + {}, + {"stabilise_level_agg": True}, + {"timestep": timedelta(hours=12)}, + ] + + preds = [] + for flags in flag_collections: + model = AuroraSmall(use_lora=True, **flags) + model.load_checkpoint( + "microsoft/aurora", + "aurora-0.25-small-pretrained.ckpt", + strict=False, # LoRA parameters not available. + ) + model = model.double() + model.eval() + with torch.inference_mode(): + preds.append(model.forward(batch).normalise(model.surf_stats)) + + # Check that all predictions are different. + for i, pred1 in enumerate(preds): + for pred2 in preds[i + 1 :]: + for k in pred1.surf_vars: + assert not np.allclose( + pred1.surf_vars[k], + pred2.surf_vars[k], + rtol=5e-2, + ) + for k in pred1.static_vars: + np.testing.assert_allclose( + pred1.static_vars[k], + pred2.static_vars[k], + rtol=1e-5, + ) + for k in pred1.atmos_vars: + assert not np.allclose( + pred1.atmos_vars[k], + pred2.atmos_vars[k], + rtol=5e-2, + )