77 Star 596 Fork 1.2K

Ascend/pytorch

关于sdpa算子的显存占用过高,是否有优化方式

DONE
训练问题
创建于  
2024-09-27 10:50

一、问题现象(附报错日志上下文):
我在训练qwen2.5-14b模型时(40 muli_head),设置16k序列,报错:
输入图片说明
猜测这个值等于
2(bf16)*40(muli_head)*(16*1024)^2=20G,
请问有没有优化的方式。
另外我阅读了 FlashAttentionScore
是不是可以认为sdpa就是使用FlashAttention实现,但为什么显存占用这么高呢,与seq_len平方增长

二、软件版本:
-- CANN 版本 (e.g., CANN 3.0.x,5.x.x): 8.0.RC1
--Tensorflow/Pytorch/MindSpore 版本: PyTorch version: 2.2.0 (NPU)
--Python 版本 (e.g., Python 3.7.5): 3.9.10
--操作系统版本 (e.g., Ubuntu 18.04): Linux-4.19.90-vhulk2211.3.0.h1543.eulerosv2r10.aarch64-aarch64-with-glibc2.28

评论 (3)

GlowwormX 创建了训练问题 9个月前
GlowwormX 修改了描述 9个月前
GlowwormX 修改了描述 9个月前
展开全部操作日志

sdpa不完全是FlashAttention实现,只在满足条件的时候走FlashAttention实现,可以通过profiling看下具体走到哪个分支了

是只有这几个形状的才支持吗
输入图片说明

https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000006.html
可以查看这里关于torch.nn.functional.scaled_dot_product_attention的约束,满足时会走FlashAttention实现,同时您也可以通过profiling看下具体走哪个算子实现
https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/devaids/auxiliarydevtool/atlasprofiling_16_0006.html

huangyunlong 任务状态TODO 修改为DONE 8个月前

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
优先级
预计工期 (小时)
开始日期   -   截止日期
-
置顶选项
参与者(2)
huangyunlong-huangyunlong2022 GlowwormX-GlowwormX
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch

搜索帮助