Skip to content

Commit

Permalink
Feat/refactor process group (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiacx authored Dec 10, 2024
1 parent 71c32c8 commit cd53c32
Show file tree
Hide file tree
Showing 5 changed files with 415 additions and 69 deletions.
6 changes: 2 additions & 4 deletions configs/57B_qwen2_MoE.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -201,15 +200,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
6 changes: 2 additions & 4 deletions configs/8x22B_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -202,15 +201,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
6 changes: 2 additions & 4 deletions configs/8x7B_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -202,15 +201,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
110 changes: 58 additions & 52 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from internlm.utils.utils import TensorParallelMode

from . import process_group_initializer as pgroup_initializer
from .process_group_initializer import ParallelMode
from .process_group_initializer import (
GroupConfig,
ParallelMode,
create_parallel_process_groups,
create_single_process_group,
generate_2d_attn_process_group,
generate_parallel_group_configs,
)
from .random import add_seed, get_seeds, set_mode

# for layernorm
Expand Down Expand Up @@ -633,60 +639,60 @@ def init_parallel_groups(self):

self.check_sanity()

initializer_args = [
rank,
world_size,
self.weight_parallel_size,
self.weight_data_parallel_size,
self.sequence_parallel_size,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.zero1_parallel_size,
self.nettest_parallel_size,
self.expert_parallel_size,
self.expert_tensor_parallel_size,
self.expert_weight_parallel_size,
self.expert_data_parallel_size,
parallel_config.sequence_2D,
]

# run initialization of different process groups
initializers = []
if "gqa" in parallel_config and parallel_config["gqa"] is True:
initializers.append(pgroup_initializer.Initializer_GQA(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args))
parallel_sizes = {
ParallelMode.TENSOR: self.tensor_parallel_size,
ParallelMode.SEQUENCE: self.sequence_parallel_size,
ParallelMode.PIPELINE: self.pipeline_parallel_size,
ParallelMode.DATA: self.data_parallel_size,
ParallelMode.ZERO1: self.zero1_parallel_size,
ParallelMode.WEIGHT: self.weight_parallel_size,
ParallelMode.WEIGHT_DATA: self.weight_data_parallel_size,
ParallelMode.NETTEST: self.nettest_parallel_size,
ParallelMode.EXPERT: self.expert_parallel_size,
ParallelMode.EXPERT_WEIGHT: self.expert_weight_parallel_size,
ParallelMode.EXPERT_TENSOR: self.expert_tensor_parallel_size,
ParallelMode.EXPERT_DATA: self.expert_data_parallel_size,
}

# process groups for parallelism.
enable_moe = self.config.model.get("num_experts", 1) > 1
tp_mode = "mtp" if isinstance(parallel_config.tensor, int) else parallel_config.tensor.get("mode", "mtp")
is_fsdp = False if isinstance(parallel_config.zero1, int) else parallel_config.zero1.get("fsdp", False)
parallel_strategy = "fsdp" if is_fsdp else tp_mode
group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe)
group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False)

# process group for extra gqa tensor parallel.
if (
isinstance(parallel_config["tensor"], dict)
and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name
"num_kv_attention_heads" in self.config.model
and self.config.model.num_kv_attention_heads < self.tensor_parallel_size
):
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
else:
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if isinstance(parallel_config["zero1"], dict) and parallel_config["zero1"].get("fsdp", False):
initializers.append(pgroup_initializer.Initializer_Zero3_dp(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
if self.config.model.get("num_experts", 1) > 1:
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
initializers.append(pgroup_initializer.Initializer_Expert_Weight_Data(*initializer_args))
else:
initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args))
group_results.append(
create_single_process_group(
world_size,
rank,
GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads),
)
)

# process group for network test.
group_results.append(
create_single_process_group(
world_size,
rank,
GroupConfig(ParallelMode.NETTEST, self.nettest_parallel_size, allow_partial_group=True),
)
)

# process group for isp 2D attn.
if parallel_config.sequence_2D.get("enable", False) is True:
initializers.append(pgroup_initializer.Initializer_2D_SEQUENCE_PARALLEL(*initializer_args))
group_results.extend(
generate_2d_attn_process_group(world_size, rank, parallel_config.sequence_2D, parallel_sizes)
)

for initializer in initializers:
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
self._register_dist(*args)
else:
self._register_dist(*parallel_setting)
# register process groups
for result in group_results:
self._register_dist(*result)

def is_initialized(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether `parallel_mode` is initialized
Expand Down
Loading

0 comments on commit cd53c32

Please sign in to comment.