-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmha.py
85 lines (71 loc) · 2.72 KB
/
mha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch import nn
from typing import Optional
def apply_rope(x: torch.Tensor, *args, **kwargs):
return x
def update_kv_cache(key_states: torch.Tensor, value_states: torch.Tensor):
return key_states.repeat(1, 1, 5, 1), value_states.repeat(1, 1, 5, 1)
class MultiHeadAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: Optional[int] = None,
use_cache: Optional[bool] = False,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim or hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.use_cache = use_cache
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
):
bsz, q_len, _ = hidden_states.size()
# q: (bsz, num_heads, q_len, head_dim)
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# RoPE
query_states = apply_rope(query_states)
key_states = apply_rope(key_states)
if self.use_cache:
# Update KV Cache and get full kv
# k/v: (bsz, num_heads, kv_len, head_dim)
key_states, value_states = update_kv_cache(key_states, value_states)
kv_len = key_states.shape[2]
# Softmax(Q @ K^T / sqrt(d_k))
# attn: (bsz, num_heads, q_len, kv_len)
attn_weights = (
torch.einsum("bhld, bhnd -> bhln", query_states, key_states) * self.scale
)
attn_weights = attn_weights.softmax(dim=-1)
# A @ V
# attn: (bsz, num_heads, q_len, head_dim)
attn_weights = torch.einsum("bhln, bhnd -> bhld", attn_weights, value_states)
# output: (bsz, q_len, hidden_size)
attn_output = self.o_proj(
attn_weights.transpose(1, 2)
.contiguous()
.view(bsz, q_len, self.num_heads * self.head_dim)
)
return attn_output