Skip to content

Commit

Permalink
Merge branch 'ecmwf:develop' into pr/aw_rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
havardhhaugen authored Nov 15, 2024
2 parents d0d2b57 + d0a8866 commit bc91253
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 135 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Save entire config in mlflow
### Added
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102)
- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)
- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137)
- Add without subsetting in ScaleTensor
Expand Down
22 changes: 21 additions & 1 deletion docs/modules/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,32 @@ parameters to plot, as well as the plotting frequency, and
asynchronosity.

Setting ``config.diagnostics.plot.asynchronous``, means that the model
training doesn't stop whilst the callbacks are being evaluated)
training doesn't stop whilst the callbacks are being evaluated. This is
useful for large models where the plotting can take a long time. The
plotting module uses asynchronous callbacks via `asyncio` and
`concurrent.futures.ThreadPoolExecutor` to handle plotting tasks without
blocking the main application. A dedicated event loop runs in a separate
background thread, allowing plotting tasks to be offloaded to worker
threads. This setup keeps the main thread responsive, handling
plot-related tasks asynchronously and efficiently in the background.

There is an additional flag in the plotting callbacks to control the
rendering method for geospatial plots, offering a trade-off between
performance and detail. When `datashader` is set to True, Datashader is
used for rendering, which accelerates plotting through efficient
hexbining, particularly useful for large datasets. This approach can
produce smoother-looking plots due to the aggregation of data points. If
`datashader` is set to False, matplotlib.scatter is used, which provides
sharper and more detailed visuals but may be slower for large datasets.

**Note** - this asynchronous behaviour is only available for the
plotting callbacks.

.. code:: yaml
plot:
asynchronous: True # Whether to plot asynchronously
datashader: True # Whether to use datashader for plotting (faster)
frequency: # Frequency of the plotting
batch: 750
epoch: 5
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"anemoi-graphs>=0.4",
"anemoi-models>=0.3",
"anemoi-utils[provenance]>=0.4.4",
"datashader>=0.16.3",
"einops>=0.6.1",
"hydra-core>=1.3",
"matplotlib>=3.7.1",
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
asynchronous: True # Whether to plot asynchronously
datashader: True # Choose which technique to use for plotting
frequency: # Frequency of the plotting
batch: 750
epoch: 5
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/simple.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
asynchronous: True # Whether to plot asynchronously
datashader: True # Choose which technique to use for plotting
frequency: # Frequency of the plotting
batch: 750
epoch: 10
Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = self.rollout
r = max(r, self.config.dataloader.get("validation_rollout", 1))
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
Expand Down
153 changes: 79 additions & 74 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# ruff: noqa: ANN001

from __future__ import annotations

import asyncio
import copy
import logging
import sys
import threading
import time
import traceback
from abc import ABC
Expand All @@ -23,8 +23,6 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
Expand All @@ -43,33 +41,14 @@
from anemoi.training.losses.weightedloss import BaseWeightedLoss

if TYPE_CHECKING:
from typing import Any

import pytorch_lightning as pl
from omegaconf import OmegaConf

LOGGER = logging.getLogger(__name__)


class ParallelExecutor(ThreadPoolExecutor):
"""Wraps parallel execution and provides accurate information about errors.
Extends ThreadPoolExecutor to preserve the original traceback and line number.
Reference: https://stackoverflow.com/questions/19309514/getting-original-line-
number-for-exception-in-concurrent-futures/24457608#24457608
"""

def submit(self, fn: Any, *args, **kwargs) -> Callable:
"""Submits the wrapped function instead of `fn`."""
return super().submit(self._function_wrapper, fn, *args, **kwargs)

def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable:
"""Wraps `fn` in order to preserve the traceback of any kind of."""
try:
return fn(*args, **kwargs)
except Exception as exc:
raise sys.exc_info()[0](traceback.format_exc()) from exc


class BasePlotCallback(Callback, ABC):
"""Factory for creating a callback that plots data to Experiment Logging."""

Expand All @@ -93,11 +72,21 @@ def __init__(self, config: OmegaConf) -> None:

self.plot = self._plot
self._executor = None
self._error: BaseException = None
self.datashader_plotting = config.diagnostics.plot.datashader

if self.config.diagnostics.plot.asynchronous:
self._executor = ParallelExecutor(max_workers=1)
self._error: BaseException | None = None
LOGGER.info("Setting up asynchronous plotting ...")
self.plot = self._async_plot
self._executor = ThreadPoolExecutor(max_workers=1)
self.loop_thread = threading.Thread(target=self.start_event_loop, daemon=True)
self.loop_thread.start()

def start_event_loop(self) -> None:
"""Start the event loop in a separate thread."""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

