Skip to content

Commit

Permalink
fix(npu): fix npu dim incorrect squeeze when head num=1
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Aug 13, 2024
1 parent 708260f commit bd28606
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
4 changes: 2 additions & 2 deletions internlm/model/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def _npu_varlen_kvpacked_attn(
):
# TODO: support npu native varlen flash attention
k, v = kv.unbind(dim=2)
k, v = k.squeeze(dim=2), v.squeeze(dim=2)
# k, v = k.squeeze(dim=2), v.squeeze(dim=2)
return _npu_varlen_qkvsplited_attn(
q,
k,
Expand All @@ -393,7 +393,7 @@ def _npu_varlen_kvpacked_attn(

def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False):
k, v = kv.unbind(dim=2)
k, v = k.squeeze(dim=2), v.squeeze(dim=2)
# k, v = k.squeeze(dim=2), v.squeeze(dim=2)
return _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal)


Expand Down
3 changes: 1 addition & 2 deletions tests/test_model/test_npu_ops/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
HEAD_NUM = 32
HIDDEN_SZIE = 4096
SEQ_LEN = [2048, 4096]
MICRO_BSZ = 1
HEAD_DIM = HIDDEN_SZIE // HEAD_NUM
VOCAB_SIZE = 32000
NUM_KV_HEAD_LIST = [8, 32]
NUM_KV_HEAD_LIST = [1, 8, 32]
MICRO_BSZ_LIST = [1, 2]
DTYPE_LIST = [torch.bfloat16, torch.float16]

Expand Down

0 comments on commit bd28606

Please sign in to comment.