Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel inference #108

Open
wants to merge 27 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ce74e6e
model parallel wip
cathalobrien Nov 25, 2024
936c60a
logging only on rank 0
cathalobrien Nov 26, 2024
d870289
fallback if env vars arent set and some work only done by rank 0
cathalobrien Nov 26, 2024
b39b796
changelog
cathalobrien Nov 26, 2024
b95e167
pre-commit checks and no model comm group for single gpu case
cathalobrien Nov 26, 2024
9fe691c
changelog
cathalobrien Nov 26, 2024
5f92574
added parallel inf
cathalobrien Jan 14, 2025
71fdf0e
precommit
cathalobrien Jan 14, 2025
9264754
9k parallel inference works
cathalobrien Jan 15, 2025
06a575d
refactor
cathalobrien Jan 15, 2025
fa89bb8
refactor
cathalobrien Jan 15, 2025
a6a4ea4
tidy
cathalobrien Jan 16, 2025
8a73f62
more compatible with older versions of models
cathalobrien Jan 16, 2025
db560eb
forgot precommit
cathalobrien Jan 16, 2025
b21d811
remove commented code
cathalobrien Jan 16, 2025
48ad37b
added license
cathalobrien Jan 16, 2025
b9ecc14
feedback
cathalobrien Jan 16, 2025
1a0ae49
Merge remote-tracking branch 'origin/develop' into feature/model-para…
cathalobrien Jan 17, 2025
27965ff
refactor to parallel runner
cathalobrien Jan 17, 2025
43167c5
refactored into explicit parallel runner class
cathalobrien Jan 17, 2025
6974ac3
allow MASTER_ADDR and MASTER_PORT to be set as env vars before runtime
cathalobrien Jan 20, 2025
2016c7b
readd line accicdentally deleted
cathalobrien Jan 21, 2025
bd391f5
added documentation
cathalobrien Jan 21, 2025
1cd4982
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2025
079036a
forgot precommit
cathalobrien Jan 21, 2025
d6a77ff
Merge branch 'feature/model-parallel' of github.com:ecmwf/anemoi-infe…
cathalobrien Jan 21, 2025
b8be926
docs feedback
cathalobrien Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55)

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
96 changes: 96 additions & 0 deletions src/anemoi/inference/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import datetime
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
import logging
import os
import socket
import subprocess

import numpy as np
import torch
import torch.distributed as dist

LOG = logging.getLogger(__name__)


