Skip to content

Commit

Permalink
feat(isp): support switch for launch ag and forward overlap per module (
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 authored Dec 17, 2024
1 parent e60a50a commit 0ec6cdc
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 67 deletions.
12 changes: 10 additions & 2 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. launch_allgather_before: str, before which module to launch the all gather communication to
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
Must be used with forward_overlap_per 'layer'.
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -193,14 +197,18 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. launch_allgather_before: str, before which module to launch the all gather communication to
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
Must be used with forward_overlap_per 'layer'.
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True),
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True),
expert_weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
)

cudnn_deterministic = False
Expand Down
6 changes: 5 additions & 1 deletion configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. launch_allgather_before: str, before which module to launch the all gather communication to
prefetch next layer's weight, should be in ['wqkv', 'attn', 'wo', 'w1'], defaults to 'wo'.
Must be used with forward_overlap_per 'layer'.
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
sequence_2D (dict):
1. enable: bool, whether enable the 2D sequence parallel or not.
2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses).
Expand All @@ -205,7 +209,7 @@
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=4, overlap=True),
weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
sequence_2D=dict(
enable=False,
head_size=2,
Expand Down
160 changes: 102 additions & 58 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,13 @@ def __init__(
dtype: torch.dtype = torch.half,
device: torch.device = None,
activation_checkpointing: float = 0.0,
module_shapes: Dict[str, torch.Size] = None,
) -> None:
self.dtype = dtype
if device is None:
self.device = get_current_device()
else:
self.device = device
self.activation_checkpointing = activation_checkpointing
self.module_shapes = module_shapes


class ISPOverlapState:
Expand All @@ -285,7 +283,7 @@ class ISPOverlapState:
def __init__(self) -> None:
self.num_blocks: int = 0
self.ckpt_block_num: int = 0
self.isp_outs: List[nn.Module] = []
self.isp_prefetch_launch_module: List[nn.Module] = []
self.isp_modules: List[nn.Module] = []
self.index_to_isp_modules: Dict[int, nn.Module] = {}
self.index_to_block: Dict[int, nn.Module] = {}
Expand Down Expand Up @@ -315,16 +313,17 @@ def __init__(
self.is_moe = is_moe
self.is_forward = True
self.reduce_scatter_handlers = {}
self._module_shapes = {}
self._forward_prefetch_prerequisites = []
self._forward_overlap_per = self._get_forward_overlap_granularity()
self._launch_before_module = self._get_launch_before_module()

# real overlap state for each chunk.
self._overlap_states: Dict[int, ISPOverlapState] = {}

# inner interface variables of overlap state.
self._num_blocks = None
self._ckpt_block_num = None
self._isp_outs = None
self._isp_prefetch_launch_module = None
self._isp_modules = None
# key: isp module; value: module global all-gather op handle
self._weight_global_handle = None
Expand All @@ -351,14 +350,46 @@ def __init__(
self._register_sync_parameters_hook()
# switch to chunk 0 at first.
self.switch_current_model_chunk(0)
self.model_conf.module_shapes = self._module_shapes

def _get_launch_before_module(self):
if self.is_moe is True:
_launch_before = gpc.config.parallel.expert_weight.get("launch_allgather_before", "wo")
else:
_launch_before = gpc.config.parallel.weight.get("launch_allgather_before", "wo")

if _launch_before == "wqkv":
return ["wqkv", "Wqkv", "qkv", "q_a_proj", "q_proj"]
elif _launch_before == "attn":
return ["attn"]
elif _launch_before == "wo":
return ["out_proj", "wo"]
elif _launch_before == "w1":
return ["w1", "fused_w1_w3"]
else:
assert False, "launch module should be in ['wqkv', 'attn', 'wo', 'w1']"

def _get_forward_overlap_granularity(self):
if self.is_moe is True:
_overlap_granularity = gpc.config.parallel.expert_weight.get("forward_overlap_per", "layer")
else:
_overlap_granularity = gpc.config.parallel.weight.get("forward_overlap_per", "layer")

assert _overlap_granularity in ["module", "layer"]
return _overlap_granularity

def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
self._overlap_states[cid] = ISPOverlapState()

def get_model(obj: nn.Module) -> nn.Module:
return get_model(obj.model) if hasattr(obj, "model") else obj

def is_allgather_launch_module(name, module):
return (
hasattr(module, "is_attn_cls")
and getattr(module, "is_attn_cls")
and self._launch_before_module == ["attn"]
) or (name.split(".")[-1] in self._launch_before_module)

# Important: only works for llama-class models
children_name = get_model(model).named_children()
for _, children in children_name:
Expand All @@ -369,18 +400,12 @@ def get_model(obj: nn.Module) -> nn.Module:
self._overlap_states[cid].index_to_isp_modules[idx] = []
self._overlap_states[cid].index_to_block[idx] = block
for name, child in block.named_modules():
if name.split(".")[-1] in ["out_proj", "wo"]:
self._overlap_states[cid].isp_outs.append(child)
self._overlap_states[cid].module_to_index[child] = idx
if is_allgather_launch_module(name, child):
self._overlap_states[cid].isp_prefetch_launch_module.append(child)
if isinstance(child, (ParallelLinearWithCommExt)):
if is_moe_param(child.weight) != self.is_moe:
continue
if name not in self._module_shapes:
weight_parallel_size = dist.get_world_size(self.process_group)
origin_shape = tuple(
[child.weight.shape[0] * weight_parallel_size] + list(child.weight.shape[1:])
)
self._module_shapes[name] = torch.Size(origin_shape)

self._overlap_states[cid].module_to_index[child] = idx
self._overlap_states[cid].isp_modules.append(child)
self._overlap_states[cid].index_to_isp_modules[idx].append(child)
Expand All @@ -403,25 +428,28 @@ def get_model(obj: nn.Module) -> nn.Module:
self._overlap_states[cid].num_blocks = len(self._overlap_states[cid].index_to_isp_modules)

def _all_gather_module_weight(self, module):
assert module not in self._bias_global_output and module not in self._weight_global_output
with_bias = module.bias is not None

# submit the all-gather communication for weight and bias.
if with_bias:
bias_output, bias_handle = all_gather_raw(
module.bias,
if module not in self._bias_global_output:
bias_output, bias_handle = all_gather_raw(
module.bias,
self.process_group,
async_op=True,
)
self._bias_global_handle[module] = bias_handle
self._bias_global_output[module] = bias_output

if module not in self._weight_global_output:
weight_output, weight_handle = all_gather_raw(
module.weight,
self.process_group,
async_op=True,
)
self._bias_global_handle[module] = bias_handle
self._bias_global_output[module] = bias_output

weight_output, weight_handle = all_gather_raw(
module.weight,
self.process_group,
async_op=True,
)
self._weight_global_handle[module] = weight_handle
self._weight_global_output[module] = weight_output
self._weight_global_handle[module] = weight_handle
self._weight_global_output[module] = weight_output

def _all_gather_block_weight(self, block_index: int):
block = self._index_to_block[block_index]
Expand Down Expand Up @@ -463,30 +491,53 @@ def _pre_forward_hook_for_first_block(self, *args): # pylint: disable=W0613
"""
prefetch weight for block 0 before forward.
"""
if self.is_forward is True:
if self._forward_overlap_per == "layer" and self.is_forward is True:
self._all_gather_block_weight(0)

def _pre_forward_hook_for_last_ckpt_block(self, *args): # pylint: disable=W0613
if self.is_forward is False:
self._all_gather_block_weight(self._ckpt_block_num - 1)

def _pre_forward_hook_for_out_proj(self, module: nn.Module, *args): # pylint: disable=W0613
def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args): # pylint: disable=W0613
block_index = self._module_to_index[module]

if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False:
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
else:
# start the all-gather for next block
if block_index + 1 < self._num_blocks:
self._all_gather_block_weight(block_index + 1)
if self._forward_overlap_per == "layer":
if (block_index - 1 < self._ckpt_block_num) and self.is_forward is False:
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
else:
# start the all-gather for next block
if block_index + 1 < self._num_blocks:
self._all_gather_block_weight(block_index + 1)

def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if module not in self._weight_global_handle:
self._all_gather_module_weight(module)

self._wait_handle(module)

if self._forward_overlap_per == "module":
# start the all-gather for next module
# 1.forward prefetch for next module
module_index = self._isp_modules.index(module)
module_layer_id = self._module_to_index[module]
if module_index + 1 < len(self._isp_modules) and self.is_forward is True:
next_module = self._isp_modules[module_index + 1]
self._all_gather_module_weight(next_module)

# 2.recompute forward prefetch for next module
if self.is_forward is False:
if module_index + 1 < len(self._isp_modules):
next_module = self._isp_modules[module_index + 1]
next_module_layer_id = self._module_to_index[next_module]
if module_layer_id == next_module_layer_id:
self._all_gather_module_weight(next_module)
# if current module is the last module in current layer, prefetch previous layer's first module
elif module_layer_id - 1 >= 0:
next_module = self._index_to_isp_modules[module_layer_id - 1][0]
self._all_gather_module_weight(next_module)
else:
# if current module is the last module, prefetch previous layer's first module
if module_layer_id - 1 >= 0:
next_module = self._index_to_isp_modules[module_layer_id - 1][0]
self._all_gather_module_weight(next_module)

def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
self._clear_handle(module)
Expand Down Expand Up @@ -515,29 +566,24 @@ def _register_sync_parameters_hook(self) -> None:
register forward hooks and backward hooks for isp modules.
"""
# register forward hooks
# 1. register pre_forward_hook @block_0 to prefetch for block 0
# 2. register pre_forward_hook @block_(ckpt_block_num-1) to prefetch for the last ckpt block
# 3. register pre_forward_hook @out_proj module to prefetch for next block,
# notice that next block's all_gather op should be after current block's all_to_all op
# 4. register pre_forward_hook @isp_module to wait handle for current module
# 5. register post_forward_hook @isp_module to release resource
# 1. register pre_forward_hook @block_0 to prefetch weight for block 0.
# 2. register pre_forward_hook @prefetch_launch_module to prefetch weight for next block,
# when forward overlap granularity is 'layer'.
# 3. register pre_forward_hook @isp_module to wait handle for current module,
# and prefetch weight for next module when forward overlap granularity is 'module'.
# 4. register post_forward_hook @isp_module to release memory resource.
self._index_to_block[0].register_forward_pre_hook(self._pre_forward_hook_for_first_block)

if self._ckpt_block_num >= 1:
self._index_to_block[self._ckpt_block_num - 1].register_forward_pre_hook(
self._pre_forward_hook_for_last_ckpt_block
)

for out_proj in self._isp_outs:
out_proj.register_forward_pre_hook(self._pre_forward_hook_for_out_proj)
for module in self._isp_prefetch_launch_module:
module.register_forward_pre_hook(self._pre_forward_hook_for_prefetch_launch_module)

for module in self._isp_modules:
module.register_forward_pre_hook(self._pre_forward_hook_for_module)
module.register_forward_hook(self._post_forward_hook_for_module)

# register backward hooks
# 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module
# 2. register post_backward_hook @isp_module to release resource
# 1. register pre_backward_hook @isp_module to wait handle for current module and to prefetch for next module.
# 2. register post_backward_hook @isp_module to release memory resource.
if self._ckpt_block_num < self._num_blocks:
for module in self._isp_modules:
module.register_full_backward_pre_hook(self._pre_backward_hook_for_module)
Expand All @@ -556,7 +602,7 @@ def communication_mode(self) -> str:
return "wp"

def switch_current_model_chunk(self, chunk_id: int) -> None:
self._isp_outs = self._overlap_states[chunk_id].isp_outs
self._isp_prefetch_launch_module = self._overlap_states[chunk_id].isp_prefetch_launch_module
self._isp_modules = self._overlap_states[chunk_id].isp_modules
self._weight_global_handle = self._overlap_states[chunk_id].weight_global_handle
self._bias_global_handle = self._overlap_states[chunk_id].bias_global_handle
Expand Down Expand Up @@ -872,9 +918,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten

q, kv = _SeqAllToAll.apply(self.spg, [2, 3], [1, 1], q, kv)

torch.cuda.synchronize()
context = self.local_attn(q, kv, *args, **kwargs)
torch.cuda.synchronize()

context = _SeqAllToAll.apply(self.spg, 1, 2, context)

Expand Down
8 changes: 6 additions & 2 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,17 @@ def args_sanity_check():
gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name))

if "weight" not in gpc.config.parallel:
gpc.config.parallel._add_item("weight", dict(size=1, overlap=False))
gpc.config.parallel._add_item(
"weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer")
)

if "expert" not in gpc.config.parallel:
gpc.config.parallel._add_item("expert", dict(size=-1, no_tp=False))

if "expert_weight" not in gpc.config.parallel:
gpc.config.parallel._add_item("expert_weight", dict(size=1, overlap=False))
gpc.config.parallel._add_item(
"expert_weight", dict(size=1, overlap=False, launch_allgather_before="wo", forward_overlap_per="layer")
)

if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
Expand Down
2 changes: 2 additions & 0 deletions internlm/model/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,8 @@ class SelfAttention(nn.Module):
attention_dropout (float): Dropout rate for attention scores. Defaults to 0.0.
"""

is_attn_cls = True

def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, layer_idx=0):
super().__init__()
self.causal = causal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def forward(
@staticmethod
def backward(ctx, dout, *args): # pylint: disable=W0613

torch.cuda.synchronize()
q, k, v, out, softmax_lse = ctx.saved_tensors

dq, dk, dv = zigzag_double_ring_flash_attn_backward(
Expand All @@ -504,8 +503,6 @@ def backward(ctx, dout, *args): # pylint: disable=W0613
deterministic=ctx.deterministic,
)

torch.cuda.synchronize()

return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None


Expand Down
2 changes: 1 addition & 1 deletion tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def train(
config.hybrid_zero_optimizer.overlap_sync_grad = False

config.parallel.pipeline = dict(size=pp_size, mode=pp_mode)
config.parallel.weight = dict(size=wp_size, overlap=True)
config.parallel.weight = dict(size=wp_size, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer")
if interleaved is True:
config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True, mode=pp_mode)
config.model.num_chunks = num_chunks
Expand Down

0 comments on commit 0ec6cdc

Please sign in to comment.