diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 574e7cab..85bca3de 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -349,7 +349,7 @@ def compute_norm(gradients, parameters, norm_type=2, zero_mode=ParallelMode.ZERO # model and zero have been reduced!!! if zero_mode == ParallelMode.EXPERT_DATA: pg = gpc.get_group(ParallelMode.EXPERT) - scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) + scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)) scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor.item()