def init_network():
"""Reads Slurm environment to set master address and port for parallel communication"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's one thing I wonder about in general -- say you have Anemoi inference inside a slurm job which happens to do other things as well, eg, run anemoi inference first, intended in non-dist mode, and then some postproc in parallel. That makes the SLURM_<...> variables visible to the anemoi code, but does not expect it to heed it. Wouldn't that cause misbehavior, as in, making anemoi inference thinking it is running in dist mode when it actually shouldn't be?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ftr, it's not a hypothetical question, I do exactly that in Cascade :) And we have two layers of the problem:

  • how to allow Amenoi parallel inference as well as Anemoi-in-Cascade-non-parallel, in a mutually non disruptive way
  • how to allow Anemoi parallel inference in Cascade, that is, slurm-within-slum :)

for the first one, it may be easiest to use custom env vars (ANEMOI_NODELIST, ANEMOI_JOBID, ...), and have the example bash submit script set them to slurm's vars. And for the second one... well, a more profound deliberation needs to be taken, but let's just keep it in mind laterally, and solve it later

Copy link
Contributor Author

@cathalobrien cathalobrien Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making anemoi inference thinking it is running in dist mode when it actually shouldn't be?

Very good point

I am reticent to add more env vars. I think a 'num-gpus' entry to the anemoi-inference config, which defaults to 1, would suffice. But yeah more thought is needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmi I believe the first point is resolved now that parallelism is disabled by default.


# Get the master address from the SLURM_NODELIST environment variable
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if not slurm_nodelist:
raise ValueError("SLURM_NODELIST environment variable is not set.")

# Use subprocess to execute scontrol and get the first hostname
result = subprocess.run(
["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True
)
master_addr = result.stdout.splitlines()[0]

# Resolve the master address using nslookup
try:
resolved_addr = socket.gethostbyname(master_addr)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {master_addr}")

# Set the resolved address as MASTER_ADDR
master_addr = resolved_addr

# Calculate the MASTER_PORT using SLURM_JOBID
slurm_jobid = os.environ.get("SLURM_JOBID")
if not slurm_jobid:
raise ValueError("SLURM_JOBID environment variable is not set.")

master_port = str(10000 + int(slurm_jobid[-4:]))

# Print the results for confirmation
LOG.debug(f"MASTER_ADDR: {master_addr}")
LOG.debug(f"MASTER_PORT: {master_port}")

return master_addr, master_port


def init_parallel(device, global_rank, world_size):
"""Creates a model communication group to be used for parallel inference"""

if world_size > 1:

master_addr, master_port = init_network()

# use 'startswith' instead of '==' in case device is 'cuda:0'
if device.startswith("cuda"):
backend = "nccl"
else:
backend = "gloo"

dist.init_process_group(
backend=backend,
init_method=f"tcp://{master_addr}:{master_port}",
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
timeout=datetime.timedelta(minutes=3),
world_size=world_size,
rank=global_rank,
)

model_comm_group_ranks = np.arange(world_size, dtype=int)
model_comm_group = torch.distributed.new_group(model_comm_group_ranks)
else:
model_comm_group = None

return model_comm_group


def get_parallel_info():
"""Reads Slurm env vars, if they exist, to determine if inference is running in parallel"""
local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus
global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

return global_rank, local_rank, world_size
97 changes: 62 additions & 35 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from .checkpoint import Checkpoint
from .context import Context
from .parallel import get_parallel_info
from .parallel import init_parallel
from .postprocess import Accumulator
from .postprocess import Noop
from .precisions import PRECISIONS
Expand Down Expand Up @@ -239,25 +241,39 @@ def model(self):
return model

def forecast(self, lead_time, input_tensor_numpy, input_state):

# determine processes rank for parallel inference and assign a device
global_rank, local_rank, world_size = get_parallel_info()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In training we seed the whole parallel model group with the same random seed I think it would be nice to do the same here as well. The challenge is to generate appropriate seeds...

if self.device == "cuda":
self.device = f"{self.device}:{local_rank}"
torch.cuda.set_device(local_rank)

self.model.eval()
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

torch.set_grad_enabled(False)

# Create pytorch input tensor
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device)

LOG.info("Using autocast %s", self.autocast)

lead_time = to_timedelta(lead_time)
steps = lead_time // self.checkpoint.timestep

LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)
if global_rank == 0:
LOG.info("World size: %d", world_size)
LOG.info("Using autocast %s", self.autocast)
LOG.info(
"Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps
)

result = input_state.copy() # We should not modify the input state
result["fields"] = dict()

start = input_state["date"]

# Create a model comm group for parallel inference
# A dummy comm group is created if only a single device is in use
model_comm_group = init_parallel(self.device, global_rank, world_size)

# The variable `check` is used to keep track of which variables have been updated
# In the input tensor. `reset` is used to reset `check` to False except
# when the values are of the constant in time variables
Expand All @@ -277,56 +293,67 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
for s in range(steps):
step = (s + 1) * self.checkpoint.timestep
date = start + step
LOG.info("Forecasting step %s (%s)", step, date)
if global_rank == 0:
LOG.info("Forecasting step %s (%s)", step, date)

result["date"] = date

# Predict next state of atmosphere
with torch.autocast(device_type=self.device, dtype=self.autocast):
y_pred = self.model.predict_step(input_tensor_torch)
if model_comm_group is None:
y_pred = self.model.predict_step(input_tensor_torch)
else:
try:
y_pred = self.model.predict_step(input_tensor_torch, model_comm_group)
except TypeError as err:
LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference")
raise err
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)
if global_rank == 0:
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)

# Update state
for i in range(output.shape[1]):
result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]
# Update state
for i in range(output.shape[1]):
result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output)
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output)

yield result
yield result

# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue
# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue

# Update tensor for next iteration
# Update tensor for next iteration

check[:] = reset
check[:] = reset

input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)
input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)

del y_pred # Recover memory
del y_pred # Recover memory

input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)

if not check.all():
# Not all variables have been updated
missing = []
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
for i in range(check.shape[-1]):
if not check[i]:
missing.append(mapping[i])
if not check.all():
# Not all variables have been updated
missing = []
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
for i in range(check.shape[-1]):
if not check[i]:
missing.append(mapping[i])

raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")
raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)

def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check):

Expand Down
Loading