Processing math: 100%
1 Star 1 Fork 0

Hauk Zero/from-mha-to-mla

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
贡献代码
同步代码
Hauk Zero add all bdc1616 2个月前
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

From MHA to MLA

这个项目主要记录从多头注意力(MHA)到多头潜在注意力(MLA)的发展过程及其简要实现, 为突出重点, 代码中忽略了layernorm,RoPE, KV Cache 等其他相对不那么重要的组件的具体实现

How to start

python main.py

Multi-head Attention (MHA)

mha_mqa_gqa

如上图所示, 在最原始的 MHA 中, 每一个 token 的 head 是 q,k,v 一一对应的, 保存 kvcache 需要占用很大的显存空间, 实现见代码 mha.py

Multi-query Attention (MQA)

为了缓解因为保存所有的 head 带来的显存压力, 于是就有了 MQA, k, v 只有一个头, 保存 kvcache 时现存开销就大大减少了, 需要做注意力时再将 kv 按 head 来 repeat 即可, 见 mqa.py

Grouped-query Attention (GQA)

MQA 虽然极大减少了现存开销, 但是 kv 多个头的内容都是一样的, 这无疑也降低了模型获取信息的能力, 于是介于 MHA 和 MQA 二者之间的 GQA 就出现了. 通过将 q 的 head 分成多组, 每个组对应一组 kv, 这样相比 MHA 节约了显存空间, 相比 MQA 又增强了模型的能力, 见 gqa.py

Multi-head Latent Attention (MLA)

GQA 虽然相比 MHA 节约了显存, 但是在长序列推理场景下还是存在瓶颈问题. MLA 试图通过将 kv 映射到低秩空间, 在尽可能无损精度的前提下降低 kvcache 开销, 提高推理速度

Cache Decompressed (CD)

mla_cd

如上图所示, MLA 将 k, v 映射到低秩空间后重新映射回高维空间. 但是如果仔细观察示意图, 不免会产生下面几个疑惑:

  • 为什么对 q 也要重新映射?
  • 为什么不对所有的 q, k 做 RoPE, 反而是拆分成两部分, 一部分做 RoPE 而另一部分不做处理(NoPE)?
  • 在上面的图中, KV Cache 存的还是全部头的 kv, 这好像并没有降低 kvcache 开销吧?
  • 上面的计算比原先的 MHA, MQA, GQA 要复杂得多, 凭什么说它能提高速度?

在下面的部分中, 将会对这几个问题一一解答.

对于第一个问题, 最直接的想法就是与 kv 的做法保持一致性, 但是更深层次上则要放到后面第三部分与第二个问题来解答. 如果只看这一张图, 我们很快就能想到第三个问题的解决方案, 即下面的第二部分. 这个第一部分只是最基础的版本, 存在很多问题, 具体代码见 mla_cd.py

Cache Compressed (CC)

mla_cc

仔细观察上下两幅图, 最直观的对比就是 KV Cache 的地方变了, 而这也就是上面第三个问题的解决方案. 我们可以直接将低秩映射后的 compressed_kv 当成一个整体来存储. 但是这又出现了一个问题, 如果不对全部的 C 做 RoPE, 模型的能力会下降; 如果做 RoPE, 这个 C 里面又包含了 V 的内容, V 是不应该加 RoPE 的, 同时 C 接下来还要重新映射回高维空间, 很难保证位置信息不丢失. 于是, 图中 RoPE & NoPE 的妙处之一就体现出来了: 做 RoPE 的部分负责处理位置信息, NoPE 的部分负责保留剩余的 k 的信息和全部的 v 的信息. 见代码 mla_cc.py

Absorb

mla_absorb

观察此时的结构, 会发现最明显的不同是组件变少了, 而这也就是这部分中将提到的 "吸收".

记某一次推理中传入的 hidden_statesht, 经过对应于 Q 的低秩变换后为 cqt=htWDQ,WDQRh×rq, 经过对应于 KV 的低秩变换后为 ckvt=htWDKV,WDKVRh×rkv. 如果有 kvcache, 则 compressed_kv 变为 Ct=[ckv0,ckv1,,ckvt]Rt×akv.

注意到 A=QKT=(cqtWUQ)(CtWUKV)T=cqt(WUQ(WUKV)T)CTt=cqtWUQCTt,WUQ:=WUQ(WUKV)TRaq×akv. 这样, 通过矩阵乘法的变换操作, 可以用 WUQWUKV "吸收", 直接让 compressed_kv 参与 attention 计算, 而不需要显式的 K, 减少计算量. 但是注意到这存在一个问题, 上面的式子是在没有 RoPE 的时候成立的, 如果加了位置编码, 由于 RTi=Ri, 上面的式子中, 就单独考虑 K 的某一项 ki(it), 会变成 A=QkTi=(cqtWUQRt)(ckviWUKVRi)T=cqt(WUQRtRTi(WUKV)T)(ckvt)T=cqt(WUQRti(WUKV)T)(ckvt)T, 这样一来, 中间的式子中就存在了一个与位置相关的不确定项 Rti. 此时, 对 qk 拆分成 RoPE & NoPE 的两部分的妙处就凸显出来了.

同理, 对 V 一样可以做简化. 记 WUKV=[WUK,WUV], 其中 WUVRakv×dv, 而又有 WORdv×h, 注意到 S=(eAiiteAi)O=SVU=OWO=S(CtWUV)WO=SCt(WUVWO)=SCtWOWO:=WUVWORakv×h. 于是, WUV 一样可以被 WO "吸收", 而不需要显式的 V 参与计算, 这也减轻了计算量.

有了上面的基础后, 就可以回答上面的问题了.

对于第一个问题, 现在可以发现, 既然 WUQ 可以将 WUKV "吸收", 那么从减少计算量的角度来说, 自然是希望 WUQ 左边的维度尽可能小, 所以对 q 也做了低秩映射.

第二个问题的做法有两个妙处, 一是方便对 KV Cache 压缩, 二是在保证矩阵 "吸收" 的同时保证位置编码信息.

至于第四个问题, 显然通过上面的简化, MLA 的计算量已经大幅度减少了, 而更为重要的一点是, MLA 最开始是为了减轻 kvcache 调用带来的时延问题的, 在 decode 阶段主要是 memory bound, 于是减少 kvcache 的显存占用是尤为重要的, MLA 通过将 kv 低秩映射到更小的压缩张量 c 来一起存储, 这无疑大大减少了显存开销.

做了 absorb 的 MLA 实现代码可见 mla_absorb.py

空文件

简介

MHA, MQA, GQA, MLA 相关原理及简要实现 展开 收起
取消

发行版

暂无发行版

贡献者

全部

近期动态

不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/haukzero/from-mha-to-mla.git
git@gitee.com:haukzero/from-mha-to-mla.git
haukzero
from-mha-to-mla
from-mha-to-mla
master

搜索帮助

371d5123 14472233 46e8bd33 14472233