Skip to content

haukzero/from-mha-to-mla

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

From MHA to MLA

这个项目主要记录从多头注意力(MHA)到多头潜在注意力(MLA)的发展过程及其简要实现, 为突出重点, 代码中忽略了layernorm,RoPE, KV Cache 等其他相对不那么重要的组件的具体实现

How to start

python main.py

Multi-head Attention (MHA)

mha_mqa_gqa

如上图所示, 在最原始的 MHA 中, 每一个 token 的 head 是 q,k,v 一一对应的, 保存 kvcache 需要占用很大的显存空间, 实现见代码 mha.py

Multi-query Attention (MQA)

为了缓解因为保存所有的 head 带来的显存压力, 于是就有了 MQA, k, v 只有一个头, 保存 kvcache 时现存开销就大大减少了, 需要做注意力时再将 kv 按 head 来 repeat 即可, 见 mqa.py

Grouped-query Attention (GQA)

MQA 虽然极大减少了现存开销, 但是 kv 多个头的内容都是一样的, 这无疑也降低了模型获取信息的能力, 于是介于 MHA 和 MQA 二者之间的 GQA 就出现了. 通过将 q 的 head 分成多组, 每个组对应一组 kv, 这样相比 MHA 节约了显存空间, 相比 MQA 又增强了模型的能力, 见 gqa.py

Multi-head Latent Attention (MLA)

GQA 虽然相比 MHA 节约了显存, 但是在长序列推理场景下还是存在瓶颈问题. MLA 试图通过将 kv 映射到低秩空间, 在尽可能无损精度的前提下降低 kvcache 开销, 提高推理速度

Cache Decompressed (CD)

mla_cd

如上图所示, MLA 将 k, v 映射到低秩空间后重新映射回高维空间. 但是如果仔细观察示意图, 不免会产生下面几个疑惑:

  • 为什么对 q 也要重新映射?
  • 为什么不对所有的 q, k 做 RoPE, 反而是拆分成两部分, 一部分做 RoPE 而另一部分不做处理(NoPE)?
  • 在上面的图中, KV Cache 存的还是全部头的 kv, 这好像并没有降低 kvcache 开销吧?
  • 上面的计算比原先的 MHA, MQA, GQA 要复杂得多, 凭什么说它能提高速度?

在下面的部分中, 将会对这几个问题一一解答.

对于第一个问题, 最直接的想法就是与 kv 的做法保持一致性, 但是更深层次上则要放到后面第三部分与第二个问题来解答. 如果只看这一张图, 我们很快就能想到第三个问题的解决方案, 即下面的第二部分. 这个第一部分只是最基础的版本, 存在很多问题, 具体代码见 mla_cd.py

Cache Compressed (CC)

mla_cc

仔细观察上下两幅图, 最直观的对比就是 KV Cache 的地方变了, 而这也就是上面第三个问题的解决方案. 我们可以直接将低秩映射后的 compressed_kv 当成一个整体来存储. 但是这又出现了一个问题, 如果不对全部的 C 做 RoPE, 模型的能力会下降; 如果做 RoPE, 这个 C 里面又包含了 V 的内容, V 是不应该加 RoPE 的, 同时 C 接下来还要重新映射回高维空间, 很难保证位置信息不丢失. 于是, 图中 RoPE & NoPE 的妙处之一就体现出来了: 做 RoPE 的部分负责处理位置信息, NoPE 的部分负责保留剩余的 k 的信息和全部的 v 的信息. 见代码 mla_cc.py

Absorb

mla_absorb

观察此时的结构, 会发现最明显的不同是组件变少了, 而这也就是这部分中将提到的 "吸收".

