diff --git a/CHANGELOG.md b/CHANGELOG.md index 446bda1..6945ecd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you! - Add CONTRIBUTORS.md file (#36) - Add sanetise command - Add support for huggingface +- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55) ### Changed - Change `write_initial_state` default value to `true` diff --git a/docs/parallel.rst b/docs/parallel.rst new file mode 100644 index 0000000..1914087 --- /dev/null +++ b/docs/parallel.rst @@ -0,0 +1,70 @@ +#################### + Parallel Inference +#################### + +If the memory requirements of your model are too large to fit within a +single GPU, you can run Anemoi-Inference in parallel across multiple +GPUs. + +Parallel inference requires SLURM to launch the parallel processes and +to determine information about your network environment. If SLURM is not +available to you, please create an issue on the Anemoi-Inference github +page `here `_. + +*************** + Configuration +*************** + +To run in parallel, you must add '`runner:parallel`' to your inference +config file. + +.. code:: yaml + + checkpoint: /path/to/inference-last.ckpt + lead_time: 60 + runner: parallel + input: + grib: /path/to/input.grib + output: + grib: /path/to/output.grib + +******************************* + Running inference in parallel +******************************* + +Below is an example SLURM batch script to launch a parallel inference +job across 4 GPUs. + +.. code:: bash + + #!/bin/bash + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=4 + #SBATCH --gpus-per-node=4 + #SBATCH --cpus-per-task=8 + #SBATCH --time=0:05:00 + #SBATCH --output=outputs/parallel_inf.%j.out + + source /path/to/venv/bin/activate + srun anemoi-inference run parallel.yaml + +.. warning:: + + If you specify '`runner:parallel`' but you don't launch with + '`srun`', your anemoi-inference job may hang as only 1 process will + be launched. + +.. note:: + + By default, anemoi-inference will determine your systems master + address and port itself. If this fails (i.e. when running + Anemoi-Inference inside a container), you can instead set these + values yourself via environment variables in your SLURM batch script: + + .. code:: bash + + MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1) + export MASTER_ADDR=$(nslookup $MASTER_ADDR | grep -oP '(?<=Address: ).*') + export MASTER_PORT=$((10000 + RANDOM % 10000)) + + srun anemoi-inference run parallel.yaml diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 4ba0a8a..13957f8 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -251,11 +251,10 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # Create pytorch input tensor input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device) - LOG.info("Using autocast %s", self.autocast) - lead_time = to_timedelta(lead_time) steps = lead_time // self.checkpoint.timestep + LOG.info("Using autocast %s", self.autocast) LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps) result = input_state.copy() # We should not modify the input state diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py new file mode 100644 index 0000000..f77a19d --- /dev/null +++ b/src/anemoi/inference/runners/parallel.py @@ -0,0 +1,156 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +import os +import socket +import subprocess + +import numpy as np +import torch +import torch.distributed as dist + +from ..outputs import create_output +from . import runner_registry +from .default import DefaultRunner + +LOG = logging.getLogger(__name__) + + +@runner_registry.register("parallel") +class ParallelRunner(DefaultRunner): + """Runner which splits a model over multiple devices""" + + def __init__(self, context): + super().__init__(context) + global_rank, local_rank, world_size = self.__get_parallel_info() + self.global_rank = global_rank + self.local_rank = local_rank + self.world_size = world_size + + if self.device == "cuda": + self.device = f"{self.device}:{local_rank}" + torch.cuda.set_device(local_rank) + + # disable most logging on non-zero ranks + if self.global_rank != 0: + logging.getLogger().setLevel(logging.WARNING) + + # Create a model comm group for parallel inference + # A dummy comm group is created if only a single device is in use + model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size) + self.model_comm_group = model_comm_group + + def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): + if self.model_comm_group is None: + return model.predict_step(input_tensor_torch) + else: + try: + return model.predict_step(input_tensor_torch, self.model_comm_group) + except TypeError as err: + LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference") + raise err + + def create_output(self): + if self.global_rank == 0: + output = create_output(self, self.config.output) + LOG.info("Output: %s", output) + return output + else: + output = create_output(self, "none") + return output + + def __del__(self): + if self.model_comm_group is not None: + dist.destroy_process_group() + + def __init_network(self): + """Reads Slurm environment to set master address and port for parallel communication""" + + # Get the master address from the SLURM_NODELIST environment variable + slurm_nodelist = os.environ.get("SLURM_NODELIST") + if not slurm_nodelist: + raise ValueError("SLURM_NODELIST environment variable is not set.") + + # Check if MASTER_ADDR is given, otherwise try set it using 'scontrol' + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + LOG.debug("'MASTER_ADDR' environment variable not set. Trying to set via SLURM") + try: + result = subprocess.run( + ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True + ) + except subprocess.CalledProcessError as err: + LOG.error( + "Python could not execute 'scontrol show hostname $SLURM_NODELIST' while calculating MASTER_ADDR. You could avoid this error by setting the MASTER_ADDR env var manually." + ) + raise err + + master_addr = result.stdout.splitlines()[0] + + # Resolve the master address using nslookup + try: + master_addr = socket.gethostbyname(master_addr) + except socket.gaierror: + raise ValueError(f"Could not resolve hostname: {master_addr}") + + # Check if MASTER_PORT is given, otherwise generate one based on SLURM_JOBID + master_port = os.environ.get("MASTER_PORT") + if master_port is None: + LOG.debug("'MASTER_PORT' environment variable not set. Trying to set via SLURM") + slurm_jobid = os.environ.get("SLURM_JOBID") + if not slurm_jobid: + raise ValueError("SLURM_JOBID environment variable is not set.") + + master_port = str(10000 + int(slurm_jobid[-4:])) + + # Print the results for confirmation + LOG.debug(f"MASTER_ADDR: {master_addr}") + LOG.debug(f"MASTER_PORT: {master_port}") + + return master_addr, master_port + + def __init_parallel(self, device, global_rank, world_size): + """Creates a model communication group to be used for parallel inference""" + + if world_size > 1: + + master_addr, master_port = self.__init_network() + + # use 'startswith' instead of '==' in case device is 'cuda:0' + if device.startswith("cuda"): + backend = "nccl" + else: + backend = "gloo" + + dist.init_process_group( + backend=backend, + init_method=f"tcp://{master_addr}:{master_port}", + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, + ) + LOG.info(f"Creating a model comm group with {world_size} devices with the {backend} backend") + + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + else: + model_comm_group = None + + return model_comm_group + + def __get_parallel_info(self): + """Reads Slurm env vars, if they exist, to determine if inference is running in parallel""" + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus + global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes + world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes + + return global_rank, local_rank, world_size