Skip to content

Commit

Permalink
feat(parallel_context.py): remove useless gqa process group (#390)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 authored Dec 17, 2024
1 parent 5ad2eb0 commit e60a50a
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 94 deletions.
13 changes: 0 additions & 13 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,6 @@ def init_parallel_groups(self):
group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe)
group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False)

# process group for extra gqa tensor parallel.
if (
"num_kv_attention_heads" in self.config.model
and self.config.model.num_kv_attention_heads < self.tensor_parallel_size
):
group_results.append(
create_single_process_group(
world_size,
rank,
GroupConfig(ParallelMode.GQA, self.tensor_parallel_size // self.num_kv_attention_heads),
)
)

# process group for network test.
group_results.append(
create_single_process_group(
Expand Down
81 changes: 0 additions & 81 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ class ParallelMode(Enum):
# real data parallel for isp
ISP_DATA = "isp_data"

# grouped query attention
GQA = "gqa"

# sequence 2D parallel
HEAD = "head"
CONTEXT = "context"
Expand Down Expand Up @@ -1454,84 +1451,6 @@ def init_dist_group(self, use_cpu: bool = False):
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode


class Initializer_GQA(ProcessGroupInitializer):
"""A ProcessGroupInitializer for allreduce kv gradients with common attention head.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
weight_parallel_size (int): Size of model weight parallel.
weight_data_parallel_size (int): Size of data parallel for common weight.
sequence_parallel_size (int): Size of data sequence parallel.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
nettest_parallel_size (int): Size of net testing parallel.
expert_parallel_size (int): Size of expert parallel.
"""

def __init__(self, *args, **kwargs):
self.num_attention_heads = kwargs.pop("num_attention_heads")
self.num_kv_attention_heads = kwargs.pop("num_kv_attention_heads")
super().__init__(*args, **kwargs)
self.kv_head_repeats_num = self.tensor_parallel_size // self.num_kv_attention_heads
self.num_kv_group_per_tp = self.num_kv_attention_heads
self.num_kv_groups = self.num_kv_group_per_tp * self.data_parallel_size

assert self.world_size % self.tensor_parallel_size == 0
assert self.world_size % (self.pipeline_parallel_size * self.tensor_parallel_size) == 0
assert self.pipeline_parallel_size == 1

def init_dist_group(self, use_cpu: bool = False):
"""Initialize weight's data parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A WEIGHT_DATA parallelism's information tuple.
n=128 sp=32 wp=64 zo1=1 with nopp
sp groups: [0-31] [32-63] [64-95] [96-127]
wp groups: [0-63] [64-127]
kv_head groups: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15]
[16,17,18,19] [20,21,22,23] [24,25,26,27] [28,29,30,31]
...
...
...
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.GQA

for i in range(self.data_parallel_size):
for j in range(self.num_kv_group_per_tp):
ranks = [
i * self.tensor_parallel_size + j * self.kv_head_repeats_num + k
for k in range(self.kv_head_repeats_num)
]
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode


class Initializer_2D_SEQUENCE_PARALLEL(ProcessGroupInitializer):
"""
A ProcessGroupInitializer for 2D sequence parallel.
Expand Down

0 comments on commit e60a50a

Please sign in to comment.