From 95dcc048ce95a65db3a476c09ded9a45ec56c938 Mon Sep 17 00:00:00 2001 From: cx <759046501@qq.com> Date: Tue, 10 Sep 2024 10:35:54 +0800 Subject: [PATCH] add vacab parallel embedding (#315) --- internlm/core/parallel/comm/isp.py | 9 +++--- internlm/core/parallel/comm/tensor.py | 15 ++++++++-- internlm/core/parallel/comm/utils.py | 37 ++++++++++++++++++++++++ internlm/model/modules/embedding.py | 41 ++++++++++++++++++++++++--- 4 files changed, 92 insertions(+), 10 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 8a65052f..a1b63c62 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -177,7 +177,7 @@ class EmbeddingWeightParallelCommunicator: def __init__(self, parallel_mode: ParallelMode) -> None: self.parallel_mode = parallel_mode - self.emb_column = 1 + self.gather_dim = 0 self._cur_micro_step = 0 self._num_micro_step = gpc.config.data.micro_num @@ -186,6 +186,7 @@ def register_module_hook(self, module: Embedding1D) -> None: assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D" module.weight.evo_tensor = None + self.gather_dim = 0 if module.vocab_parallel else 1 class PreModuleWrapper(torch.autograd.Function): """ @@ -197,7 +198,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613 if module.weight.evo_tensor is None: module.weight.evo_tensor = module.weight.data - module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column) + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim) inputs = inputs.detach() return inputs @@ -220,7 +221,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613 @staticmethod def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613 - module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column) + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim) return grad_output def _pre_forward_hook(module, inputs): # pylint: disable=W0613 @@ -237,7 +238,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613 def grad_reduce_hook(self, param: torch.Tensor): _grad, _ = reduce_scatter_raw( - param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column + param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.gather_dim ) if param.evo_tensor.grad is None: param.evo_tensor.grad = _grad diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index 2dfc8bd2..453b33f1 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -19,6 +19,7 @@ all_gather_raw, all_reduce_raw, gather_forward_split_backward, + reduce_forward, reduce_scatter_raw, split_forward_gather_backward, ) @@ -341,7 +342,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup """ _emb_dim = 2 # [bsz, seqlen, emb_dim] - return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + if module.vocab_parallel: + output = reduce_forward(output, self._parallel_mode) + else: + output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + + return output class EmbeddingSequenceParallelCommunicator: @@ -363,7 +369,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup """ _emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim] - output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + # tp: + if module.vocab_parallel: + output = reduce_forward(output, self._parallel_mode) + else: + output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + # sp: output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim) return output diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index abd291cf..a7f93c3b 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -117,6 +117,17 @@ def _gather(input_, parallel_mode, dim=-1): return output +def _reduce(input_, parallel_mode): + # skip if only one rank involved + if gpc.get_world_size(parallel_mode) == 1: + return input_ + + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, group=group) + + return input_ + + class _GatherForwardSplitBackward(torch.autograd.Function): """Gather the input from model parallel region and concatenate. @@ -174,6 +185,32 @@ def split_forward_gather_backward(input_, parallel_mode, dim): return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) +class _ReduceForward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(input_): + return _reduce(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode): # pylint: disable=W0613 + return _reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, grad_output): # pylint: disable=W0613 + return grad_output, None + + +def reduce_forward(input_, parallel_mode): + return _ReduceForward.apply(input_, parallel_mode) + + def all_gather_raw( input_: Tensor, process_group: ProcessGroup, diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 5c4e9f65..365ee46a 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -8,6 +8,7 @@ from einops import rearrange from torch import Tensor, nn +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb from internlm.utils.parallel import is_using_isp @@ -33,6 +34,7 @@ def __init__( *args, padding_idx: int = None, dtype: torch.dtype = None, + vocab_parallel: bool = False, **kwargs, ): super().__init__() @@ -42,14 +44,45 @@ def __init__( self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + self.vocab_parallel = vocab_parallel - _parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size + parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size - embed_dim_per_partition = embedding_dim // _parallel_size - self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) + if vocab_parallel: + assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}" + + self.num_embeddings_per_partition = num_embeddings // parallel_size + self.embed_dim_per_partition = embedding_dim + self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + else: + assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}" + + self.num_embeddings_per_partition = num_embeddings + self.embed_dim_per_partition = embedding_dim // parallel_size + self.vocab_start_index = 0 + self.vocab_end_index = self.num_embeddings_per_partition + + self.weight = nn.Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) + ) def forward(self, input_: Tensor) -> Tensor: - return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + if self.vocab_parallel and not is_using_isp(): + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + + output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + if self.vocab_parallel and not is_using_isp(): + output[input_mask, :] = 0.0 + + return output class RotaryEmbedding(torch.nn.Module):