From 73b6894e2e7f2984a804f85cca7ff088e66ef28c Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 26 Feb 2024 11:17:02 +0800 Subject: [PATCH] fix(context/process_group_initializer.py): fix gqa process group (#58) --- internlm/core/context/process_group_initializer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 51a722ab..5519c7d8 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -894,10 +894,9 @@ def init_dist_group(self, use_cpu: bool = False): n=128 sp=32 wp=64 zo1=1 with nopp sp groups: [0-31] [32-63] [64-95] [96-127] wp groups: [0-63] [64-127] - kv_head groups: [0,8,16,24] [1,9,17,25] [2,10,18,26] [3,11,19,27] - [4,12,20,28] [5,13,21,29] [6,14,22,30] [7,15,23,31] - [32,40,48,56] [33,41,49,57] [34,42,50,58] [35,43,51,59] - [36,44,52,60] [37,45,53,61] [38,46,54,62] [39,47,55,63] + kv_head groups: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15] + [16,17,18,19] [20,21,22,23] [24,25,26,27] [28,29,30,31] + ... ... ... """ @@ -912,7 +911,7 @@ def init_dist_group(self, use_cpu: bool = False): for i in range(self.data_parallel_size): for j in range(self.num_kv_group_per_tp): ranks = [ - i * self.tensor_parallel_size + j + k * self.num_kv_attention_heads + i * self.tensor_parallel_size + j * self.kv_head_repeats_num + k for k in range(self.kv_head_repeats_num) ] group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)