Skip to content

Commit

Permalink
fix(context/process_group_initializer.py): fix gqa process group (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 authored Feb 26, 2024
1 parent 8baa139 commit 73b6894
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,10 +894,9 @@ def init_dist_group(self, use_cpu: bool = False):
n=128 sp=32 wp=64 zo1=1 with nopp
sp groups: [0-31] [32-63] [64-95] [96-127]
wp groups: [0-63] [64-127]
kv_head groups: [0,8,16,24] [1,9,17,25] [2,10,18,26] [3,11,19,27]
[4,12,20,28] [5,13,21,29] [6,14,22,30] [7,15,23,31]
[32,40,48,56] [33,41,49,57] [34,42,50,58] [35,43,51,59]
[36,44,52,60] [37,45,53,61] [38,46,54,62] [39,47,55,63]
kv_head groups: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15]
[16,17,18,19] [20,21,22,23] [24,25,26,27] [28,29,30,31]
...
...
...
"""
Expand All @@ -912,7 +911,7 @@ def init_dist_group(self, use_cpu: bool = False):
for i in range(self.data_parallel_size):
for j in range(self.num_kv_group_per_tp):
ranks = [
i * self.tensor_parallel_size + j + k * self.num_kv_attention_heads
i * self.tensor_parallel_size + j * self.kv_head_repeats_num + k
for k in range(self.kv_head_repeats_num)
]
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
Expand Down

0 comments on commit 73b6894

Please sign in to comment.