From ce74e6eb1fcac29e295a7dc514ebfd4dcd02ab16 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Mon, 25 Nov 2024 10:10:26 +0100 Subject: [PATCH 01/24] model parallel wip --- src/anemoi/inference/runner.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 93e8bf3..ba55e50 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -14,8 +14,12 @@ import numpy as np import torch +import torch.distributed as dist from anemoi.utils.timer import Timer + +import os + from .checkpoint import Checkpoint LOGGER = logging.getLogger(__name__) @@ -375,6 +379,28 @@ def run( model_index, ) raise ValueError(f"Field '{name}' has NaNs and is not marked as imputable") + + + global_rank = int(os.environ["SLURM_PROCID"]) # Get rank of the current process, equivalent to dist.get_rank() + LOGGER.info("Global rank: %d", global_rank) + world_size = int(os.environ["SLURM_NTASKS"]) # Total number of processes + LOGGER.info("World size: %d", world_size) + dist.init_process_group( + backend="nccl", + init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, + ) + + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) + LOGGER.info("Global rank %d has also local rank: %d ...", global_rank, local_rank) + + # Ensure each process only uses one GPU + torch.cuda.set_device(local_rank) + + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) with Timer(f"Loading {self.checkpoint}"): try: @@ -443,7 +469,8 @@ def get_most_recent_datetime(input_fields): # Predict next state of atmosphere with torch.autocast(device_type=device, dtype=autocast): - y_pred = model.predict_step(input_tensor_torch) + #y_pred = model.predict_step(input_tensor_torch, model_comm_group) + y_pred = model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) # Detach tensor and squeeze output = np.squeeze(y_pred.cpu().numpy()) From 936c60a9159d5ee6f55f6943fd755387864cc311 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 26 Nov 2024 09:12:25 +0100 Subject: [PATCH 02/24] logging only on rank 0 --- src/anemoi/inference/runner.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index ba55e50..acfe88d 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -382,9 +382,9 @@ def run( global_rank = int(os.environ["SLURM_PROCID"]) # Get rank of the current process, equivalent to dist.get_rank() - LOGGER.info("Global rank: %d", global_rank) world_size = int(os.environ["SLURM_NTASKS"]) # Total number of processes - LOGGER.info("World size: %d", world_size) + if global_rank == 0: + LOGGER.info("World size: %d", world_size) dist.init_process_group( backend="nccl", init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', @@ -393,9 +393,6 @@ def run( rank=global_rank, ) - local_rank = int(os.environ.get("SLURM_LOCALID", 0)) - LOGGER.info("Global rank %d has also local rank: %d ...", global_rank, local_rank) - # Ensure each process only uses one GPU torch.cuda.set_device(local_rank) @@ -424,7 +421,8 @@ def run( prognostic_output_mask = self.checkpoint.prognostic_output_mask diagnostic_output_mask = self.checkpoint.diagnostic_output_mask - LOGGER.info("Using autocast %s", autocast) + if global_rank == 0: + LOGGER.info("Using autocast %s", autocast) # Write dynamic fields def get_most_recent_datetime(input_fields): @@ -441,7 +439,8 @@ def get_most_recent_datetime(input_fields): prognostic_template = reference_fields[self.checkpoint.variable_to_index["lsm"]] else: first = list(self.checkpoint.variable_to_index.keys()) - LOGGER.warning("No LSM found to use as a GRIB template, using %s", first[0]) + if global_rank == 0: + LOGGER.warning("No LSM found to use as a GRIB template, using %s", first[0]) prognostic_template = reference_fields[0] accumulated_output = np.zeros( From d870289d42230f214d078c14ca57e43a9ac3c0bf Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Tue, 26 Nov 2024 14:01:51 +0000 Subject: [PATCH 03/24] fallback if env vars arent set and some work only done by rank 0 --- src/anemoi/inference/runner.py | 94 ++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index acfe88d..5dd9d56 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -17,7 +17,7 @@ import torch.distributed as dist from anemoi.utils.timer import Timer - +import random import os from .checkpoint import Checkpoint @@ -381,14 +381,17 @@ def run( raise ValueError(f"Field '{name}' has NaNs and is not marked as imputable") - global_rank = int(os.environ["SLURM_PROCID"]) # Get rank of the current process, equivalent to dist.get_rank() - world_size = int(os.environ["SLURM_NTASKS"]) # Total number of processes + global_rank = int(os.getenv("SLURM_PROCID", 0)) # Get rank of the current process, equivalent to dist.get_rank() + world_size = int(os.getenv("SLURM_NTASKS", 1)) # Total number of processes + local_rank = int(os.getenv("SLURM_LOCALID", 0)) # Get GPU num of current process if global_rank == 0: LOGGER.info("World size: %d", world_size) + addr=os.getenv("MASTER_ADDR", 'localhost') #localhost should be sufficient to run on a single node + port=os.getenv("MASTER_PORT", 10000 + random.randint(0,10000)) #random port between 10,000 and 20,000 dist.init_process_group( backend="nccl", - init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', - timeout=datetime.timedelta(minutes=3), + init_method=f'tcp://{addr}:{port}', + timeout=datetime.timedelta(minutes=1), world_size=world_size, rank=global_rank, ) @@ -471,53 +474,54 @@ def get_most_recent_datetime(input_fields): #y_pred = model.predict_step(input_tensor_torch, model_comm_group) y_pred = model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) - # Detach tensor and squeeze - output = np.squeeze(y_pred.cpu().numpy()) - - prognostic_fields_numpy = output[:, prognostic_output_mask] - if len(diagnostic_output_mask): - diagnostic_fields_numpy = output[:, diagnostic_output_mask] + if global_rank == 0: + # Detach tensor and squeeze + output = np.squeeze(y_pred.cpu().numpy()) - for n, (m, param) in enumerate(zip(prognostic_data_from_retrieved_fields_mask, prognostic_params)): - template = reference_fields[m] - assert template.datetime()["valid_time"] == most_recent_datetime, ( - template.datetime()["valid_time"], - most_recent_datetime, - ) - output_callback( - prognostic_fields_numpy[:, n], - template=template, - step=step, - check_nans=True, # param in can_be_missing, - ) + prognostic_fields_numpy = output[:, prognostic_output_mask] + if len(diagnostic_output_mask): + diagnostic_fields_numpy = output[:, diagnostic_output_mask] - # Write diagnostic fields - if len(diagnostic_output_mask): - for n, param in enumerate(self.checkpoint.diagnostic_params): - accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n]) - assert prognostic_template.datetime()["valid_time"] == most_recent_datetime, ( - prognostic_template.datetime()["valid_time"], + for n, (m, param) in enumerate(zip(prognostic_data_from_retrieved_fields_mask, prognostic_params)): + template = reference_fields[m] + assert template.datetime()["valid_time"] == most_recent_datetime, ( + template.datetime()["valid_time"], most_recent_datetime, ) + output_callback( + prognostic_fields_numpy[:, n], + template=template, + step=step, + check_nans=True, # param in can_be_missing, + ) - if param in accumulations_params: - output_callback( - accumulated_output[n], - stepType="accum", - template=prognostic_template, - startStep=0, - endStep=step, - param=param, - check_nans=True, # param in can_be_missing, - ) - else: - output_callback( - diagnostic_fields_numpy[:, n], - template=prognostic_template, - step=step, - check_nans=True, # param in can_be_missing, + # Write diagnostic fields + if len(diagnostic_output_mask): + for n, param in enumerate(self.checkpoint.diagnostic_params): + accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n]) + assert prognostic_template.datetime()["valid_time"] == most_recent_datetime, ( + prognostic_template.datetime()["valid_time"], + most_recent_datetime, ) + if param in accumulations_params: + output_callback( + accumulated_output[n], + stepType="accum", + template=prognostic_template, + startStep=0, + endStep=step, + param=param, + check_nans=True, # param in can_be_missing, + ) + else: + output_callback( + diagnostic_fields_numpy[:, n], + template=prognostic_template, + step=step, + check_nans=True, # param in can_be_missing, + ) + # Next step # Compute new forcing From b39b796bc7ad751bf762dd295a814f3be4882849 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Tue, 26 Nov 2024 14:16:54 +0000 Subject: [PATCH 04/24] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7a72bc..0bce58f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Keep it human-readable, your future self will thank you! - Add anemoi-transform link to documentation - Add support for unstructured grids - Add CONTRIBUTORS.md file (#36) +- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55) ### Changed From b95e16758ceac258ad83c46c0c8507948d57d1d1 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Tue, 26 Nov 2024 14:35:57 +0000 Subject: [PATCH 05/24] pre-commit checks and no model comm group for single gpu case --- src/anemoi/inference/runner.py | 43 ++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 5dd9d56..b42f7f4 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -10,6 +10,8 @@ import datetime import logging +import os +import random from functools import cached_property import numpy as np @@ -17,9 +19,6 @@ import torch.distributed as dist from anemoi.utils.timer import Timer -import random -import os - from .checkpoint import Checkpoint LOGGER = logging.getLogger(__name__) @@ -379,28 +378,32 @@ def run( model_index, ) raise ValueError(f"Field '{name}' has NaNs and is not marked as imputable") - - - global_rank = int(os.getenv("SLURM_PROCID", 0)) # Get rank of the current process, equivalent to dist.get_rank() + + global_rank = int( + os.getenv("SLURM_PROCID", 0) + ) # Get rank of the current process, equivalent to dist.get_rank() world_size = int(os.getenv("SLURM_NTASKS", 1)) # Total number of processes - local_rank = int(os.getenv("SLURM_LOCALID", 0)) # Get GPU num of current process + local_rank = int(os.getenv("SLURM_LOCALID", 0)) # Get GPU num of current process if global_rank == 0: LOGGER.info("World size: %d", world_size) - addr=os.getenv("MASTER_ADDR", 'localhost') #localhost should be sufficient to run on a single node - port=os.getenv("MASTER_PORT", 10000 + random.randint(0,10000)) #random port between 10,000 and 20,000 + addr = os.getenv("MASTER_ADDR", "localhost") # localhost should be sufficient to run on a single node + port = os.getenv("MASTER_PORT", 10000 + random.randint(0, 10000)) # random port between 10,000 and 20,000 dist.init_process_group( - backend="nccl", - init_method=f'tcp://{addr}:{port}', - timeout=datetime.timedelta(minutes=1), - world_size=world_size, - rank=global_rank, - ) - + backend="nccl", + init_method=f"tcp://{addr}:{port}", + timeout=datetime.timedelta(minutes=1), + world_size=world_size, + rank=global_rank, + ) + # Ensure each process only uses one GPU torch.cuda.set_device(local_rank) - - model_comm_group_ranks = np.arange(world_size, dtype=int) - model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + + if world_size > 1: + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + else: + model_comm_group = None with Timer(f"Loading {self.checkpoint}"): try: @@ -471,7 +474,7 @@ def get_most_recent_datetime(input_fields): # Predict next state of atmosphere with torch.autocast(device_type=device, dtype=autocast): - #y_pred = model.predict_step(input_tensor_torch, model_comm_group) + # y_pred = model.predict_step(input_tensor_torch, model_comm_group) y_pred = model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) if global_rank == 0: From 9fe691cfa3e9780d480559c45df19f7cdd81ca1a Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Tue, 26 Nov 2024 14:16:54 +0000 Subject: [PATCH 06/24] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 446bda1..6945ecd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you! - Add CONTRIBUTORS.md file (#36) - Add sanetise command - Add support for huggingface +- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55) ### Changed - Change `write_initial_state` default value to `true` From 5f925740efdc1f7007d2e0ade328dcf41eda650f Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 14 Jan 2025 16:49:44 +0100 Subject: [PATCH 07/24] added parallel inf --- src/anemoi/inference/runner.py | 96 ++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index da60915..a7b5bb4 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -19,6 +19,9 @@ from anemoi.utils.text import table from anemoi.utils.timer import Timer # , Timers +import torch.distributed as dist +import os + from .checkpoint import Checkpoint from .context import Context from .postprocess import Accumulator @@ -239,10 +242,18 @@ def model(self): return model def forecast(self, lead_time, input_tensor_numpy, input_state): + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) + self.device=f"{self.device}:{local_rank}" + torch.cuda.set_device(local_rank) self.model.eval() torch.set_grad_enabled(False) + global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Get rank of the current process, equivalent to dist.get_rank() + world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes + if (local_rank == 0): + LOG.info("World size: %d", world_size) + # Create pytorch input tensor input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device) @@ -258,6 +269,21 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): start = input_state["date"] + if (world_size > 1): + dist.init_process_group( + backend="nccl", + init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, + ) + + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + else: + model_comm_group = None + # The variable `check` is used to keep track of which variables have been updated # In the input tensor. `reset` is used to reset `check` to False except # when the values are of the constant in time variables @@ -277,56 +303,62 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): for s in range(steps): step = (s + 1) * self.checkpoint.timestep date = start + step - LOG.info("Forecasting step %s (%s)", step, date) + if (local_rank == 0): + LOG.info("Forecasting step %s (%s)", step, date) result["date"] = date # Predict next state of atmosphere with torch.autocast(device_type=self.device, dtype=self.autocast): - y_pred = self.model.predict_step(input_tensor_torch) + #y_pred = self.model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) + #y_pred = self.model.forward(input_tensor_torch, model_comm_group) + y_pred = self.model.predict_step(input_tensor_torch, model_comm_group) - # Detach tensor and squeeze (should we detach here?) - output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) + if (local_rank == 0): + # Detach tensor and squeeze (should we detach here?) + output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) - # Update state - for i in range(output.shape[1]): - result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i] + # Update state + for i in range(output.shape[1]): + result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i] - if (s == 0 and self.verbosity > 0) or self.verbosity > 1: - self._print_output_tensor("Output tensor", output) + if (s == 0 and self.verbosity > 0) or self.verbosity > 1: + self._print_output_tensor("Output tensor", output) - yield result + yield result - # No need to prepare next input tensor if we are at the last step - if s == steps - 1: - continue + # No need to prepare next input tensor if we are at the last step + if s == steps - 1: + continue - # Update tensor for next iteration + # Update tensor for next iteration - check[:] = reset + check[:] = reset - input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check) + input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check) - del y_pred # Recover memory + del y_pred # Recover memory - input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check) - input_tensor_torch = self.add_boundary_forcings_to_input_tensor( - input_tensor_torch, input_state, date, check - ) + input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check) + input_tensor_torch = self.add_boundary_forcings_to_input_tensor( + input_tensor_torch, input_state, date, check + ) + + if not check.all(): + # Not all variables have been updated + missing = [] + variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index + mapping = {v: k for k, v in variable_to_input_tensor_index.items()} + for i in range(check.shape[-1]): + if not check[i]: + missing.append(mapping[i]) - if not check.all(): - # Not all variables have been updated - missing = [] - variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index - mapping = {v: k for k, v in variable_to_input_tensor_index.items()} - for i in range(check.shape[-1]): - if not check[i]: - missing.append(mapping[i]) + raise ValueError(f"Missing variables in input tensor: {sorted(missing)}") - raise ValueError(f"Missing variables in input tensor: {sorted(missing)}") + if (s == 0 and self.verbosity > 0) or self.verbosity > 1: + self._print_input_tensor("Next input tensor", input_tensor_torch) - if (s == 0 and self.verbosity > 0) or self.verbosity > 1: - self._print_input_tensor("Next input tensor", input_tensor_torch) + dist.destroy_process_group() def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check): From 71fdf0e887aedbd4fcde33928666c4c26e6fbe1e Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 14 Jan 2025 16:51:19 +0100 Subject: [PATCH 08/24] precommit --- src/anemoi/inference/runner.py | 37 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index a7b5bb4..9755ded 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -10,18 +10,17 @@ import datetime import logging +import os import warnings from functools import cached_property import numpy as np import torch +import torch.distributed as dist from anemoi.utils.dates import frequency_to_timedelta as to_timedelta from anemoi.utils.text import table from anemoi.utils.timer import Timer # , Timers -import torch.distributed as dist -import os - from .checkpoint import Checkpoint from .context import Context from .postprocess import Accumulator @@ -243,15 +242,17 @@ def model(self): def forecast(self, lead_time, input_tensor_numpy, input_state): local_rank = int(os.environ.get("SLURM_LOCALID", 0)) - self.device=f"{self.device}:{local_rank}" + self.device = f"{self.device}:{local_rank}" torch.cuda.set_device(local_rank) self.model.eval() torch.set_grad_enabled(False) - global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Get rank of the current process, equivalent to dist.get_rank() + global_rank = int( + os.environ.get("SLURM_PROCID", 0) + ) # Get rank of the current process, equivalent to dist.get_rank() world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes - if (local_rank == 0): + if local_rank == 0: LOG.info("World size: %d", world_size) # Create pytorch input tensor @@ -269,13 +270,13 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): start = input_state["date"] - if (world_size > 1): + if world_size > 1: dist.init_process_group( - backend="nccl", - init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', - timeout=datetime.timedelta(minutes=3), - world_size=world_size, - rank=global_rank, + backend="nccl", + init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, ) model_comm_group_ranks = np.arange(world_size, dtype=int) @@ -303,18 +304,18 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): for s in range(steps): step = (s + 1) * self.checkpoint.timestep date = start + step - if (local_rank == 0): + if local_rank == 0: LOG.info("Forecasting step %s (%s)", step, date) result["date"] = date # Predict next state of atmosphere with torch.autocast(device_type=self.device, dtype=self.autocast): - #y_pred = self.model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) - #y_pred = self.model.forward(input_tensor_torch, model_comm_group) + # y_pred = self.model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) + # y_pred = self.model.forward(input_tensor_torch, model_comm_group) y_pred = self.model.predict_step(input_tensor_torch, model_comm_group) - if (local_rank == 0): + if local_rank == 0: # Detach tensor and squeeze (should we detach here?) output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) @@ -339,7 +340,9 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): del y_pred # Recover memory - input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check) + input_tensor_torch = self.add_dynamic_forcings_to_input_tensor( + input_tensor_torch, input_state, date, check + ) input_tensor_torch = self.add_boundary_forcings_to_input_tensor( input_tensor_torch, input_state, date, check ) From 06a575d1021c4806b85f592f10e26bcbf99fb6dc Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Wed, 15 Jan 2025 18:31:30 +0100 Subject: [PATCH 09/24] refactor --- src/anemoi/inference/parallel.py | 78 ++++++++++++++++++++++++++++++-- src/anemoi/inference/runner.py | 35 +++----------- 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/src/anemoi/inference/parallel.py b/src/anemoi/inference/parallel.py index b85db5d..f2a5451 100644 --- a/src/anemoi/inference/parallel.py +++ b/src/anemoi/inference/parallel.py @@ -4,17 +4,19 @@ import logging -LOG = logging.getLogger(__name__) -def getParallelLogger(): +def getParallelLogger(moduleName): global_rank = int(os.environ.get("SLURM_PROCID", 0)) - logger = logging.getLogger(__name__) + logger = logging.getLogger(moduleName) if global_rank != 0: logger.setLevel(logging.NOTSET) return logger +LOG = getParallelLogger(__name__) + +#TODO could replace env vars with regular variables now def init_network(): # Get the master address from the SLURM_NODELIST environment variable slurm_nodelist = os.environ.get("SLURM_NODELIST") @@ -48,5 +50,71 @@ def init_network(): os.environ["MASTER_PORT"] = str(master_port) # Print the results for confirmation - #print(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - #print(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + LOG.info(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + LOG.info(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + +def init_parallel(world_size): + if world_size > 1: + + init_network() + dist.init_process_group( + backend="nccl", + init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, + ) + + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + else: + model_comm_group = None + + return model_comm_group + +def get_parallel_info(): + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) + global_rank = int( + os.environ.get("SLURM_PROCID", 0) + ) # Get rank of the current process, equivalent to dist.get_rank() + world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes + + return global_rank, local_rank, world_size + + + lead_time = to_timedelta(lead_time) + steps = lead_time // self.checkpoint.timestep + + if global_rank == 0: + LOG.info("World size: %d", world_size) + LOG.info("Using autocast %s", self.autocast) + LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps) + + result = input_state.copy() # We should not modify the input state + result["fields"] = dict() + + start = input_state["date"] + + if world_size > 1: + + #only rank 0 logs + if (local_rank != 0): + LOG.handlers.clear() + + init_network() + + + dist.init_process_group( + backend="nccl", + init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, + ) + + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + else: + model_comm_group = None diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 74af21f..7fc3775 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -27,7 +27,7 @@ from .postprocess import Accumulator from .postprocess import Noop from .precisions import PRECISIONS -from .parallel import init_network +from .parallel import init_parallel, get_parallel_info LOG = logging.getLogger(__name__) @@ -243,24 +243,22 @@ def model(self): return model def forecast(self, lead_time, input_tensor_numpy, input_state): - local_rank = int(os.environ.get("SLURM_LOCALID", 0)) + + global_rank, local_rank, world_size = get_parallel_info() + self.device = f"{self.device}:{local_rank}" torch.cuda.set_device(local_rank) self.model.eval() torch.set_grad_enabled(False) - global_rank = int( - os.environ.get("SLURM_PROCID", 0) - ) # Get rank of the current process, equivalent to dist.get_rank() - world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes - # Create pytorch input tensor input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device) lead_time = to_timedelta(lead_time) steps = lead_time // self.checkpoint.timestep + #TODO make it so that only rank 0 logs by default if global_rank == 0: LOG.info("World size: %d", world_size) LOG.info("Using autocast %s", self.autocast) @@ -271,28 +269,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): start = input_state["date"] - if world_size > 1: - - #only rank 0 logs - if (local_rank != 0): - LOG.handlers.clear() - - init_network() - - - dist.init_process_group( - backend="nccl", - init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', - timeout=datetime.timedelta(minutes=3), - world_size=world_size, - rank=global_rank, - ) - - model_comm_group_ranks = np.arange(world_size, dtype=int) - model_comm_group_ranks = np.arange(world_size, dtype=int) - model_comm_group = torch.distributed.new_group(model_comm_group_ranks) - else: - model_comm_group = None + model_comm_group = init_parallel(world_size) # The variable `check` is used to keep track of which variables have been updated # In the input tensor. `reset` is used to reset `check` to False except From fa89bb832a2258e415656cd025fa7f904e58b6b3 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Wed, 15 Jan 2025 18:40:10 +0100 Subject: [PATCH 10/24] refactor --- src/anemoi/inference/parallel.py | 71 +++++++------------------------- src/anemoi/inference/runner.py | 16 +++---- 2 files changed, 22 insertions(+), 65 deletions(-) diff --git a/src/anemoi/inference/parallel.py b/src/anemoi/inference/parallel.py index f2a5451..4da9945 100644 --- a/src/anemoi/inference/parallel.py +++ b/src/anemoi/inference/parallel.py @@ -1,22 +1,17 @@ +import datetime +import logging import os -import subprocess import socket +import subprocess -import logging - - -def getParallelLogger(moduleName): - global_rank = int(os.environ.get("SLURM_PROCID", 0)) - logger = logging.getLogger(moduleName) - - if global_rank != 0: - logger.setLevel(logging.NOTSET) +import numpy as np +import torch +import torch.distributed as dist - return logger +LOG = logging.getLogger(__name__) -LOG = getParallelLogger(__name__) -#TODO could replace env vars with regular variables now +# TODO could replace env vars with regular variables now def init_network(): # Get the master address from the SLURM_NODELIST environment variable slurm_nodelist = os.environ.get("SLURM_NODELIST") @@ -25,10 +20,7 @@ def init_network(): # Use subprocess to execute scontrol and get the first hostname result = subprocess.run( - ["scontrol", "show", "hostname", slurm_nodelist], - stdout=subprocess.PIPE, - text=True, - check=True + ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True ) master_addr = result.stdout.splitlines()[0] @@ -50,10 +42,11 @@ def init_network(): os.environ["MASTER_PORT"] = str(master_port) # Print the results for confirmation - LOG.info(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - LOG.info(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + LOG.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + LOG.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + -def init_parallel(world_size): +def init_parallel(global_rank, world_size): if world_size > 1: init_network() @@ -73,6 +66,7 @@ def init_parallel(world_size): return model_comm_group + def get_parallel_info(): local_rank = int(os.environ.get("SLURM_LOCALID", 0)) global_rank = int( @@ -81,40 +75,3 @@ def get_parallel_info(): world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes return global_rank, local_rank, world_size - - - lead_time = to_timedelta(lead_time) - steps = lead_time // self.checkpoint.timestep - - if global_rank == 0: - LOG.info("World size: %d", world_size) - LOG.info("Using autocast %s", self.autocast) - LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps) - - result = input_state.copy() # We should not modify the input state - result["fields"] = dict() - - start = input_state["date"] - - if world_size > 1: - - #only rank 0 logs - if (local_rank != 0): - LOG.handlers.clear() - - init_network() - - - dist.init_process_group( - backend="nccl", - init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', - timeout=datetime.timedelta(minutes=3), - world_size=world_size, - rank=global_rank, - ) - - model_comm_group_ranks = np.arange(world_size, dtype=int) - model_comm_group_ranks = np.arange(world_size, dtype=int) - model_comm_group = torch.distributed.new_group(model_comm_group_ranks) - else: - model_comm_group = None diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 7fc3775..73da1ac 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -10,24 +10,22 @@ import datetime import logging -import os import warnings -import random from functools import cached_property import numpy as np import torch -import torch.distributed as dist from anemoi.utils.dates import frequency_to_timedelta as to_timedelta from anemoi.utils.text import table from anemoi.utils.timer import Timer # , Timers from .checkpoint import Checkpoint from .context import Context +from .parallel import get_parallel_info +from .parallel import init_parallel from .postprocess import Accumulator from .postprocess import Noop from .precisions import PRECISIONS -from .parallel import init_parallel, get_parallel_info LOG = logging.getLogger(__name__) @@ -258,18 +256,20 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): lead_time = to_timedelta(lead_time) steps = lead_time // self.checkpoint.timestep - #TODO make it so that only rank 0 logs by default + # TODO make it so that only rank 0 logs by default if global_rank == 0: LOG.info("World size: %d", world_size) LOG.info("Using autocast %s", self.autocast) - LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps) + LOG.info( + "Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps + ) result = input_state.copy() # We should not modify the input state result["fields"] = dict() start = input_state["date"] - model_comm_group = init_parallel(world_size) + model_comm_group = init_parallel(global_rank, world_size) # The variable `check` is used to keep track of which variables have been updated # In the input tensor. `reset` is used to reset `check` to False except @@ -347,7 +347,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): if (s == 0 and self.verbosity > 0) or self.verbosity > 1: self._print_input_tensor("Next input tensor", input_tensor_torch) - #dist.destroy_process_group() + # dist.destroy_process_group() def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check): From a6a4ea407ffce997cecebbeeb2c047c25070931e Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Thu, 16 Jan 2025 08:41:54 +0100 Subject: [PATCH 11/24] tidy --- src/anemoi/inference/parallel.py | 28 +++++++++++++++------------- src/anemoi/inference/runner.py | 8 ++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/anemoi/inference/parallel.py b/src/anemoi/inference/parallel.py index 4da9945..7048198 100644 --- a/src/anemoi/inference/parallel.py +++ b/src/anemoi/inference/parallel.py @@ -11,8 +11,9 @@ LOG = logging.getLogger(__name__) -# TODO could replace env vars with regular variables now def init_network(): + """Reads Slurm environment to set master address and port for parallel communication""" + # Get the master address from the SLURM_NODELIST environment variable slurm_nodelist = os.environ.get("SLURM_NODELIST") if not slurm_nodelist: @@ -31,34 +32,36 @@ def init_network(): raise ValueError(f"Could not resolve hostname: {master_addr}") # Set the resolved address as MASTER_ADDR - os.environ["MASTER_ADDR"] = resolved_addr + master_addr = resolved_addr # Calculate the MASTER_PORT using SLURM_JOBID slurm_jobid = os.environ.get("SLURM_JOBID") if not slurm_jobid: raise ValueError("SLURM_JOBID environment variable is not set.") - master_port = 10000 + int(slurm_jobid[-4:]) - os.environ["MASTER_PORT"] = str(master_port) + master_port = str(10000 + int(slurm_jobid[-4:])) # Print the results for confirmation - LOG.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - LOG.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + LOG.debug(f"MASTER_ADDR: {master_addr}") + LOG.debug(f"MASTER_PORT: {master_port}") + + return master_addr, master_port def init_parallel(global_rank, world_size): + """Creates a model communication group to be used for parallel inference""" + if world_size > 1: - init_network() + master_addr, master_port = init_network() dist.init_process_group( backend="nccl", - init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}', + init_method=f"tcp://{master_addr}:{master_port}", timeout=datetime.timedelta(minutes=3), world_size=world_size, rank=global_rank, ) - model_comm_group_ranks = np.arange(world_size, dtype=int) model_comm_group_ranks = np.arange(world_size, dtype=int) model_comm_group = torch.distributed.new_group(model_comm_group_ranks) else: @@ -68,10 +71,9 @@ def init_parallel(global_rank, world_size): def get_parallel_info(): - local_rank = int(os.environ.get("SLURM_LOCALID", 0)) - global_rank = int( - os.environ.get("SLURM_PROCID", 0) - ) # Get rank of the current process, equivalent to dist.get_rank() + """Reads Slurm env vars, if they exist, to determine if inference is running in parallel""" + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus + global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes return global_rank, local_rank, world_size diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 73da1ac..4c0b98d 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -242,10 +242,11 @@ def model(self): def forecast(self, lead_time, input_tensor_numpy, input_state): + # determine processes rank for parallel inference and assign a device global_rank, local_rank, world_size = get_parallel_info() - self.device = f"{self.device}:{local_rank}" torch.cuda.set_device(local_rank) + self.model.eval() torch.set_grad_enabled(False) @@ -256,7 +257,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): lead_time = to_timedelta(lead_time) steps = lead_time // self.checkpoint.timestep - # TODO make it so that only rank 0 logs by default if global_rank == 0: LOG.info("World size: %d", world_size) LOG.info("Using autocast %s", self.autocast) @@ -269,6 +269,8 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): start = input_state["date"] + # Create a model comm group for parallel inference + # A dummy comm group is created if only a single device is in use model_comm_group = init_parallel(global_rank, world_size) # The variable `check` is used to keep track of which variables have been updated @@ -297,8 +299,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # Predict next state of atmosphere with torch.autocast(device_type=self.device, dtype=self.autocast): - # y_pred = self.model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) - # y_pred = self.model.forward(input_tensor_torch, model_comm_group) y_pred = self.model.predict_step(input_tensor_torch, model_comm_group) if global_rank == 0: From 8a73f622dd5eda2ef21d6738840a2e3afc7a1c57 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Thu, 16 Jan 2025 09:35:27 +0100 Subject: [PATCH 12/24] more compatible with older versions of models --- src/anemoi/inference/runner.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 4c0b98d..194d683 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -299,7 +299,15 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # Predict next state of atmosphere with torch.autocast(device_type=self.device, dtype=self.autocast): - y_pred = self.model.predict_step(input_tensor_torch, model_comm_group) + if model_comm_group is None: + y_pred = self.model.predict_step(input_tensor_torch) + else: + try: + y_pred = self.model.predict_step(input_tensor_torch, model_comm_group) + except TypeError as err: + LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference") + raise err + if global_rank == 0: # Detach tensor and squeeze (should we detach here?) From db560ebf9d99eb432807347dd4d11a624cf40a19 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Thu, 16 Jan 2025 09:36:04 +0100 Subject: [PATCH 13/24] forgot precommit --- src/anemoi/inference/runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 194d683..f565d53 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -308,7 +308,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference") raise err - if global_rank == 0: # Detach tensor and squeeze (should we detach here?) output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) From b21d811e04b9e84f8a27573429401cdeb87e382b Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Thu, 16 Jan 2025 09:55:57 +0100 Subject: [PATCH 14/24] remove commented code --- src/anemoi/inference/runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index f565d53..9773320 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -354,8 +354,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): if (s == 0 and self.verbosity > 0) or self.verbosity > 1: self._print_input_tensor("Next input tensor", input_tensor_torch) - # dist.destroy_process_group() - def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check): # input_tensor_torch is shape: (batch, multi_step_input, values, variables) From 48ad37b7fa1846d59070766a619d53c92d889c94 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Thu, 16 Jan 2025 10:33:52 +0100 Subject: [PATCH 15/24] added license --- src/anemoi/inference/parallel.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/anemoi/inference/parallel.py b/src/anemoi/inference/parallel.py index 7048198..3e9e4a6 100644 --- a/src/anemoi/inference/parallel.py +++ b/src/anemoi/inference/parallel.py @@ -1,3 +1,13 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + import datetime import logging import os From b9ecc149e0c120b7e6437b617abae0b835beca49 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Thu, 16 Jan 2025 11:22:52 +0100 Subject: [PATCH 16/24] feedback --- src/anemoi/inference/parallel.py | 11 +++++++++-- src/anemoi/inference/runner.py | 7 ++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/anemoi/inference/parallel.py b/src/anemoi/inference/parallel.py index 3e9e4a6..e4632c9 100644 --- a/src/anemoi/inference/parallel.py +++ b/src/anemoi/inference/parallel.py @@ -58,14 +58,21 @@ def init_network(): return master_addr, master_port -def init_parallel(global_rank, world_size): +def init_parallel(device, global_rank, world_size): """Creates a model communication group to be used for parallel inference""" if world_size > 1: master_addr, master_port = init_network() + + # use 'startswith' instead of '==' in case device is 'cuda:0' + if device.startswith("cuda"): + backend = "nccl" + else: + backend = "gloo" + dist.init_process_group( - backend="nccl", + backend=backend, init_method=f"tcp://{master_addr}:{master_port}", timeout=datetime.timedelta(minutes=3), world_size=world_size, diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 9773320..b16f6d6 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -244,8 +244,9 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # determine processes rank for parallel inference and assign a device global_rank, local_rank, world_size = get_parallel_info() - self.device = f"{self.device}:{local_rank}" - torch.cuda.set_device(local_rank) + if self.device == "cuda": + self.device = f"{self.device}:{local_rank}" + torch.cuda.set_device(local_rank) self.model.eval() @@ -271,7 +272,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # Create a model comm group for parallel inference # A dummy comm group is created if only a single device is in use - model_comm_group = init_parallel(global_rank, world_size) + model_comm_group = init_parallel(self.device, global_rank, world_size) # The variable `check` is used to keep track of which variables have been updated # In the input tensor. `reset` is used to reset `check` to False except From 27965ff0f312ab81599f4998522fae9c31e27702 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Fri, 17 Jan 2025 13:43:54 +0100 Subject: [PATCH 17/24] refactor to parallel runner --- src/anemoi/inference/runner.py | 51 +++++++++++++----------- src/anemoi/inference/runners/parallel.py | 49 +++++++++++++++++++++++ 2 files changed, 76 insertions(+), 24 deletions(-) create mode 100644 src/anemoi/inference/runners/parallel.py diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 94a61ff..8bd8687 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -15,6 +15,7 @@ import numpy as np import torch +import torch.distributed as dist from anemoi.utils.dates import frequency_to_timedelta as to_timedelta from anemoi.utils.text import table from anemoi.utils.timer import Timer # , Timers @@ -320,38 +321,40 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): yield result - # No need to prepare next input tensor if we are at the last step - if s == steps - 1: - continue + # No need to prepare next input tensor if we are at the last step + if s == steps - 1: + continue - # Update tensor for next iteration + # Update tensor for next iteration - check[:] = reset + check[:] = reset - input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check) + input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check) - del y_pred # Recover memory + del y_pred # Recover memory - input_tensor_torch = self.add_dynamic_forcings_to_input_tensor( - input_tensor_torch, input_state, date, check - ) - input_tensor_torch = self.add_boundary_forcings_to_input_tensor( - input_tensor_torch, input_state, date, check - ) + input_tensor_torch = self.add_dynamic_forcings_to_input_tensor( + input_tensor_torch, input_state, date, check + ) + input_tensor_torch = self.add_boundary_forcings_to_input_tensor( + input_tensor_torch, input_state, date, check + ) - if not check.all(): - # Not all variables have been updated - missing = [] - variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index - mapping = {v: k for k, v in variable_to_input_tensor_index.items()} - for i in range(check.shape[-1]): - if not check[i]: - missing.append(mapping[i]) + if not check.all(): + # Not all variables have been updated + missing = [] + variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index + mapping = {v: k for k, v in variable_to_input_tensor_index.items()} + for i in range(check.shape[-1]): + if not check[i]: + missing.append(mapping[i]) - raise ValueError(f"Missing variables in input tensor: {sorted(missing)}") + raise ValueError(f"Missing variables in input tensor: {sorted(missing)}") - if (s == 0 and self.verbosity > 0) or self.verbosity > 1: - self._print_input_tensor("Next input tensor", input_tensor_torch) + if (s == 0 and self.verbosity > 0) or self.verbosity > 1: + self._print_input_tensor("Next input tensor", input_tensor_torch) + + #dist.destroy_process_group(model_comm_group) def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check): diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py new file mode 100644 index 0000000..7f98eea --- /dev/null +++ b/src/anemoi/inference/runners/parallel.py @@ -0,0 +1,49 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +from . import runner_registry +from .default import DefaultRunner +from ..parallel import get_parallel_info +from ..outputs import create_output + +LOG = logging.getLogger(__name__) + + +@runner_registry.register("parallel") +class ParallelRunner(DefaultRunner): + + def __init__(self, context): + super().__init__(context) + global_rank, local_rank, world_size = get_parallel_info() + self.global_rank = global_rank + self.local_rank = local_rank + self.world_size = world_size + + def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): + model_comm_group = kwargs.get("model_comm_group", None) + if model_comm_group is None: + return model.predict_step(input_tensor_torch) + else: + try: + return model.predict_step(input_tensor_torch, model_comm_group) + except TypeError as err: + LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference") + raise err + + def create_output(self): + if self.global_rank == 0: + output = create_output(self, self.config.output) + LOG.info("Output: %s", output) + return output + else: + output = create_output(self, "none") + return output From 43167c51a4c8778b68db893779a65eb3ad37c926 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Fri, 17 Jan 2025 14:27:20 +0100 Subject: [PATCH 18/24] refactored into explicit parallel runner class --- src/anemoi/inference/parallel.py | 96 -------------------- src/anemoi/inference/runner.py | 50 +++-------- src/anemoi/inference/runners/parallel.py | 109 +++++++++++++++++++++-- 3 files changed, 115 insertions(+), 140 deletions(-) delete mode 100644 src/anemoi/inference/parallel.py diff --git a/src/anemoi/inference/parallel.py b/src/anemoi/inference/parallel.py deleted file mode 100644 index e4632c9..0000000 --- a/src/anemoi/inference/parallel.py +++ /dev/null @@ -1,96 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import datetime -import logging -import os -import socket -import subprocess - -import numpy as np -import torch -import torch.distributed as dist - -LOG = logging.getLogger(__name__) - - -def init_network(): - """Reads Slurm environment to set master address and port for parallel communication""" - - # Get the master address from the SLURM_NODELIST environment variable - slurm_nodelist = os.environ.get("SLURM_NODELIST") - if not slurm_nodelist: - raise ValueError("SLURM_NODELIST environment variable is not set.") - - # Use subprocess to execute scontrol and get the first hostname - result = subprocess.run( - ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True - ) - master_addr = result.stdout.splitlines()[0] - - # Resolve the master address using nslookup - try: - resolved_addr = socket.gethostbyname(master_addr) - except socket.gaierror: - raise ValueError(f"Could not resolve hostname: {master_addr}") - - # Set the resolved address as MASTER_ADDR - master_addr = resolved_addr - - # Calculate the MASTER_PORT using SLURM_JOBID - slurm_jobid = os.environ.get("SLURM_JOBID") - if not slurm_jobid: - raise ValueError("SLURM_JOBID environment variable is not set.") - - master_port = str(10000 + int(slurm_jobid[-4:])) - - # Print the results for confirmation - LOG.debug(f"MASTER_ADDR: {master_addr}") - LOG.debug(f"MASTER_PORT: {master_port}") - - return master_addr, master_port - - -def init_parallel(device, global_rank, world_size): - """Creates a model communication group to be used for parallel inference""" - - if world_size > 1: - - master_addr, master_port = init_network() - - # use 'startswith' instead of '==' in case device is 'cuda:0' - if device.startswith("cuda"): - backend = "nccl" - else: - backend = "gloo" - - dist.init_process_group( - backend=backend, - init_method=f"tcp://{master_addr}:{master_port}", - timeout=datetime.timedelta(minutes=3), - world_size=world_size, - rank=global_rank, - ) - - model_comm_group_ranks = np.arange(world_size, dtype=int) - model_comm_group = torch.distributed.new_group(model_comm_group_ranks) - else: - model_comm_group = None - - return model_comm_group - - -def get_parallel_info(): - """Reads Slurm env vars, if they exist, to determine if inference is running in parallel""" - local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus - global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes - world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes - - return global_rank, local_rank, world_size diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 8bd8687..8bdb14d 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -15,15 +15,12 @@ import numpy as np import torch -import torch.distributed as dist from anemoi.utils.dates import frequency_to_timedelta as to_timedelta from anemoi.utils.text import table from anemoi.utils.timer import Timer # , Timers from .checkpoint import Checkpoint from .context import Context -from .parallel import get_parallel_info -from .parallel import init_parallel from .postprocess import Accumulator from .postprocess import Noop from .precisions import PRECISIONS @@ -247,15 +244,6 @@ def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): return model.predict_step(input_tensor_torch) def forecast(self, lead_time, input_tensor_numpy, input_state): - - # determine processes rank for parallel inference and assign a device - global_rank, local_rank, world_size = get_parallel_info() - if self.device == "cuda": - self.device = f"{self.device}:{local_rank}" - torch.cuda.set_device(local_rank) - - self.model.eval() - torch.set_grad_enabled(False) # Create pytorch input tensor @@ -264,22 +252,14 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): lead_time = to_timedelta(lead_time) steps = lead_time // self.checkpoint.timestep - if global_rank == 0: - LOG.info("World size: %d", world_size) - LOG.info("Using autocast %s", self.autocast) - LOG.info( - "Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps - ) + LOG.info("Using autocast %s", self.autocast) + LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps) result = input_state.copy() # We should not modify the input state result["fields"] = dict() start = input_state["date"] - # Create a model comm group for parallel inference - # A dummy comm group is created if only a single device is in use - model_comm_group = init_parallel(self.device, global_rank, world_size) - # The variable `check` is used to keep track of which variables have been updated # In the input tensor. `reset` is used to reset `check` to False except # when the values are of the constant in time variables @@ -299,8 +279,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): for s in range(steps): step = (s + 1) * self.checkpoint.timestep date = start + step - if global_rank == 0: - LOG.info("Forecasting step %s (%s)", step, date) + LOG.info("Forecasting step %s (%s)", step, date) result["date"] = date @@ -308,18 +287,17 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): with torch.autocast(device_type=self.device, dtype=self.autocast): y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s) - if global_rank == 0: - # Detach tensor and squeeze (should we detach here?) - output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) + # Detach tensor and squeeze (should we detach here?) + output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) - # Update state - for i in range(output.shape[1]): - result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i] + # Update state + for i in range(output.shape[1]): + result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i] - if (s == 0 and self.verbosity > 0) or self.verbosity > 1: - self._print_output_tensor("Output tensor", output) + if (s == 0 and self.verbosity > 0) or self.verbosity > 1: + self._print_output_tensor("Output tensor", output) - yield result + yield result # No need to prepare next input tensor if we are at the last step if s == steps - 1: @@ -333,9 +311,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): del y_pred # Recover memory - input_tensor_torch = self.add_dynamic_forcings_to_input_tensor( - input_tensor_torch, input_state, date, check - ) + input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check) input_tensor_torch = self.add_boundary_forcings_to_input_tensor( input_tensor_torch, input_state, date, check ) @@ -354,8 +330,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): if (s == 0 and self.verbosity > 0) or self.verbosity > 1: self._print_input_tensor("Next input tensor", input_tensor_torch) - #dist.destroy_process_group(model_comm_group) - def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check): # input_tensor_torch is shape: (batch, multi_step_input, values, variables) diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py index 7f98eea..18202c40 100644 --- a/src/anemoi/inference/runners/parallel.py +++ b/src/anemoi/inference/runners/parallel.py @@ -8,33 +8,53 @@ # nor does it submit to any jurisdiction. +import datetime import logging +import os +import socket +import subprocess +import numpy as np +import torch +import torch.distributed as dist + +from ..outputs import create_output from . import runner_registry from .default import DefaultRunner -from ..parallel import get_parallel_info -from ..outputs import create_output LOG = logging.getLogger(__name__) @runner_registry.register("parallel") class ParallelRunner(DefaultRunner): + """Runner which splits a model over multiple devices""" def __init__(self, context): super().__init__(context) - global_rank, local_rank, world_size = get_parallel_info() + global_rank, local_rank, world_size = self.__get_parallel_info() self.global_rank = global_rank self.local_rank = local_rank self.world_size = world_size + if self.device == "cuda": + self.device = f"{self.device}:{local_rank}" + torch.cuda.set_device(local_rank) + + # disable most logging on non-zero ranks + if self.global_rank != 0: + logging.getLogger().setLevel(logging.WARNING) + + # Create a model comm group for parallel inference + # A dummy comm group is created if only a single device is in use + model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size) + self.model_comm_group = model_comm_group + def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): - model_comm_group = kwargs.get("model_comm_group", None) - if model_comm_group is None: + if self.model_comm_group is None: return model.predict_step(input_tensor_torch) else: try: - return model.predict_step(input_tensor_torch, model_comm_group) + return model.predict_step(input_tensor_torch, self.model_comm_group) except TypeError as err: LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference") raise err @@ -47,3 +67,80 @@ def create_output(self): else: output = create_output(self, "none") return output + + def __del__(self): + if self.model_comm_group is not None: + dist.destroy_process_group() + + def __init_network(self): + """Reads Slurm environment to set master address and port for parallel communication""" + + # Get the master address from the SLURM_NODELIST environment variable + slurm_nodelist = os.environ.get("SLURM_NODELIST") + if not slurm_nodelist: + raise ValueError("SLURM_NODELIST environment variable is not set.") + + # Use subprocess to execute scontrol and get the first hostname + result = subprocess.run( + ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True + ) + master_addr = result.stdout.splitlines()[0] + + # Resolve the master address using nslookup + try: + resolved_addr = socket.gethostbyname(master_addr) + except socket.gaierror: + raise ValueError(f"Could not resolve hostname: {master_addr}") + + # Set the resolved address as MASTER_ADDR + master_addr = resolved_addr + + # Calculate the MASTER_PORT using SLURM_JOBID + slurm_jobid = os.environ.get("SLURM_JOBID") + if not slurm_jobid: + raise ValueError("SLURM_JOBID environment variable is not set.") + + master_port = str(10000 + int(slurm_jobid[-4:])) + + # Print the results for confirmation + LOG.debug(f"MASTER_ADDR: {master_addr}") + LOG.debug(f"MASTER_PORT: {master_port}") + + return master_addr, master_port + + def __init_parallel(self, device, global_rank, world_size): + """Creates a model communication group to be used for parallel inference""" + + if world_size > 1: + + master_addr, master_port = self.__init_network() + + # use 'startswith' instead of '==' in case device is 'cuda:0' + if device.startswith("cuda"): + backend = "nccl" + else: + backend = "gloo" + + dist.init_process_group( + backend=backend, + init_method=f"tcp://{master_addr}:{master_port}", + timeout=datetime.timedelta(minutes=3), + world_size=world_size, + rank=global_rank, + ) + LOG.info(f"Creating a model comm group with {world_size} devices with the {backend} backend") + + model_comm_group_ranks = np.arange(world_size, dtype=int) + model_comm_group = torch.distributed.new_group(model_comm_group_ranks) + else: + model_comm_group = None + + return model_comm_group + + def __get_parallel_info(self): + """Reads Slurm env vars, if they exist, to determine if inference is running in parallel""" + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus + global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes + world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes + + return global_rank, local_rank, world_size From 6974ac33a70f20a334f295ca7ff38e727908c575 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Mon, 20 Jan 2025 16:11:29 +0100 Subject: [PATCH 19/24] allow MASTER_ADDR and MASTER_PORT to be set as env vars before runtime --- src/anemoi/inference/runners/parallel.py | 52 ++++++++++++++---------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/src/anemoi/inference/runners/parallel.py b/src/anemoi/inference/runners/parallel.py index 18202c40..f77a19d 100644 --- a/src/anemoi/inference/runners/parallel.py +++ b/src/anemoi/inference/runners/parallel.py @@ -80,27 +80,37 @@ def __init_network(self): if not slurm_nodelist: raise ValueError("SLURM_NODELIST environment variable is not set.") - # Use subprocess to execute scontrol and get the first hostname - result = subprocess.run( - ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True - ) - master_addr = result.stdout.splitlines()[0] - - # Resolve the master address using nslookup - try: - resolved_addr = socket.gethostbyname(master_addr) - except socket.gaierror: - raise ValueError(f"Could not resolve hostname: {master_addr}") - - # Set the resolved address as MASTER_ADDR - master_addr = resolved_addr - - # Calculate the MASTER_PORT using SLURM_JOBID - slurm_jobid = os.environ.get("SLURM_JOBID") - if not slurm_jobid: - raise ValueError("SLURM_JOBID environment variable is not set.") - - master_port = str(10000 + int(slurm_jobid[-4:])) + # Check if MASTER_ADDR is given, otherwise try set it using 'scontrol' + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + LOG.debug("'MASTER_ADDR' environment variable not set. Trying to set via SLURM") + try: + result = subprocess.run( + ["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True + ) + except subprocess.CalledProcessError as err: + LOG.error( + "Python could not execute 'scontrol show hostname $SLURM_NODELIST' while calculating MASTER_ADDR. You could avoid this error by setting the MASTER_ADDR env var manually." + ) + raise err + + master_addr = result.stdout.splitlines()[0] + + # Resolve the master address using nslookup + try: + master_addr = socket.gethostbyname(master_addr) + except socket.gaierror: + raise ValueError(f"Could not resolve hostname: {master_addr}") + + # Check if MASTER_PORT is given, otherwise generate one based on SLURM_JOBID + master_port = os.environ.get("MASTER_PORT") + if master_port is None: + LOG.debug("'MASTER_PORT' environment variable not set. Trying to set via SLURM") + slurm_jobid = os.environ.get("SLURM_JOBID") + if not slurm_jobid: + raise ValueError("SLURM_JOBID environment variable is not set.") + + master_port = str(10000 + int(slurm_jobid[-4:])) # Print the results for confirmation LOG.debug(f"MASTER_ADDR: {master_addr}") From 2016c7b780df246cc31d071e018890a3196fee36 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 21 Jan 2025 15:48:29 +0100 Subject: [PATCH 20/24] readd line accicdentally deleted --- src/anemoi/inference/runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 8bdb14d..13957f8 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -244,6 +244,8 @@ def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): return model.predict_step(input_tensor_torch) def forecast(self, lead_time, input_tensor_numpy, input_state): + self.model.eval() + torch.set_grad_enabled(False) # Create pytorch input tensor From bd391f5c1a81bece5da5bd8c69beac6a9a7a9243 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 21 Jan 2025 17:41:53 +0100 Subject: [PATCH 21/24] added documentation --- docs/parallel.rst | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 docs/parallel.rst diff --git a/docs/parallel.rst b/docs/parallel.rst new file mode 100644 index 0000000..1f70775 --- /dev/null +++ b/docs/parallel.rst @@ -0,0 +1,60 @@ +################### + Parallel Inference +################### + +If the memory requirements of your model are too large to fit within a single GPU, you run Anemoi-Inference in parallel across multiple GPUs. + +Parallel inference requires SLURM to launch the parallel processes and to determine information about your network environment. If SLURM is not available to you, please create an issue on the Anemoi-Inference github page. + +************** + Configuration +************** + +To run in parallel, you must add '`runner:parallel`' to your inference config file. + +.. code:: yaml + + checkpoint: /path/to/inference-last.ckpt + lead_time: 60 + runner: parallel + input: + grib: /path/to/input.grib + output: + grib: /path/to/output.grib + + + + +****************************** + Running inference in parallel +****************************** + +Below is an example SLURM batch script to launch a parallel inference job across 4 GPUs. + +.. code:: bash + + #!/bin/bash + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=4 + #SBATCH --gpus-per-node=4 + #SBATCH --cpus-per-task=8 + #SBATCH --time=0:05:00 + #SBATCH --output=outputs/paralell_inf.%j.out + + source /path/to/venv/bin/activate + srun anemoi-inference run parallel.yaml + +.. warning:: + If you specify '`runner:parallel`' but you don't launch with '`srun`', your anemoi-inference job may hang as only 1 process will be launched. + +.. note:: + By default, anemoi-inference will determine your systems master address and port itself. If this fails (i.e. when running Anemoi-Inference inside a container), you can instead set these values yourself via environment variables in your SLURM batch script: + + .. code:: bash + + MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1) + export MASTER_ADDR=$(nslookup $MASTER_ADDR | grep -oP '(?<=Address: ).*') + export MASTER_PORT=$((10000 + RANDOM % 10000)) + + srun anemoi-inference run parallel.yaml + From 1cd498204ad55f2f3904b516e4b9f4b745759e96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 16:42:33 +0000 Subject: [PATCH 22/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/parallel.rst | 73 ++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/docs/parallel.rst b/docs/parallel.rst index 1f70775..17412aa 100644 --- a/docs/parallel.rst +++ b/docs/parallel.rst @@ -1,54 +1,64 @@ -################### +#################### Parallel Inference -################### +#################### -If the memory requirements of your model are too large to fit within a single GPU, you run Anemoi-Inference in parallel across multiple GPUs. +If the memory requirements of your model are too large to fit within a +single GPU, you run Anemoi-Inference in parallel across multiple GPUs. -Parallel inference requires SLURM to launch the parallel processes and to determine information about your network environment. If SLURM is not available to you, please create an issue on the Anemoi-Inference github page. +Parallel inference requires SLURM to launch the parallel processes and +to determine information about your network environment. If SLURM is not +available to you, please create an issue on the Anemoi-Inference github +page. -************** +*************** Configuration -************** +*************** -To run in parallel, you must add '`runner:parallel`' to your inference config file. +To run in parallel, you must add '`runner:parallel`' to your inference +config file. .. code:: yaml - checkpoint: /path/to/inference-last.ckpt - lead_time: 60 - runner: parallel - input: - grib: /path/to/input.grib - output: - grib: /path/to/output.grib + checkpoint: /path/to/inference-last.ckpt + lead_time: 60 + runner: parallel + input: + grib: /path/to/input.grib + output: + grib: /path/to/output.grib - - - -****************************** +******************************* Running inference in parallel -****************************** +******************************* -Below is an example SLURM batch script to launch a parallel inference job across 4 GPUs. +Below is an example SLURM batch script to launch a parallel inference +job across 4 GPUs. .. code:: bash - #!/bin/bash - #SBATCH --nodes=1 - #SBATCH --ntasks-per-node=4 - #SBATCH --gpus-per-node=4 - #SBATCH --cpus-per-task=8 - #SBATCH --time=0:05:00 - #SBATCH --output=outputs/paralell_inf.%j.out + #!/bin/bash + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=4 + #SBATCH --gpus-per-node=4 + #SBATCH --cpus-per-task=8 + #SBATCH --time=0:05:00 + #SBATCH --output=outputs/paralell_inf.%j.out - source /path/to/venv/bin/activate - srun anemoi-inference run parallel.yaml + source /path/to/venv/bin/activate + srun anemoi-inference run parallel.yaml .. warning:: - If you specify '`runner:parallel`' but you don't launch with '`srun`', your anemoi-inference job may hang as only 1 process will be launched. + + If you specify '`runner:parallel`' but you don't launch with + '`srun`', your anemoi-inference job may hang as only 1 process will + be launched. .. note:: - By default, anemoi-inference will determine your systems master address and port itself. If this fails (i.e. when running Anemoi-Inference inside a container), you can instead set these values yourself via environment variables in your SLURM batch script: + + By default, anemoi-inference will determine your systems master + address and port itself. If this fails (i.e. when running + Anemoi-Inference inside a container), you can instead set these + values yourself via environment variables in your SLURM batch script: .. code:: bash @@ -57,4 +67,3 @@ Below is an example SLURM batch script to launch a parallel inference job across export MASTER_PORT=$((10000 + RANDOM % 10000)) srun anemoi-inference run parallel.yaml - From 079036af9be823e0f518f8ac899f7c2dec5989f5 Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 21 Jan 2025 17:46:32 +0100 Subject: [PATCH 23/24] forgot precommit --- docs/parallel.rst | 73 ++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/docs/parallel.rst b/docs/parallel.rst index 1f70775..17412aa 100644 --- a/docs/parallel.rst +++ b/docs/parallel.rst @@ -1,54 +1,64 @@ -################### +#################### Parallel Inference -################### +#################### -If the memory requirements of your model are too large to fit within a single GPU, you run Anemoi-Inference in parallel across multiple GPUs. +If the memory requirements of your model are too large to fit within a +single GPU, you run Anemoi-Inference in parallel across multiple GPUs. -Parallel inference requires SLURM to launch the parallel processes and to determine information about your network environment. If SLURM is not available to you, please create an issue on the Anemoi-Inference github page. +Parallel inference requires SLURM to launch the parallel processes and +to determine information about your network environment. If SLURM is not +available to you, please create an issue on the Anemoi-Inference github +page. -************** +*************** Configuration -************** +*************** -To run in parallel, you must add '`runner:parallel`' to your inference config file. +To run in parallel, you must add '`runner:parallel`' to your inference +config file. .. code:: yaml - checkpoint: /path/to/inference-last.ckpt - lead_time: 60 - runner: parallel - input: - grib: /path/to/input.grib - output: - grib: /path/to/output.grib + checkpoint: /path/to/inference-last.ckpt + lead_time: 60 + runner: parallel + input: + grib: /path/to/input.grib + output: + grib: /path/to/output.grib - - - -****************************** +******************************* Running inference in parallel -****************************** +******************************* -Below is an example SLURM batch script to launch a parallel inference job across 4 GPUs. +Below is an example SLURM batch script to launch a parallel inference +job across 4 GPUs. .. code:: bash - #!/bin/bash - #SBATCH --nodes=1 - #SBATCH --ntasks-per-node=4 - #SBATCH --gpus-per-node=4 - #SBATCH --cpus-per-task=8 - #SBATCH --time=0:05:00 - #SBATCH --output=outputs/paralell_inf.%j.out + #!/bin/bash + #SBATCH --nodes=1 + #SBATCH --ntasks-per-node=4 + #SBATCH --gpus-per-node=4 + #SBATCH --cpus-per-task=8 + #SBATCH --time=0:05:00 + #SBATCH --output=outputs/paralell_inf.%j.out - source /path/to/venv/bin/activate - srun anemoi-inference run parallel.yaml + source /path/to/venv/bin/activate + srun anemoi-inference run parallel.yaml .. warning:: - If you specify '`runner:parallel`' but you don't launch with '`srun`', your anemoi-inference job may hang as only 1 process will be launched. + + If you specify '`runner:parallel`' but you don't launch with + '`srun`', your anemoi-inference job may hang as only 1 process will + be launched. .. note:: - By default, anemoi-inference will determine your systems master address and port itself. If this fails (i.e. when running Anemoi-Inference inside a container), you can instead set these values yourself via environment variables in your SLURM batch script: + + By default, anemoi-inference will determine your systems master + address and port itself. If this fails (i.e. when running + Anemoi-Inference inside a container), you can instead set these + values yourself via environment variables in your SLURM batch script: .. code:: bash @@ -57,4 +67,3 @@ Below is an example SLURM batch script to launch a parallel inference job across export MASTER_PORT=$((10000 + RANDOM % 10000)) srun anemoi-inference run parallel.yaml - From b8be926429c7c543c9a113cab54afd6565d71bdf Mon Sep 17 00:00:00 2001 From: cathalobrien Date: Tue, 21 Jan 2025 18:02:20 +0100 Subject: [PATCH 24/24] docs feedback --- docs/parallel.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/parallel.rst b/docs/parallel.rst index 17412aa..1914087 100644 --- a/docs/parallel.rst +++ b/docs/parallel.rst @@ -3,12 +3,13 @@ #################### If the memory requirements of your model are too large to fit within a -single GPU, you run Anemoi-Inference in parallel across multiple GPUs. +single GPU, you can run Anemoi-Inference in parallel across multiple +GPUs. Parallel inference requires SLURM to launch the parallel processes and to determine information about your network environment. If SLURM is not available to you, please create an issue on the Anemoi-Inference github -page. +page `here `_. *************** Configuration @@ -42,7 +43,7 @@ job across 4 GPUs. #SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=8 #SBATCH --time=0:05:00 - #SBATCH --output=outputs/paralell_inf.%j.out + #SBATCH --output=outputs/parallel_inf.%j.out source /path/to/venv/bin/activate srun anemoi-inference run parallel.yaml