这个项目主要记录从多头注意力(MHA)到多头潜在注意力(MLA)的发展过程及其简要实现, 为突出重点, 代码中忽略了layernorm
,RoPE
, KV Cache
等其他相对不那么重要的组件的具体实现
python main.py
如上图所示, 在最原始的 MHA 中, 每一个 token 的 head 是 q,k,v
一一对应的, 保存 kvcache 需要占用很大的显存空间, 实现见代码 mha.py
为了缓解因为保存所有的 head 带来的显存压力, 于是就有了 MQA, k, v
只有一个头, 保存 kvcache 时现存开销就大大减少了, 需要做注意力时再将 kv 按 head 来 repeat 即可, 见 mqa.py
MQA 虽然极大减少了现存开销, 但是 kv 多个头的内容都是一样的, 这无疑也降低了模型获取信息的能力, 于是介于 MHA 和 MQA 二者之间的 GQA 就出现了. 通过将 q
的 head 分成多组, 每个组对应一组 kv
, 这样相比 MHA 节约了显存空间, 相比 MQA 又增强了模型的能力, 见 gqa.py
GQA 虽然相比 MHA 节约了显存, 但是在长序列推理场景下还是存在瓶颈问题. MLA 试图通过将 kv 映射到低秩空间, 在尽可能无损精度的前提下降低 kvcache 开销, 提高推理速度
如上图所示, MLA 将 k, v
映射到低秩空间后重新映射回高维空间. 但是如果仔细观察示意图, 不免会产生下面几个疑惑:
q
也要重新映射?q, k
做 RoPE, 反而是拆分成两部分, 一部分做 RoPE 而另一部分不做处理(NoPE)?在下面的部分中, 将会对这几个问题一一解答.
对于第一个问题, 最直接的想法就是与 kv 的做法保持一致性, 但是更深层次上则要放到后面第三部分与第二个问题来解答. 如果只看这一张图, 我们很快就能想到第三个问题的解决方案, 即下面的第二部分. 这个第一部分只是最基础的版本, 存在很多问题, 具体代码见 mla_cd.py
仔细观察上下两幅图, 最直观的对比就是 KV Cache 的地方变了, 而这也就是上面第三个问题的解决方案. 我们可以直接将低秩映射后的 compressed_kv
当成一个整体来存储. 但是这又出现了一个问题, 如果不对全部的 C 做 RoPE, 模型的能力会下降; 如果做 RoPE, 这个 C 里面又包含了 V 的内容, V 是不应该加 RoPE 的, 同时 C 接下来还要重新映射回高维空间, 很难保证位置信息不丢失. 于是, 图中 RoPE & NoPE 的妙处之一就体现出来了: 做 RoPE 的部分负责处理位置信息, NoPE 的部分负责保留剩余的 k 的信息和全部的 v 的信息. 见代码 mla_cc.py
观察此时的结构, 会发现最明显的不同是组件变少了, 而这也就是这部分中将提到的 "吸收".
记某一次推理中传入的 hidden_states
为 ht, 经过对应于 Q 的低秩变换后为 cqt=htWDQ,WDQ∈Rh×rq, 经过对应于 KV 的低秩变换后为 ckvt=htWDKV,WDKV∈Rh×rkv. 如果有 kvcache, 则 compressed_kv
变为 Ct=[ckv0,ckv1,…,ckvt]∈Rt×akv.
注意到
A=QKT=(cqtWUQ)(CtWUKV)T=cqt(WUQ(WUKV)T)CTt=cqtWUQ′CTt,WUQ′:=WUQ(WUKV)T∈Raq×akv.
这样, 通过矩阵乘法的变换操作, 可以用 WUQ 将 WUKV "吸收", 直接让 compressed_kv
参与 attention 计算, 而不需要显式的 K, 减少计算量. 但是注意到这存在一个问题, 上面的式子是在没有 RoPE 的时候成立的, 如果加了位置编码, 由于 RTi=R−i, 上面的式子中, 就单独考虑 K 的某一项 ki(i≤t), 会变成
A=QkTi=(cqtWUQRt)(ckviWUKVRi)T=cqt(WUQRtRTi(WUKV)T)(ckvt)T=cqt(WUQRt−i(WUKV)T)(ckvt)T,
这样一来, 中间的式子中就存在了一个与位置相关的不确定项 Rt−i. 此时, 对 qk 拆分成 RoPE & NoPE 的两部分的妙处就凸显出来了.
同理, 对 V 一样可以做简化. 记 WUKV=[WUK,WUV], 其中 WUV∈Rakv×dv, 而又有 WO∈Rdv×h, 注意到 S=(eAi∑i≤teAi)O=SVU=OWO=S(CtWUV)WO=SCt(WUVWO)=SCtWO′WO′:=WUVWO∈Rakv×h. 于是, WUV 一样可以被 WO "吸收", 而不需要显式的 V 参与计算, 这也减轻了计算量.
有了上面的基础后, 就可以回答上面的问题了.
对于第一个问题, 现在可以发现, 既然 WUQ 可以将 WUKV "吸收", 那么从减少计算量的角度来说, 自然是希望 WUQ 左边的维度尽可能小, 所以对 q 也做了低秩映射.
第二个问题的做法有两个妙处, 一是方便对 KV Cache 压缩, 二是在保证矩阵 "吸收" 的同时保证位置编码信息.
至于第四个问题, 显然通过上面的简化, MLA 的计算量已经大幅度减少了, 而更为重要的一点是, MLA 最开始是为了减轻 kvcache 调用带来的时延问题的, 在 decode 阶段主要是 memory bound, 于是减少 kvcache 的显存占用是尤为重要的, MLA 通过将 kv 低秩映射到更小的压缩张量 c 来一起存储, 这无疑大大减少了显存开销.
做了 absorb 的 MLA 实现代码可见 mla_absorb.py
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。