1 Star 1 Fork 0

lbs-ai/NVIDIA-Apex

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

Fast Multihead Attention

This implementation has two main features :

  • A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.
  • The removal of all copies and transposes found in standard implementations of Multihead Attention.
Python Version C++ Version
Layer Norm and Residual Add Variant X X
Includes Linear Biases X
Reduces CPU Overheads X
Fuses masking with Softmax X
Removes Transposes and Copies X X
Includes Self and Encoder/Decoder Variants X X

How to Instantiate

SelfMultiheadAttn( hidden dim, heads, dropout=prob, bias=bool, include_norm_add=bool, impl='fast' ) EncdecMultiheadAttn( hidden dim, heads, dropout=prob, bias=bool, include_norm_add=bool, impl='fast' )

impl has two options:

  • fast uses C++ Version
  • default uses Python Version

Instructions to build on Linux

$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./

Try Performance Tests Yourself!

Perf test script is found here!

cd contrib/examples/multihead_attn

Fast Multihead Attention

python perf_test_multihead_attn.py --ref

Fast Multihead Attention with C++ Implementation

python perf_test_multihead_attn.py

Compare with torch.nn.MultiheadAttn

python perf_test_multihead_attn.py --native

Test your own range!

python perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5

Performance Comparisons

  • Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.
  • Time is measured across multiple layers to simulate an in model scenario.

Multihead Attention Forward Multihead Attention Backward

马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Open_FootSteps/nvidia-apex.git
git@gitee.com:Open_FootSteps/nvidia-apex.git
Open_FootSteps
nvidia-apex
NVIDIA-Apex
master

搜索帮助