Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel inference #108

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ce74e6e
model parallel wip
cathalobrien Nov 25, 2024
936c60a
logging only on rank 0
cathalobrien Nov 26, 2024
d870289
fallback if env vars arent set and some work only done by rank 0
cathalobrien Nov 26, 2024
b39b796
changelog
cathalobrien Nov 26, 2024
b95e167
pre-commit checks and no model comm group for single gpu case
cathalobrien Nov 26, 2024
9fe691c
changelog
cathalobrien Nov 26, 2024
5f92574
added parallel inf
cathalobrien Jan 14, 2025
71fdf0e
precommit
cathalobrien Jan 14, 2025
9264754
9k parallel inference works
cathalobrien Jan 15, 2025
06a575d
refactor
cathalobrien Jan 15, 2025
fa89bb8
refactor
cathalobrien Jan 15, 2025
a6a4ea4
tidy
cathalobrien Jan 16, 2025
8a73f62
more compatible with older versions of models
cathalobrien Jan 16, 2025
db560eb
forgot precommit
cathalobrien Jan 16, 2025
b21d811
remove commented code
cathalobrien Jan 16, 2025
48ad37b
added license
cathalobrien Jan 16, 2025
b9ecc14
feedback
cathalobrien Jan 16, 2025
1a0ae49
Merge remote-tracking branch 'origin/develop' into feature/model-para…
cathalobrien Jan 17, 2025
27965ff
refactor to parallel runner
cathalobrien Jan 17, 2025
43167c5
refactored into explicit parallel runner class
cathalobrien Jan 17, 2025
6974ac3
allow MASTER_ADDR and MASTER_PORT to be set as env vars before runtime
cathalobrien Jan 20, 2025
2016c7b
readd line accicdentally deleted
cathalobrien Jan 21, 2025
bd391f5
added documentation
cathalobrien Jan 21, 2025
1cd4982
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2025
079036a
forgot precommit
cathalobrien Jan 21, 2025
d6a77ff
Merge branch 'feature/model-parallel' of github.com:ecmwf/anemoi-infe…
cathalobrien Jan 21, 2025
b8be926
docs feedback
cathalobrien Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55)

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
5 changes: 1 addition & 4 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,15 @@ def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
return model.predict_step(input_tensor_torch)

def forecast(self, lead_time, input_tensor_numpy, input_state):
self.model.eval()
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

torch.set_grad_enabled(False)

# Create pytorch input tensor
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device)

LOG.info("Using autocast %s", self.autocast)

lead_time = to_timedelta(lead_time)
steps = lead_time // self.checkpoint.timestep

LOG.info("Using autocast %s", self.autocast)
LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)

result = input_state.copy() # We should not modify the input state
Expand Down
146 changes: 146 additions & 0 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import datetime
import logging
import os
import socket
import subprocess

import numpy as np
import torch
import torch.distributed as dist

from ..outputs import create_output
from . import runner_registry
from .default import DefaultRunner

LOG = logging.getLogger(__name__)


@runner_registry.register("parallel")
class ParallelRunner(DefaultRunner):
"""Runner which splits a model over multiple devices"""

def __init__(self, context):
super().__init__(context)
global_rank, local_rank, world_size = self.__get_parallel_info()
self.global_rank = global_rank
self.local_rank = local_rank
self.world_size = world_size

if self.device == "cuda":
self.device = f"{self.device}:{local_rank}"
torch.cuda.set_device(local_rank)

# disable most logging on non-zero ranks
if self.global_rank != 0:
logging.getLogger().setLevel(logging.WARNING)

# Create a model comm group for parallel inference
# A dummy comm group is created if only a single device is in use
model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size)
self.model_comm_group = model_comm_group

def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
if self.model_comm_group is None:
return model.predict_step(input_tensor_torch)
else:
try:
return model.predict_step(input_tensor_torch, self.model_comm_group)
except TypeError as err:
LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference")
raise err

def create_output(self):
if self.global_rank == 0:
output = create_output(self, self.config.output)
LOG.info("Output: %s", output)
return output
else:
output = create_output(self, "none")
return output

def __del__(self):
if self.model_comm_group is not None:
dist.destroy_process_group()

def __init_network(self):
"""Reads Slurm environment to set master address and port for parallel communication"""

# Get the master address from the SLURM_NODELIST environment variable
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if not slurm_nodelist:
raise ValueError("SLURM_NODELIST environment variable is not set.")

# Use subprocess to execute scontrol and get the first hostname
result = subprocess.run(
["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
)
master_addr = result.stdout.splitlines()[0]

# Resolve the master address using nslookup
try:
resolved_addr = socket.gethostbyname(master_addr)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {master_addr}")

# Set the resolved address as MASTER_ADDR
master_addr = resolved_addr

# Calculate the MASTER_PORT using SLURM_JOBID
slurm_jobid = os.environ.get("SLURM_JOBID")
if not slurm_jobid:
raise ValueError("SLURM_JOBID environment variable is not set.")

master_port = str(10000 + int(slurm_jobid[-4:]))

# Print the results for confirmation
LOG.debug(f"MASTER_ADDR: {master_addr}")
LOG.debug(f"MASTER_PORT: {master_port}")

return master_addr, master_port

def __init_parallel(self, device, global_rank, world_size):
"""Creates a model communication group to be used for parallel inference"""

if world_size > 1:

master_addr, master_port = self.__init_network()

# use 'startswith' instead of '==' in case device is 'cuda:0'
if device.startswith("cuda"):
backend = "nccl"
else:
backend = "gloo"

dist.init_process_group(
backend=backend,
init_method=f"tcp://{master_addr}:{master_port}",
timeout=datetime.timedelta(minutes=3),
world_size=world_size,
rank=global_rank,
)
LOG.info(f"Creating a model comm group with {world_size} devices with the {backend} backend")

model_comm_group_ranks = np.arange(world_size, dtype=int)
model_comm_group = torch.distributed.new_group(model_comm_group_ranks)
else:
model_comm_group = None

return model_comm_group

def __get_parallel_info(self):
"""Reads Slurm env vars, if they exist, to determine if inference is running in parallel"""
local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus
global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes

return global_rank, local_rank, world_size
Loading