Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support split qkv linear and sp overlap comm #415

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Prev Previous commit
Merge remote-tracking branch 'origin/main' into split_qkv_overlap_comm
inkcherry committed Dec 6, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit ec194e7f87094012053fa9f880516e66c37beccb
6 changes: 6 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
@@ -983,6 +983,12 @@ def _add_training_args(parser):
group.add_argument('--ds-sequence-parallel-overlap-comm', action='store_true',
help='overlap comm for ds-sequence-parallel',
dest='ds_sequence_parallel_overlap_comm')
group.add_argument('--ds-sequence-parallel-fpdt', action='store_true',
help='use DeepSpeed sequence parallelism with FPDT.')
group.add_argument('--ds-sequence-parallel-fpdt-chunk-size', type=int, default=65536,
help='Chunk size used in FPDT attention.')
group.add_argument('--ds-sequence-parallel-fpdt-offloading', action='store_true',
help='use DeepSpeed sequence parallelism FPDT with offloading.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
29 changes: 15 additions & 14 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
@@ -753,19 +753,21 @@ def forward(self, hidden_states, attention_mask,
# =====================
if self.attention_type == AttnType.self_attn:
if not self.split_qkv:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)] hidden_states 4096, 1, 2048
mixed_x_layer, _ = self.query_key_value(hidden_states) #heads16 hidden 2048 num_per_head 128
#[4096, 1,6144] -> 16,3,128
# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(-1, (self.num_key_value_groups + 2),
self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = self.split_tensor(mixed_x_layer)
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

if self.enable_ds_sequence_parallel:
assert self.projection_size == self.kv_projection_size
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim)
if self.sequence_parallel or not self.enable_ds_sequence_parallel:
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
each_hidden_size = mixed_x_layer.shape[-1] // 3
query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim)
else:
assert self.ds_sp_overlap, """
Currently, the split_qkv operation is only applicable
@@ -785,7 +787,6 @@ def forward(self, hidden_states, attention_mask,
value_layer,_ = self.value_linear(hidden_states)
value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1)


# Repeat kv
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
You are viewing a condensed version of this merge commit. You can view the full changes here.