From acfb11e9d1c3ef608cdfb4cdf537830c7e7c9968 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 31 Jul 2024 16:23:11 +0800 Subject: [PATCH] fix(moe): change moe norm reduced group (#289) --- internlm/solver/optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 574e7cab..85bca3de 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -349,7 +349,7 @@ def compute_norm(gradients, parameters, norm_type=2, zero_mode=ParallelMode.ZERO # model and zero have been reduced!!! if zero_mode == ParallelMode.EXPERT_DATA: pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor.item()