197 Star 1.3K Fork 1.2K

GVPAscend/MindSpeed-LLM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
ci
configs
docs
mindspore
pytorch
features
cc_lora.md
communication-over-computation.md
deterministic_computation.md
environment_variable.md
fine-tuning-with-context-parallel.md
fused_mlp.md
mc2.md
o2.md
profiling.md
recompute_relative.md
ring-attention-context-parallel.md
tensor_parallel_2d.md
variable_length_flash_attention.md
virtual_pipeline_parallel.md
yarn.md
models
solutions
install_guide.md
quick_start.md
readme.md
examples
mindspeed_llm
sources/images
tests
.gitignore
LICENSE
OWNERS
README.md
SECURITYNOTE.md
Third_Party_Open_Source_Software_Notice
convert_ckpt.py
evaluation.py
img.png
inference.py
posttrain_gpt.py
preprocess_data.py
pretrain_gpt.py
pretrain_mamba.py
requirements.txt
setup.py
克隆/下载
variable_length_flash_attention.md 1.92 KB
一键复制 编辑 原始数据 按行查看 历史
jzh 提交于 3个月前 . !2696docs fix some bug

支持variable length flash attention训练场景

训练场景

通常一个批次中输入进模型的文本序列是由多个文档(doc)拼接得到。在默认情况下,多个文档被视为同一序列,互相间的self attention没有掩盖。在特定情况下,多个文档间要求独立,文档间不能互相做self attention,在这种情况下attention mask和position ids需要在每个文档结束的位置(EOD)被重新设置。

解决方案

通过调用底层flash-attention算子的可变长模式,支持当前训练场景。

使用方式

1. 数据准备

首先确保每一个文档的末尾都添加了EOD Token。

python ./preprocess_data.py \
   --input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
   --tokenizer-name-or-path ./model_from_hf/Llama3-hf/ \
   --output-prefix ./dataset/enwiki \
   --workers 4 \
   --log-interval 1000  \
   --append-eod \ #使能TND必须增加eod
   --tokenizer-type PretrainedFromHF

2. 训练参数设置

在模型训练脚本中增加 --reset-position-ids

使能之后,会根据eod位置,生成变量actual_seq_len,标识多个文档(doc)拼接的实际长度。

原理

1.使能前

初始化atten_mask为压缩下三角矩阵(2048*2048);

casual_mask.png

多个文档被视为同一序列,互相间的self attention没有掩盖,所有token均参与计算。

2.使能后

不初始化mask,根据eod位置生成actual_seq_len,假设一个序列中真实的文本长度分别为[2,2,0,2,2],则actual_seq_len为[2,4,4,6,8]; 实际计算量由actual_seq_len决定;

类似的attn_mask可以类似的表示为(实际计算时不生成):

varlen_mask.png

其中左下角空白位置不参与计算。

Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/MindSpeed-LLM.git
git@gitee.com:ascend/MindSpeed-LLM.git
ascend
MindSpeed-LLM
MindSpeed-LLM
2.1.0

搜索帮助