@rank_zero_only
def _output_figure(
Expand All @@ -113,27 +102,48 @@ def _output_figure(
save_path = Path(
self.save_basedir,
"plots",
f"{tag}_epoch{epoch:03d}.png",
f"{tag}_epoch{epoch:03d}.jpg",
)

save_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_path, dpi=100, bbox_inches="tight")
fig.canvas.draw()
image_array = np.array(fig.canvas.renderer.buffer_rgba())
plt.imsave(save_path, image_array, dpi=100)
if self.config.diagnostics.log.wandb.enabled:
import wandb

logger.experiment.log({exp_log_tag: wandb.Image(fig)})

if self.config.diagnostics.log.mlflow.enabled:
run_id = logger.run_id
logger.experiment.log_artifact(run_id, str(save_path))

plt.close(fig) # cleanup

@rank_zero_only
def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None:
"""To execute the plot function but ensuring we catch any errors."""
try:
self._plot(trainer, *args, **kwargs)
except BaseException:
import os

LOGGER.exception(traceback.format_exc())
os._exit(1) # to force exit when sanity val steps are used

def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
"""Method is called to close the threads."""
"""Teardown the callback."""
del trainer, pl_module, stage # unused
LOGGER.info("Teardown of the Plot Callback ...")

if self._executor is not None:
self._executor.shutdown(wait=True)
LOGGER.info("waiting and shutting down the executor ...")
self._executor.shutdown(wait=False, cancel_futures=True)

self.loop.call_soon_threadsafe(self.loop.stop)
self.loop_thread.join()
# Step 3: Close the asyncio event loop
self.loop_thread._stop()
self.loop_thread._delete()

def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor:
if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None:
Expand All @@ -147,31 +157,39 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
) -> None:
"""Plotting function to be implemented by subclasses."""

# Async function to run the plot function in the background thread
async def submit_plot(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None:
"""Async function or coroutine to schedule the plot function."""
loop = asyncio.get_running_loop()
# run_in_executor doesn't support keyword arguments,
await loop.run_in_executor(
self._executor,
self._plot_with_error_catching,
trainer,
args,
kwargs,
) # because loop.run_in_executor expects positional arguments, not keyword arguments

@rank_zero_only
def _async_plot(
self,
trainer: pl.Trainer,
*args: list,
**kwargs: dict,
*args: Any,
**kwargs: Any,
) -> None:
"""To execute the plot function but ensuring we catch any errors."""
future = self._executor.submit(
self._plot,
trainer,
*args,
**kwargs,
)
# otherwise the error won't be thrown till the validation epoch is finished
try:
future.result()
except Exception:
LOGGER.exception("Critical error occurred in asynchronous plots.")
sys.exit(1)
"""Run the plot function asynchronously.
This is the function that is called by the callback. It schedules the plot
function to run in the background thread. Since we have an event loop running in
the background thread, we need to schedule the plot function to run in that
loop.
"""
asyncio.run_coroutine_threadsafe(self.submit_plot(trainer, *args, **kwargs), self.loop)


class BasePerBatchPlotCallback(BasePlotCallback):
Expand All @@ -192,26 +210,12 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None):
super().__init__(config)
self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch

@abstractmethod
@rank_zero_only
def _plot(
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
**kwargs,
) -> None:
"""Plotting function to be implemented by subclasses."""

@rank_zero_only
def on_validation_batch_end(
self,
trainer,
pl_module,
output,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
**kwargs,
Expand Down Expand Up @@ -310,12 +314,12 @@ def __init__(
@rank_zero_only
def _plot(
self,
trainer,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx,
epoch,
batch_idx: int,
epoch: int,
) -> None:
_ = output

Expand Down Expand Up @@ -406,9 +410,9 @@ def _plot(
@rank_zero_only
def on_validation_batch_end(
self,
trainer,
pl_module,
output,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
) -> None:
Expand Down Expand Up @@ -454,7 +458,7 @@ def _plot(
_ = epoch
model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model

fig = plot_graph_node_features(model)
fig = plot_graph_node_features(model, datashader=self.datashader_plotting)

self._output_figure(
trainer.logger,
Expand Down Expand Up @@ -750,6 +754,7 @@ def _plot(
data[0, ...].squeeze(),
data[rollout_step + 1, ...].squeeze(),
output_tensor[rollout_step, ...],
datashader=self.datashader_plotting,
precip_and_related_fields=self.precip_and_related_fields,
)

Expand Down Expand Up @@ -839,7 +844,7 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
Expand Down Expand Up @@ -921,7 +926,7 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/diagnostics/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self) -> None:
def __call__(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
lon_rad = np.radians(lon)
lat_rad = np.radians(lat)
x = [v - 2 * np.pi if v > np.pi else v for v in lon_rad]
x = np.array([v - 2 * np.pi if v > np.pi else v for v in lon_rad], dtype=lon_rad.dtype)
y = lat_rad
return x, y

Expand Down
Loading

0 comments on commit bc91253

Please sign in to comment.