这个项目主要记录从多头注意力(MHA)到多头潜在注意力(MLA)的发展过程及其简要实现, 为突出重点, 代码中忽略了layernorm
,RoPE
, KV Cache
等其他相对不那么重要的组件的具体实现
python main.py
如上图所示, 在最原始的 MHA 中, 每一个 token 的 head 是 q,k,v
一一对应的, 保存 kvcache 需要占用很大的显存空间, 实现见代码 mha.py
为了缓解因为保存所有的 head 带来的显存压力, 于是就有了 MQA, k, v
只有一个头, 保存 kvcache 时现存开销就大大减少了, 需要做注意力时再将 kv 按 head 来 repeat 即可, 见 mqa.py
MQA 虽然极大减少了现存开销, 但是 kv 多个头的内容都是一样的, 这无疑也降低了模型获取信息的能力, 于是介于 MHA 和 MQA 二者之间的 GQA 就出现了. 通过将 q
的 head 分成多组, 每个组对应一组 kv
, 这样相比 MHA 节约了显存空间, 相比 MQA 又增强了模型的能力, 见 gqa.py
GQA 虽然相比 MHA 节约了显存, 但是在长序列推理场景下还是存在瓶颈问题. MLA 试图通过将 kv 映射到低秩空间, 在尽可能无损精度的前提下降低 kvcache 开销, 提高推理速度
如上图所示, MLA 将 k, v
映射到低秩空间后重新映射回高维空间. 但是如果仔细观察示意图, 不免会产生下面几个疑惑:
- 为什么对
q
也要重新映射? - 为什么不对所有的
q, k
做 RoPE, 反而是拆分成两部分, 一部分做 RoPE 而另一部分不做处理(NoPE)? - 在上面的图中, KV Cache 存的还是全部头的 kv, 这好像并没有降低 kvcache 开销吧?
- 上面的计算比原先的 MHA, MQA, GQA 要复杂得多, 凭什么说它能提高速度?
在下面的部分中, 将会对这几个问题一一解答.
对于第一个问题, 最直接的想法就是与 kv 的做法保持一致性, 但是更深层次上则要放到后面第三部分与第二个问题来解答. 如果只看这一张图, 我们很快就能想到第三个问题的解决方案, 即下面的第二部分. 这个第一部分只是最基础的版本, 存在很多问题, 具体代码见 mla_cd.py
仔细观察上下两幅图, 最直观的对比就是 KV Cache 的地方变了, 而这也就是上面第三个问题的解决方案. 我们可以直接将低秩映射后的 compressed_kv
当成一个整体来存储. 但是这又出现了一个问题, 如果不对全部的 C 做 RoPE, 模型的能力会下降; 如果做 RoPE, 这个 C 里面又包含了 V 的内容, V 是不应该加 RoPE 的, 同时 C 接下来还要重新映射回高维空间, 很难保证位置信息不丢失. 于是, 图中 RoPE & NoPE 的妙处之一就体现出来了: 做 RoPE 的部分负责处理位置信息, NoPE 的部分负责保留剩余的 k 的信息和全部的 v 的信息. 见代码 mla_cc.py
观察此时的结构, 会发现最明显的不同是组件变少了, 而这也就是这部分中将提到的 "吸收".
记某一次推理中传入的 hidden_states
为 compressed_kv
变为
注意到
$$
\begin{align*}
A &= QK^T \
&= (c_t^q W^{UQ})(C_t W^{UKV})^T \
&= c_t^q (W^{UQ} (W^{UKV})^T) C_t^T \
&= c_t^q {W^{UQ}}^{\prime} C_t^T, \
{W^{UQ}}^{\prime} &:= W^{UQ} (W^{UKV})^T \in \mathbb{R}^{a_q \times a_{kv}}.
\end{align*}
$$
这样, 通过矩阵乘法的变换操作, 可以用 compressed_kv
参与 attention 计算, 而不需要显式的 K, 减少计算量. 但是注意到这存在一个问题, 上面的式子是在没有 RoPE 的时候成立的, 如果加了位置编码, 由于
同理, 对 V 一样可以做简化. 记
有了上面的基础后, 就可以回答上面的问题了.
对于第一个问题, 现在可以发现, 既然
第二个问题的做法有两个妙处, 一是方便对 KV Cache 压缩, 二是在保证矩阵 "吸收" 的同时保证位置编码信息.
至于第四个问题, 显然通过上面的简化, MLA 的计算量已经大幅度减少了, 而更为重要的一点是, MLA 最开始是为了减轻 kvcache 调用带来的时延问题的, 在 decode 阶段主要是 memory bound, 于是减少 kvcache 的显存占用是尤为重要的, MLA 通过将 kv 低秩映射到更小的压缩张量 c 来一起存储, 这无疑大大减少了显存开销.
做了 absorb 的 MLA 实现代码可见 mla_absorb.py