diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index c558427c..4037c031 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -183,6 +183,10 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. expert parallel (dict): 1. size: int * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size @@ -193,14 +197,18 @@ expert weight parallel (dict): 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. """ parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True), + weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), expert=dict(size=-1, no_tp=False), - expert_weight=dict(size=1, overlap=True), + expert_weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), ) cudnn_deterministic = False diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index de99f917..ad68082d 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -186,6 +186,10 @@ weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. launch_allgather_before: str, before which module to launch the all gather communication to + prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'. + Must be used with forward_overlap_per 'layer'. + 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. sequence_2D (dict): 1. enable: bool, whether enable the 2D sequence parallel or not. 2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses). @@ -205,7 +209,7 @@ zero1=dict(size=-1), tensor=dict(size=2, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=4, overlap=True), + weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), sequence_2D=dict( enable=False, head_size=2, diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index d4950c75..7e722c2f 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -266,7 +266,6 @@ def __init__( dtype: torch.dtype = torch.half, device: torch.device = None, activation_checkpointing: float = 0.0, - module_shapes: Dict[str, torch.Size] = None, ) -> None: self.dtype = dtype if device is None: @@ -274,7 +273,6 @@ def __init__( else: self.device = device self.activation_checkpointing = activation_checkpointing - self.module_shapes = module_shapes class ISPOverlapState: @@ -285,7 +283,7 @@ class ISPOverlapState: def __init__(self) -> None: self.num_blocks: int = 0 self.ckpt_block_num: int = 0 - self.isp_outs: List[nn.Module] = [] + self.isp_prefetch_launch_module: List[nn.Module] = [] self.isp_modules: List[nn.Module] = [] self.index_to_isp_modules: Dict[int, nn.Module] = {} self.index_to_block: Dict[int, nn.Module] = {} @@ -315,8 +313,9 @@ def __init__( self.is_moe = is_moe self.is_forward = True self.reduce_scatter_handlers = {} - self._module_shapes = {} self._forward_prefetch_prerequisites = [] + self._forward_overlap_per = self._get_forward_overlap_granularity() + self._launch_before_module = self._get_launch_before_module() # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -324,7 +323,7 @@ def __init__( # inner interface variables of overlap state. self._num_blocks = None self._ckpt_block_num = None - self._isp_outs = None + self._isp_prefetch_launch_module = None self._isp_modules = None # key: isp module; value: module global all-gather op handle self._weight_global_handle = None @@ -351,7 +350,32 @@ def __init__( self._register_sync_parameters_hook() # switch to chunk 0 at first. self.switch_current_model_chunk(0) - self.model_conf.module_shapes = self._module_shapes + + def _get_launch_before_module(self): + if self.is_moe is True: + _launch_before = gpc.config.parallel.expert_weight.get("launch_allgather_before", "wo") + else: + _launch_before = gpc.config.parallel.weight.get("launch_allgather_before", "wo") + + if _launch_before == "wqkv": + return ["wqkv", "Wqkv", "qkv", "q_a_proj", "q_proj"] + elif _launch_before == "attn": + return ["attn"] + elif _launch_before == "wo": + return ["out_proj", "wo"] + elif _launch_before == "w1": + return ["w1", "fused_w1_w3"] + else: + assert False, "launch module should be in ['wqkv', 'attn', 'wo', 'w1']" + + def _get_forward_overlap_granularity(self): + if self.is_moe is True: + _overlap_granularity = gpc.config.parallel.expert_weight.get("forward_overlap_per", "layer") + else: + _overlap_granularity = gpc.config.parallel.weight.get("forward_overlap_per", "layer") + + assert _overlap_granularity in ["module", "layer"] + return _overlap_granularity def _parse_model_structure(self, cid: int, model: nn.Module) -> None: self._overlap_states[cid] = ISPOverlapState() @@ -359,6 +383,13 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None: def get_model(obj: nn.Module) -> nn.Module: return get_model(obj.model) if hasattr(obj, "model") else obj + def is_allgather_launch_module(name, module): + return ( + hasattr(module, "is_attn_cls") + and getattr(module, "is_attn_cls") + and self._launch_before_module == ["attn"] + ) or (name.split(".")[-1] in self._launch_before_module) + # Important: only works for llama-class models children_name = get_model(model).named_children() for _, children in children_name: @@ -369,18 +400,12 @@ def get_model(obj: nn.Module) -> nn.Module: self._overlap_states[cid].index_to_isp_modules[idx] = [] self._overlap_states[cid].index_to_block[idx] = block for name, child in block.named_modules(): - if name.split(".")[-1] in ["out_proj", "wo"]: - self._overlap_states[cid].isp_outs.append(child) - self._overlap_states[cid].module_to_index[child] = idx + if is_allgather_launch_module(name, child): + self._overlap_states[cid].isp_prefetch_launch_module.append(child) if isinstance(child, (ParallelLinearWithCommExt)): if is_moe_param(child.weight) != self.is_moe: continue - if name not in self._module_shapes: - weight_parallel_size = dist.get_world_size(self.process_group) - origin_shape = tuple( - [child.weight.shape[0] * weight_parallel_size] + list(child.weight.shape[1:]) - ) - self._module_shapes[name] = torch.Size(origin_shape) + self._overlap_states[cid].module_to_index[child] = idx self._overlap_states[cid].isp_modules.append(child) self._overlap_states[cid].index_to_isp_modules[idx].append(child) @@ -403,25 +428,28 @@ def get_model(obj: nn.Module) -> nn.Module: self._overlap_states[cid].num_blocks = len(self._overlap_states[cid].index_to_isp_modules) def _all_gather_module_weight(self, module): + assert module not in self._bias_global_output and module not in self._weight_global_output with_bias = module.bias is not None # submit the all-gather communication for weight and bias. if with_bias: - bias_output, bias_handle = all_gather_raw( - module.bias, + if module not in self._bias_global_output: + bias_output, bias_handle = all_gather_raw( + module.bias, + self.process_group, + async_op=True, + ) + self._bias_global_handle[module] = bias_handle + self._bias_global_output[module] = bias_output + + if module not in self._weight_global_output: + weight_output, weight_handle = all_gather_raw( + module.weight, self.process_group, async_op=True, ) - self._bias_global_handle[module] = bias_handle - self._bias_global_output[module] = bias_output - - weight_output, weight_handle = all_gather_raw( - module.weight, - self.process_group, - async_op=True, - ) - self._weight_global_handle[module] = weight_handle - self._weight_global_output[module] = weight_output + self._weight_global_handle[module] = weight_handle + self._weight_global_output[module] = weight_output def _all_gather_block_weight(self, block_index: int): block = self._index_to_block[block_index] @@ -463,23 +491,20 @@ def _pre_forward_hook_for_first_block(self, *args): # pylint: disable=W0613 """ prefetch weight for block 0 before forward. """ - if self.is_forward is True: + if self._forward_overlap_per == "layer" and self.is_forward is True: self._all_gather_block_weight(0) - def _pre_forward_hook_for_last_ckpt_block(self, *args): # pylint: disable=W0613 - if self.is_forward is False: - self._all_gather_block_weight(self._ckpt_block_num - 1) - - def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613 + def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args): # pylint: disable=W0613 block_index = self._module_to_index[module] - if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False: - if block_index - 1 >= 0: - self._all_gather_block_weight(block_index - 1) - else: - # start the all-gather for next block - if block_index + 1 < self._num_blocks: - self._all_gather_block_weight(block_index + 1) + if self._forward_overlap_per == "layer": + if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False: + if block_index - 1 >= 0: + self._all_gather_block_weight(block_index - 1) + else: + # start the all-gather for next block + if block_index + 1 < self._num_blocks: + self._all_gather_block_weight(block_index + 1) def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: @@ -487,6 +512,32 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._wait_handle(module) + if self._forward_overlap_per == "module": + # start the all-gather for next module + # 1.forward prefetch for next module + module_index = self._isp_modules.index(module) + module_layer_id = self._module_to_index[module] + if module_index + 1 < len(self._isp_modules) and self.is_forward is True: + next_module = self._isp_modules[module_index + 1] + self._all_gather_module_weight(next_module) + + # 2.recompute forward prefetch for next module + if self.is_forward is False: + if module_index + 1 < len(self._isp_modules): + next_module = self._isp_modules[module_index + 1] + next_module_layer_id = self._module_to_index[next_module] + if module_layer_id == next_module_layer_id: + self._all_gather_module_weight(next_module) + # if current module is the last module in current layer, prefetch previous layer's first module + elif module_layer_id - 1 >= 0: + next_module = self._index_to_isp_modules[module_layer_id - 1][0] + self._all_gather_module_weight(next_module) + else: + # if current module is the last module, prefetch previous layer's first module + if module_layer_id - 1 >= 0: + next_module = self._index_to_isp_modules[module_layer_id - 1][0] + self._all_gather_module_weight(next_module) + def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_handle(module) @@ -515,29 +566,24 @@ def _register_sync_parameters_hook(self) -> None: register forward hooks and backward hooks for isp modules. """ # register forward hooks - # 1. register pre_forward_hook @block_0 to prefetch for block 0 - # 2. register pre_forward_hook @block_(ckpt_block_num-1) to prefetch for the last ckpt block - # 3. register pre_forward_hook @out_proj module to prefetch for next block, - # notice that next block's all_gather op should be after current block's all_to_all op - # 4. register pre_forward_hook @isp_module to wait handle for current module - # 5. register post_forward_hook @isp_module to release resource + # 1. register pre_forward_hook @block_0 to prefetch weight for block 0. + # 2. register pre_forward_hook @prefetch_launch_module to prefetch weight for next block, + # when forward overlap granularity is 'layer'. + # 3. register pre_forward_hook @isp_module to wait handle for current module, + # and prefetch weight for next module when forward overlap granularity is 'module'. + # 4. register post_forward_hook @isp_module to release memory resource. self._index_to_block[0].register_forward_pre_hook(self._pre_forward_hook_for_first_block) - if self._ckpt_block_num >= 1: - self._index_to_block[self._ckpt_block_num - 1].register_forward_pre_hook( - self._pre_forward_hook_for_last_ckpt_block - ) - - for out_proj in self._isp_outs: - out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj) + for module in self._isp_prefetch_launch_module: + module.register_forward_pre_hook(self._pre_forward_hook_for_prefetch_launch_module) for module in self._isp_modules: module.register_forward_pre_hook(self._pre_forward_hook_for_module) module.register_forward_hook(self._post_forward_hook_for_module) # register backward hooks - # 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module - # 2. register post_backward_hook @isp_module to release resource + # 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module. + # 2. register post_backward_hook @isp_module to release memory resource. if self._ckpt_block_num < self._num_blocks: for module in self._isp_modules: module.register_full_backward_pre_hook(self._pre_backward_hook_for_module) @@ -556,7 +602,7 @@ def communication_mode(self) -> str: return "wp" def switch_current_model_chunk(self, chunk_id: int) -> None: - self._isp_outs = self._overlap_states[chunk_id].isp_outs + self._isp_prefetch_launch_module = self._overlap_states[chunk_id].isp_prefetch_launch_module self._isp_modules = self._overlap_states[chunk_id].isp_modules self._weight_global_handle = self._overlap_states[chunk_id].weight_global_handle self._bias_global_handle = self._overlap_states[chunk_id].bias_global_handle @@ -872,9 +918,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten q, kv = _SeqAllToAll.apply(self.spg, [2, 3], [1, 1], q, kv) - torch.cuda.synchronize() context = self.local_attn(q, kv, *args, **kwargs) - torch.cuda.synchronize() context = _SeqAllToAll.apply(self.spg, 1, 2, context) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1ac8ef31..35b3d646 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -94,13 +94,17 @@ def args_sanity_check(): gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) if "weight" not in gpc.config.parallel: - gpc.config.parallel._add_item("weight", dict(size=1, overlap=False)) + gpc.config.parallel._add_item( + "weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer") + ) if "expert" not in gpc.config.parallel: gpc.config.parallel._add_item("expert", dict(size=-1, no_tp=False)) if "expert_weight" not in gpc.config.parallel: - gpc.config.parallel._add_item("expert_weight", dict(size=1, overlap=False)) + gpc.config.parallel._add_item( + "expert_weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer") + ) if isinstance(gpc.config.parallel.pipeline, int): pp = gpc.config.parallel.pipeline diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index c2639569..604ea77a 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -886,6 +886,8 @@ class SelfAttention(nn.Module): attention_dropout (float): Dropout rate for attention scores. Defaults to 0.0. """ + is_attn_cls = True + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, layer_idx=0): super().__init__() self.causal = causal diff --git a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py b/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py index 6d531158..5c22fed3 100644 --- a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py +++ b/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py @@ -481,7 +481,6 @@ def forward( @staticmethod def backward(ctx, dout, *args): # pylint: disable=W0613 - torch.cuda.synchronize() q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = zigzag_double_ring_flash_attn_backward( @@ -504,8 +503,6 @@ def backward(ctx, dout, *args): # pylint: disable=W0613 deterministic=ctx.deterministic, ) - torch.cuda.synchronize() - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index d1db7496..2fd8ad4c 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -109,7 +109,7 @@ def train( config.hybrid_zero_optimizer.overlap_sync_grad = False config.parallel.pipeline = dict(size=pp_size, mode=pp_mode) - config.parallel.weight = dict(size=wp_size, overlap=True) + config.parallel.weight = dict(size=wp_size, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer") if interleaved is True: config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True, mode=pp_mode) config.model.num_chunks = num_chunks