记某一次推理中传入的 hidden_states$h_t$, 经过对应于 Q 的低秩变换后为 $c_t^q = h_t W^{DQ}, W^{DQ} \in \mathbb{R}^{h \times r_q}$, 经过对应于 KV 的低秩变换后为 $c_t^{kv} = h_t W^{DKV}, W^{DKV} \in \mathbb{R}^{h \times r_{kv}}$. 如果有 kvcache, 则 compressed_kv 变为 $C_t = \left[ c_0^{kv}, c_1^{kv}, \dots, c_t^{kv} \right] \in \mathbb{R}^{t \times a_{kv}}$.

注意到 $$ \begin{align*} A &= QK^T \ &= (c_t^q W^{UQ})(C_t W^{UKV})^T \ &= c_t^q (W^{UQ} (W^{UKV})^T) C_t^T \ &= c_t^q {W^{UQ}}^{\prime} C_t^T, \ {W^{UQ}}^{\prime} &:= W^{UQ} (W^{UKV})^T \in \mathbb{R}^{a_q \times a_{kv}}. \end{align*} $$ 这样, 通过矩阵乘法的变换操作, 可以用 $W^{UQ}$$W^{UKV}$ "吸收", 直接让 compressed_kv 参与 attention 计算, 而不需要显式的 K, 减少计算量. 但是注意到这存在一个问题, 上面的式子是在没有 RoPE 的时候成立的, 如果加了位置编码, 由于 $R_i^T = R_{-i}$, 上面的式子中, 就单独考虑 K 的某一项 $k_i (i \le t)$, 会变成 $$ \begin{align*} A &= Qk_i^T \ &= (c_t^q W^{UQ}R_t)(c_i^{kv} W^{UKV}R_i)^T \ &= c_t^q (W^{UQ} R_t R_i^T (W^{UKV})^T) (c_t^{kv})^T \ &= c_t^q (W^{UQ} R_{t - i} (W^{UKV})^T) (c_t^{kv})^T, \ \end{align*} $$ 这样一来, 中间的式子中就存在了一个与位置相关的不确定项 $R_{t - i}$. 此时, 对 qk 拆分成 RoPE & NoPE 的两部分的妙处就凸显出来了.

同理, 对 V 一样可以做简化. 记 $W^{UKV} = \left[W^{UK}, W^{UV}\right]$, 其中 $W^{UV} \in \mathbb{R}^{a_{kv} \times {d_v}}$, 而又有 $W^O \in \mathbb{R}^{d_v \times h}$, 注意到 $$ \begin{align*} S &= \left( \frac{e^{A_i}}{\sum_{i \le t} e^{A_i}} \right) \ O &= SV \ U &= OW^O \ &= S \left(C_t W^{UV}\right) W^O \ &= S C_t \left( W^{UV} W^O \right) \ &= S C_t {W^O}^{\prime} \ {W^O}^{\prime} &:= W^{UV} W^O \in \mathbb{R}^{a_{kv} \times h}. \end{align*} $$ 于是, $W^{UV}$ 一样可以被 $W^O$ "吸收", 而不需要显式的 V 参与计算, 这也减轻了计算量.

有了上面的基础后, 就可以回答上面的问题了.

对于第一个问题, 现在可以发现, 既然 $W^{UQ}$ 可以将 $W^{UKV}$ "吸收", 那么从减少计算量的角度来说, 自然是希望 $W^{UQ}$ 左边的维度尽可能小, 所以对 q 也做了低秩映射.

第二个问题的做法有两个妙处, 一是方便对 KV Cache 压缩, 二是在保证矩阵 "吸收" 的同时保证位置编码信息.

至于第四个问题, 显然通过上面的简化, MLA 的计算量已经大幅度减少了, 而更为重要的一点是, MLA 最开始是为了减轻 kvcache 调用带来的时延问题的, 在 decode 阶段主要是 memory bound, 于是减少 kvcache 的显存占用是尤为重要的, MLA 通过将 kv 低秩映射到更小的压缩张量 c 来一起存储, 这无疑大大减少了显存开销.

做了 absorb 的 MLA 实现代码可见 mla_absorb.py

About

MHA, MQA, GQA, MLA 相关原理及简要实现

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages