This repository contains the code in both PyTorch and Jax for our paper.
Improving Transformers with Dynamically Composable Multi-Head Attention
Da Xiao, Qingye Meng, Shengping Li, Xingyuan Yuan
ICML 2024 (oral)
We propose Dynamically Composable Multi-Head Attention (DCMHA), a parameter and computation efficient attention architecture that tackles the shortcomings of Multi-Head Attention(MHA) and increases the expressive power of the model by dynamically composing attention heads. At the core of DCMHA is a Compose
function that transforms the attention score and weight matrices in an input-dependent way. DCMHA can be used as a drop-in replacement of MHA in any transformer architecture to obtain the corresponding DCFormer.
In practice, we train DCFormer on TPU for efficiency and then infer on GPU for convenience, so we open-source Jax training code and PyTorch inference code in two separate folders.
jax/
folder, supporting train DCFormer on TPU or GPU with google/MaxText.jax/README.md
for details.pytorch/
folder, supporting accelerated inference with torch.compile.pytorch/README.md
for details.Synthetic tasks dataset in the paper is located at data/synthetic_dataset.jsonl
. Each line in the file contains an in-context learning sample, where words in the bracket []
are compared to calculate accuracy. Eg.
< Elizabeth has jeans. Steven has strawberries. Steven has a fox. >. Steven does not have a kind of clothing
< Karen has a taxi. Karen has a pig. Paul has a cocktail. >. Karen does not have a kind of drink
< Ruth has a sweater. Steven has a taxi. Ruth has a handgun. >. Ruth does not have a kind of [ vehicle ]
. . .
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。