一、问题现象(附报错日志上下文):
我在训练qwen2.5-14b模型时(40 muli_head),设置16k序列,报错:
猜测这个值等于
2(bf16)*40(muli_head)*(16*1024)^2=20G,
请问有没有优化的方式。
另外我阅读了 FlashAttentionScore
是不是可以认为sdpa就是使用FlashAttention实现,但为什么显存占用这么高呢,与seq_len平方增长
二、软件版本:
-- CANN 版本 (e.g., CANN 3.0.x,5.x.x): 8.0.RC1
--Tensorflow/Pytorch/MindSpore 版本: PyTorch version: 2.2.0 (NPU)
--Python 版本 (e.g., Python 3.7.5): 3.9.10
--操作系统版本 (e.g., Ubuntu 18.04): Linux-4.19.90-vhulk2211.3.0.h1543.eulerosv2r10.aarch64-aarch64-with-glibc2.28
sdpa不完全是FlashAttention实现,只在满足条件的时候走FlashAttention实现,可以通过profiling看下具体走到哪个分支了
是只有这几个形状的才支持吗
https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000006.html
可以查看这里关于torch.nn.functional.scaled_dot_product_attention的约束,满足时会走FlashAttention实现,同时您也可以通过profiling看下具体走哪个算子实现
https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/devaids/auxiliarydevtool/atlasprofiling_16_0006.html
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
登录 后才可以发表评论