Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor model parallel isn't handled properly for shared_expert_gate in Qwen2-MoE and kv_a_proj_with_mqa in DeepSeekV2 #397

Open
2 tasks
cosmosZhou opened this issue Dec 15, 2024 · 3 comments · May be fixed by cosmosZhou/Pai-Megatron-Patch#1

Comments

@cosmosZhou
Copy link
Contributor

cosmosZhou commented Dec 15, 2024

Inconsistent gradient calculation within the tensor model parallel group in Qwen2-MoE/DeepSeekV2

Problem of shared_expert_gate in Qwen2-MoE

Let's inspect the code in detail:
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states).view(-1, 1)) * shared_expert_output.view(-1, hidden_states.shape[-1])

Initial status of shared_expert_gate

The shared_expert_gate.weight is cloned at each partition of tensor model parallel group at start-up, equivalently each partition of tensor model parallel group shares the same weight of shared_expert_gate.

Training status of shared_expert_gate

The input hidden_states is of shape (seq_length / tensor_model_parallel_world_size, batch_size, hidden_dim).
As a result, shared_expert_gate operates only on a specific partition of the tensor model parallel group. For inference, it isn't a issue. But for training, the gradient of shared_expert_gate will also be calculated only on a specific partition of the tensor model parallel group, resulting in different versions of shared_expert_gate on different tensor model parallel partitions.

Problem of kv_a_proj_with_mqa in DeepSeekV2

Let's inspect the code in detail:
compressed_kv, _ = self.linear_kv_a_proj_with_mqa(hidden_states)

Initial status of kv_a_proj_with_mqa

In this implementation, kv_a_proj_with_mqa.weight is cloned at each partition of tensor model parallel group at start-up, equivalently each partition of tensor model parallel group shares the same weight of kv_a_proj_with_mqa. This is done all in SelfAttention.init:

self.linear_kv_a_proj_with_mqa = build_module(
    submodules.linear_kv_a_proj_with_mqa,
    self.config.hidden_size,
    (self.config.kv_lora_rank + self.config.qk_rope_head_dim)*self.world_size,
    config=self.config,
    init_method=self.config.init_method,
    gather_output=False,
    bias=False,
    skip_bias_add=False,
    is_expert=False,
)

Training status of kv_a_proj_with_mqa

Likewise, the input hidden_states is of shape (seq_length / tensor_model_parallel_world_size, batch_size, hidden_dim).
As a result, kv_a_proj_with_mqa with multiple independent copies (whose number equals the tensor model parallel world size) operate on all partitions of the tensor model parallel group with different gradient calculations. For inference, it isn't a issue. But for training, the gradient of each independent kv_a_proj_with_mqa will also be calculated independently of each other, resulting in confused different versions of kv_a_proj_with_mqa.

Consequence

Confusion upon Saving Models

The issue is that: when saving shared_expert_gate/kv_a_proj_with_mqa on epoch end, which version of shared_expert_gate/kv_a_proj_with_mqa to save?

Weird Training Instability

Occasionally for Qwen2-MoE, the distributed training process would be suspended due to inconsistent route results in MoE layer forward step.
Since the inputs of MoE layer are inconsistent across tensor model parallel group, it will have different values of indices of optimal experts at different partitions within the tensor model parallel group. This will result in different input shape for the same local expert within the tensor model parallel group. Thus, the forward step of local expert would get stuck, resulting in the suspension of the whole distributed training process.

Solution

shared_expert_gate in Qwen2-MoE

self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

should be modified to :

self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
self.shared_expert_gate.weight.sequence_parallel = config.sequence_parallel

kv_a_proj_with_mqa in DeepSeekV2

initialization

in megatron_patch/model/deepseek_v2/transformer/attention.py, the code:
(self.config.kv_lora_rank + self.config.qk_rope_head_dim)*self.world_size
should be modified to :

self.config.kv_lora_rank + self.config.qk_rope_head_dim

so that, there is only one unique version of kv_a_proj_with_mqa during the distributed training process.

forward step

in megatron_patch/model/deepseek_v2/transformer/attention.py, the code:
compressed_kv, _ = self.linear_kv_a_proj_with_mqa(hidden_states)
should be modified to:

compressed_kv, _ = self.linear_kv_a_proj_with_mqa(hidden_states)
compressed_kv = tensor_parallel.gather_from_tensor_model_parallel_region(compressed_kv) 

hf-to-te weight conversion

When converting DeepSeekV2 weight from huggingface format to transformer_engine format, one minor alteration must be done to cater for the logic of kv_a_proj_with_mqa which should be scattered along the tensor model parallel group.

benefits

These modifications will correctly accumulate the gradients across different partitions within the tensor model parallel group, resulting in consistent status of shared_expert_gate/kv_a_proj_with_mqa across different partitions within the tensor model parallel group.
It will bring the following benefits:

  • When you are saving the weight of shared_expert_gate/kv_a_proj_with_mqa, you will ensure that each version on different partition within the tensor model parallel group is bound to the same!
  • It eliminates the potential crash/suspension of distributed training process.
@cosmosZhou cosmosZhou changed the title shared_expert_gate doesn't handle sequence parallel properly in Qwen2-MoE tensor model parallel isn't handled properly for shared_expert_gate in Qwen2-MoE and kv_a_proj_with_mqa in DeepSeekV2 Dec 15, 2024
@jerryli1981
Copy link
Collaborator

您好,收到。我们按照您提供的思路研究下,多谢

@jerryli1981
Copy link
Collaborator

方便的话也可以进钉钉2群,加我钉钉聊下,或者加我微信也行

cosmosZhou added a commit to cosmosZhou/Pai-Megatron-Patch that referenced this issue Dec 22, 2024
@cosmosZhou cosmosZhou linked a pull request Dec 22, 2024 that will close this issue
cosmosZhou added a commit to cosmosZhou/Pai-Megatron-Patch that referenced this issue Dec 22, 2024
fix Problem of shared_expert_gate in Qwen2-MoE in alibaba#397
@jerryli1981
Copy link
Collaborator

您好,已修复,烦请Review:#425

jerryli1981 added a commit that referenced this issue Jan 9, 2025
fix Problem of shared_expert_gate in Qwen2-MoE in #397

Co-authored-by: Jerry Li <jerryli1981@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants