diff --git a/models/CHANGELOG.md b/models/CHANGELOG.md index ba61ebe3..d8d88e60 100644 --- a/models/CHANGELOG.md +++ b/models/CHANGELOG.md @@ -18,6 +18,7 @@ Keep it human-readable, your future self will thank you! - Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84) - Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) - Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88](https://github.com/ecmwf/anemoi-models/pull/88) +- Add sequence sharding strategy for TransformerProcessor [#67](https://github.com/ecmwf/anemoi-core/pull/67) ## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design diff --git a/models/src/anemoi/models/distributed/transformer.py b/models/src/anemoi/models/distributed/transformer.py index 78691bba..83b60ade 100644 --- a/models/src/anemoi/models/distributed/transformer.py +++ b/models/src/anemoi/models/distributed/transformer.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import logging from typing import Optional import torch @@ -17,6 +18,8 @@ from anemoi.models.distributed.utils import get_memory_format +LOGGER = logging.getLogger(__name__) + def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor: """Apply all_to_all along the head dimension. @@ -82,6 +85,72 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format) +def _halo_comm(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bool = False) -> Tensor: + """Exchange halo regions between neighboring ranks. + + Expected format is (batch_size, halo_size + sequence_length + halo_size, channels). + + Parameters + ---------- + input_ : Tensor + Input tensor + halo_size : int + Halo size (left, right) + mgroup : ProcessGroup + Model communication group + bwd : bool + Flag to indicate if backward pass + + Returns + ------- + Tensor + Tensor with halo regions from neighboring ranks + """ + end = input_.shape[-2] + + left_halo_slice = slice(0, halo_size) + right_halo_slice = slice(end - halo_size, end) + left_send_slice = slice(halo_size, 2 * halo_size) + right_send_slice = slice(end - 2 * halo_size, end - halo_size) + + if bwd: # reverse halo exchange direction for gradient accumulation + left_halo_slice, left_send_slice = left_send_slice, left_halo_slice + right_halo_slice, right_send_slice = right_send_slice, right_halo_slice + + left_send = input_[:, left_send_slice, :] + right_send = input_[:, right_send_slice, :] + + # setup neighbor ranks and tensor lists for all_to_all communication + group_rank = dist.get_rank(mgroup) + group_size = dist.get_world_size(mgroup) + left_rank = group_rank - 1 if group_rank > 0 else None + right_rank = group_rank + 1 if group_rank < group_size - 1 else None + + input_list = [torch.empty(0, device=input_.device) for _ in range(group_size)] + if left_rank is not None: + input_list[left_rank] = left_send + if right_rank is not None: + input_list[right_rank] = right_send + output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list] + + dist.all_to_all(output_list, input_list, group=mgroup) + + if bwd: # add gradient contributions to halo regions and zero out send regions + if left_rank is not None: + input_[:, left_send_slice, :] = 0 + input_[:, left_halo_slice, :] += output_list[left_rank] + if right_rank is not None: + input_[:, right_send_slice, :] = 0 + input_[:, right_halo_slice, :] += output_list[right_rank] + else: # add halo regions to input tensor + if left_rank is not None: + input_[:, left_halo_slice, :] = output_list[left_rank] + if right_rank is not None: + input_[:, right_halo_slice, :] = output_list[right_rank] + + return input_ + + def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor: """Sync tensor. @@ -130,6 +199,49 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor return _SplitSequenceParallelSection.apply(input_, shapes, mgroup) +def add_halos(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> Tensor: + halo_size_left = halo_size if dist.get_rank(mgroup) != 0 else 0 + halo_size_right = halo_size if dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1 else 0 + + return ( + torch.nn.functional.pad(x, pad=(0, 0, halo_size_left, halo_size_right), mode="constant", value=0), + halo_size_left, + halo_size_right, + ) + + +def remove_halos(x: Tensor, halo_size_left: int, halo_size_right: int) -> Tensor: + return x[:, :, halo_size_left : x.shape[-2] - halo_size_right, :] + + +def halo_exchange(x: Tensor, halo_size: int, mgroup: Optional[ProcessGroup] = None) -> Tensor: + """Exchange halo regions between ranks, + + Parameters + ---------- + x : Tensor + Input tensor + halo_size : int + Halo size (left, right) + mgroup : ProcessGroup + Model communication group + + Returns + ------- + Tensor, int, int + Tensor appended with halo regions from neighboring ranks, left halo size, right halo size + """ + if mgroup is None or dist.get_world_size(mgroup) == 1: + return x, 0, 0 + + # pad tensor with halo regions + out, halo_size_left, halo_size_right = add_halos(x, halo_size, mgroup) + + out = _HaloExchangeParallelSection.apply(out, halo_size, mgroup) + + return out, halo_size_left, halo_size_right + + class _SplitHeadsParallelSection(torch.autograd.Function): """Sync the input from parallel section.""" @@ -172,3 +284,27 @@ def backward(ctx, grad_output): None, ) return grad_output, None, None + + +class _HaloExchangeParallelSection(torch.autograd.Function): + """Exchange halo regions between ranks.""" + + @staticmethod + def forward(ctx, input_, halo_size_, mgroup_): + ctx.halo_size = halo_size_ + ctx.mgroup = mgroup_ + + if mgroup_: + return _halo_comm(input_, halo_size_, mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.mgroup: + return ( + _halo_comm(grad_output, ctx.halo_size, ctx.mgroup, bwd=True), + None, + None, + ) + + return grad_output, None, None diff --git a/models/src/anemoi/models/layers/attention.py b/models/src/anemoi/models/layers/attention.py index d7f54920..96230913 100644 --- a/models/src/anemoi/models/layers/attention.py +++ b/models/src/anemoi/models/layers/attention.py @@ -25,6 +25,9 @@ else: _FLASH_ATTENTION_AVAILABLE = True + +from anemoi.models.distributed.transformer import halo_exchange +from anemoi.models.distributed.transformer import remove_halos from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence @@ -42,6 +45,7 @@ def __init__( is_causal: bool = False, window_size: Optional[int] = None, dropout_p: float = 0.0, + shard_strategy: str = "shard_heads", ): super().__init__() @@ -55,6 +59,7 @@ def __init__( self.window_size = (window_size, window_size) # flash attention self.dropout_p = dropout_p self.is_causal = is_causal + self.shard_strategy = shard_strategy self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func @@ -62,17 +67,50 @@ def __init__( if not _FLASH_ATTENTION_AVAILABLE: LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") + if shard_strategy not in ["shard_heads", "shard_sequence"]: + raise ValueError(f"Invalid shard_strategy: {shard_strategy}") + + if shard_strategy == "shard_sequence": # remove this after PR #47 is merged (sliding window support) + assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy" + self.projection = nn.Linear(embed_dim, embed_dim, bias=True) - def forward( + def get_qkv_shard_sequence( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: - query, key, value = self.lin_qkv(x).chunk(3, -1) + assert ( + shapes[-1][0] // 2 >= self.window_size[0] + ), f"Sharded sequence length ({shapes[-1][0]}) must be at least twice the window size (2*{self.window_size[0]})" + + # unpack grid dimension first to allow for halo exchange + x = einops.rearrange( + x, + "(batch grid) channels -> batch grid channels", + batch=batch_size, + ) - if model_comm_group: - assert ( - model_comm_group.size() == 1 or batch_size == 1 - ), "Only batch size of 1 is supported when model is sharded accross GPUs" + # communicate halos (adds halos to x) + x_plus_halos, halo_size_left, halo_size_right = halo_exchange( + x, halo_size=self.window_size[0], mgroup=model_comm_group + ) + + query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1) + + query, key, value = ( + einops.rearrange( + t, + "batch grid (heads vars) -> batch heads grid vars", + heads=self.num_heads, + ) + for t in (query, key, value) + ) + + return query, key, value, halo_size_left, halo_size_right + + def get_qkv_shard_heads( + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + ) -> Tensor: + query, key, value = self.lin_qkv(x).chunk(3, -1) query, key, value = ( einops.rearrange( @@ -87,6 +125,24 @@ def forward( query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) + + return query, key, value + + def forward( + self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None + ) -> Tensor: + if model_comm_group: + assert ( + model_comm_group.size() == 1 or batch_size == 1 + ), "Only batch size of 1 is supported when model is sharded accross GPUs" + + if self.shard_strategy == "shard_sequence": + query, key, value, halo_size_left, halo_size_right = self.get_qkv_shard_sequence( + x, shapes, batch_size, model_comm_group + ) + if self.shard_strategy == "shard_heads": + query, key, value = self.get_qkv_shard_heads(x, shapes, batch_size, model_comm_group) + dropout_p = self.dropout_p if self.training else 0.0 if _FLASH_ATTENTION_AVAILABLE: @@ -104,7 +160,11 @@ def forward( dropout_p=dropout_p, ) # expects (batch heads grid variable) format - out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) + if self.shard_strategy == "shard_sequence": + out = remove_halos(out, halo_size_left, halo_size_right) + if self.shard_strategy == "shard_heads": + out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) + out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") out = self.projection(out) diff --git a/models/src/anemoi/models/layers/block.py b/models/src/anemoi/models/layers/block.py index 72e487d2..d68f2e0e 100644 --- a/models/src/anemoi/models/layers/block.py +++ b/models/src/anemoi/models/layers/block.py @@ -69,6 +69,7 @@ def __init__( activation: str, window_size: int, dropout_p: float = 0.0, + shard_strategy: str = "shard_heads", ): super().__init__() @@ -87,6 +88,7 @@ def __init__( bias=False, is_causal=False, dropout_p=dropout_p, + shard_strategy=shard_strategy, ) self.mlp = nn.Sequential( diff --git a/models/src/anemoi/models/layers/chunk.py b/models/src/anemoi/models/layers/chunk.py index 5c4fae38..d3365f04 100644 --- a/models/src/anemoi/models/layers/chunk.py +++ b/models/src/anemoi/models/layers/chunk.py @@ -75,6 +75,7 @@ def __init__( mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, + shard_strategy: str = "shard_heads", ) -> None: """Initialize TransformerProcessor. @@ -92,6 +93,8 @@ def __init__( Activation function, by default "GELU" dropout_p: float Dropout probability used for multi-head self attention, default 0.0 + shard_strategy: str + Strategy for sharding either "shard_sequence" or "shard_heads", by default "shard_sequence" """ super().__init__(num_channels=num_channels, num_layers=num_layers) @@ -103,6 +106,7 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, + shard_strategy=shard_strategy, ) def forward( diff --git a/models/src/anemoi/models/layers/processor.py b/models/src/anemoi/models/layers/processor.py index 8dba1f66..c1b4d53e 100644 --- a/models/src/anemoi/models/layers/processor.py +++ b/models/src/anemoi/models/layers/processor.py @@ -97,6 +97,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, + shard_strategy: str = "shard_sequence", **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -117,6 +118,8 @@ def __init__( Activation function, by default "GELU" dropout_p: float, optional Dropout probability used for multi-head self attention, default 0.0 + shard_strategy: str, optional + Strategy for sharding either "shard_sequence" or "shard_heads", by default "shard_sequence" """ super().__init__( num_channels=num_channels, @@ -138,6 +141,7 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, + shard_strategy=shard_strategy, ) self.offload_layers(cpu_offload) diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index cd6a1e7b..b4238c02 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -14,6 +14,7 @@ processor: num_heads: 16 # GraphTransformer or Transformer only window_size: 512 dropout_p: 0.0 # GraphTransformer + shard_strategy: shard_sequence # Options: shard_sequence, shard_heads encoder: _target_: anemoi.models.layers.mapper.GraphTransformerForwardMapper