# flash-dmattn
**Repository Path**: ftgreat/flash-dmattn
## Basic Information
- **Project Name**: flash-dmattn
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: BSD-3-Clause
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2025-10-25
- **Last Updated**: 2025-10-29
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
[English](./README.md) | **简体中文**

Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存效率与动态掩码注意力的稀疏计算能力相结合,用于在 Transformer 模型中处理超长序列。
## 主要特性
### 🎯 核心内核优势
- **Mask & Bias 支持**: 原生支持 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的 attention_mask 和 attention_bias 张量
- **智能计算跳过**: 基于 attention_mask 的 block-level 自动跳过机制,完全跳过全零 mask 区块的计算和内存访问
- **完整梯度支持**: 内置 attention_bias 的完整梯度计算路径,支持端到端训练
### 🚀 性能与效率
- **动态稀疏注意力**: 为每个查询动态选择最重要的键,将计算复杂度从 $O(N^2)$ 降低到 $O(N \cdot w)$,其中 $w \ll N$, 支持可训练的稀疏结构
- **内存效率**: 保持 Flash Attention 的 $O(N)$ 内存复杂度,无需实例化完整的注意力矩阵
- **CUDA 深度优化**: 自定义 CUDA 内核,含共享内存别名、流水线预取、按块跳过,实现高吞吐与低访存开销
- **超长上下文支持**: 通过动态掩码窗口裁剪,在保持精度的前提下支撑 128K+ 令牌级别的上下文处理
## 性能
我们展示了带有mask与bias条件下 Flash-DMA 相对于标准 PyTorch SDPA 的预期加速效果。

---
### 前向传播性能
以下表格是我们在NVIDIA A100-SXM4-80GB上对Flash-DMA与标准PyTorch SDPA在不同配置下的前向性能对比测试结果。结果为预热两次, 运行三次的平均值。
| Mode | Q len | K len | Window W | SDPA (ms) | FDMA (ms) | Speedup |
|--------|-------|--------|----------|-----------|-----------|---------|
| Train | 256 | 256 | 1024 | 0.29 | 0.19 | 1.58x |
| Train | 512 | 512 | 1024 | 0.35 | 0.19 | 1.86x |
| Train | 1024 | 1024 | 1024 | 0.51 | 0.18 | 2.81x |
| Train | 2048 | 2048 | 1024 | 1.04 | 0.18 | 5.68x |
| Train | 4096 | 4096 | 1024 | 2.53 | 0.24 | 10.41x |
| Train | 8192 | 8192 | 1024 | 9.38 | 0.36 | 25.93x |
| Train | 16384 | 16384 | 1024 | 28.39 | 0.81 | 35.25x |
| Train | 32768 | 32768 | 1024 | 111.87 | 2.25 | 49.78x |
| Train | 32768 | 32768 | 32 | 113.19 | 2.10 | 53.97x |
| Train | 32768 | 32768 | 64 | 113.17 | 2.12 | 53.32x |
| Train | 32768 | 32768 | 128 | 113.14 | 2.10 | 53.78x |
| Train | 32768 | 32768 | 256 | 113.18 | 2.13 | 53.18x |
| Train | 32768 | 32768 | 512 | 113.19 | 2.17 | 52.17x |
| Train | 32768 | 32768 | 1024 | 113.19 | 2.24 | 50.45x |
| Train | 32768 | 32768 | 2048 | 113.15 | 2.39 | 47.35x |
| Train | 32768 | 32768 | 4096 | 113.16 | 2.67 | 42.39x |
| Train | 32768 | 32768 | 8192 | 113.11 | 3.20 | 35.29x |
| Train | 32768 | 32768 | 16384 | 113.15 | 3.97 | 28.51x |
| Train | 32768 | 32768 | 32768 | 113.11 | 4.90 | 23.10x |
| Infer | 1 | 256 | 1024 | 0.25 | 0.19 | 1.28x |
| Infer | 1 | 512 | 1024 | 0.25 | 0.19 | 1.27x |
| Infer | 1 | 1024 | 1024 | 0.25 | 0.20 | 1.28x |
| Infer | 1 | 2048 | 1024 | 0.25 | 0.20 | 1.24x |
| Infer | 1 | 4096 | 1024 | 0.25 | 0.19 | 1.29x |
| Infer | 1 | 8192 | 1024 | 0.25 | 0.20 | 1.25x |
| Infer | 1 | 16384 | 1024 | 0.25 | 0.19 | 1.29x |
| Infer | 1 | 32768 | 1024 | 0.27 | 0.20 | 1.33x |
| Infer | 1 | 65536 | 1024 | 0.42 | 0.20 | 2.10x |
| Infer | 1 | 131072 | 1024 | 0.72 | 0.20 | 3.65x |
| Infer | 1 | 262144 | 1024 | 1.31 | 0.22 | 6.06x |
| Infer | 1 | 524288 | 1024 | 2.49 | 0.24 | 10.45x |
| Infer | 1 | 524288 | 32 | 2.48 | 0.21 | 11.60x |
| Infer | 1 | 524288 | 64 | 2.44 | 0.21 | 11.66x |
| Infer | 1 | 524288 | 128 | 2.45 | 0.21 | 11.47x |
| Infer | 1 | 524288 | 256 | 2.43 | 0.21 | 11.47x |
| Infer | 1 | 524288 | 512 | 2.44 | 0.22 | 10.89x |
| Infer | 1 | 524288 | 1024 | 2.44 | 0.24 | 10.31x |
| Infer | 1 | 524288 | 2048 | 2.44 | 0.27 | 9.07x |
| Infer | 1 | 524288 | 4096 | 2.45 | 0.33 | 7.41x |
| Infer | 1 | 524288 | 8192 | 2.44 | 0.35 | 6.93x |
| Infer | 1 | 524288 | 16384 | 2.44 | 0.35 | 6.93x |
| Infer | 1 | 524288 | 32768 | 2.45 | 0.35 | 6.96x |
| Infer | 1 | 524288 | 65536 | 2.44 | 0.35 | 6.88x |
---
### 反向传播性能
以下表格是我们在NVIDIA A100-SXM4-80GB上对Flash-DMA与标准PyTorch SDPA在不同配置下的反向性能对比测试结果。结果为预热两次, 运行三次的平均值。
| Mode | Q len | K len | Window W | SDPA-BWD (ms) | FDMA-BWD (ms) | Speedup |
|-------|-------|--------|----------|---------------|---------------|---------|
| Train | 256 | 256 | 1024 | 0.42 | 0.62 | 0.7x |
| Train | 512 | 512 | 1024 | 0.56 | 0.60 | 0.9x |
| Train | 1024 | 1024 | 1024 | 0.94 | 0.61 | 1.5x |
| Train | 2048 | 2048 | 1024 | 1.79 | 0.69 | 2.6x |
| Train | 4096 | 4096 | 1024 | 3.76 | 1.08 | 3.5x |
| Train | 8192 | 8192 | 1024 | 14.39 | 2.06 | 7.0x |
| Train | 16384 | 16384 | 1024 | 39.56 | 4.97 | 8.0x |
| Train | 32768 | 32768 | 1024 | 142.07 | 25.63 | 5.5x |
| Train | 32768 | 32768 | 32 | 142.70 | 21.91 | 6.5x |
| Train | 32768 | 32768 | 64 | 142.65 | 22.29 | 6.4x |
| Train | 32768 | 32768 | 128 | 142.69 | 23.04 | 6.2x |
| Train | 32768 | 32768 | 256 | 142.69 | 24.27 | 5.9x |
| Train | 32768 | 32768 | 512 | 142.67 | 25.12 | 5.7x |
| Train | 32768 | 32768 | 1024 | 142.55 | 25.58 | 5.6x |
| Train | 32768 | 32768 | 2048 | 142.75 | 25.64 | 5.6x |
| Train | 32768 | 32768 | 4096 | 142.61 | 24.84 | 5.7x |
| Train | 32768 | 32768 | 8192 | 142.33 | 25.63 | 5.6x |
| Train | 32768 | 32768 | 16384 | 142.40 | 25.62 | 5.6x |
| Train | 32768 | 32768 | 32768 | 142.43 | 25.63 | 5.6x |
---
## 安装
### 依赖
- **Linux**: Ubuntu 22.04 或更高版本
- **NVIDIA GPU**: 计算能力 8.0 或更高
- **C++ 编译器**: GCC 7+
- **CUDA**: 11.8 或更高版本
- **Python**: 3.9 或更高版本
- **PyTorch**: 2.5.1 或更高版本
### 安装
您可以通过预编译的轮子安装 Flash-DMA:
```bash
pip install flash-dmattn --no-build-isolation
```
或者,您可以从源代码编译和安装:
```bash
git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn
pip install . --no-build-isolation
```
## 快速开始
### 基本用法
```python
import torch
from flash_dmattn import flash_dmattn_func_auto
from flash_dmattn.utils.mask import create_mask
import math
# 设置
batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64
window_size = 128
device = torch.device('cuda')
dtype = torch.bfloat16
min_dtype = torch.finfo(dtype).min # dtype 的最小值
# 输入张量
query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
# 为稀疏注意力创建 bias
attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
# 基于 bias 生成动态 mask
if seq_len > window_size:
attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=None,
batch_size=batch_size,
query_len=seq_len,
key_len=seq_len,
window_size=window_size,
min_dtype=min_dtype,
)
# 选择 FDMA 内核
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
# 运行 FDMA
output = flash_dmattn_func(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
attn_bias=attn_bias,
is_causal=True,
softmax_scale=1.0/math.sqrt(head_dim),
)
print(f"输出形状: {output.shape}") # [1, 256, 2, 64]
```
### 梯度计算示例
```python
# 开启梯度计算
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
attn_bias.requires_grad_(True)
# 前向传播
output = flash_dmattn_func(
query=query, key=key, value=value,
attn_mask=attn_mask,
attn_bias=attn_bias,
is_causal=True,
softmax_scale=1.0/math.sqrt(head_dim)
)
# 反向传播
loss = output.sum()
loss.backward()
print(f"Query 梯度形状: {query.grad.shape}")
print(f"Key 梯度形状: {key.grad.shape}")
print(f"Value 梯度形状: {value.grad.shape}")
print(f"Bias 梯度形状: {attn_bias.grad.shape}")
```
## 工作原理
Flash-DMA 通过将 Flash Attention 的高效内存访问模式与动态掩码注意力的稀疏计算能力相结合,实现了高效的注意力机制。
### 核心技术融合
- **🎯 Mask & Bias 原生支持**: 内核直接处理 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的张量
- **⚡ Block-level 智能跳过**: 基于 mask 的统一 OR-reduction 跳过逻辑,完全避免全零区块的计算和内存访问
- **🔄 完整梯度链路**: 内置 attention bias 梯度计算,支持端到端可微分训练
### 关键优化策略
1. **统一跳过逻辑**: 前向和反向过程使用相同的 block-level 跳过决策
2. **内存访问优化**: 只有当 `OR(mask_block) == true` 时才加载 K/V 数据
3. **梯度路径完整性**: dbias 梯度计算完全融合在反向内核中
4. **共享内存复用**: sMask ↔ sP, sBias ↔ sdS 智能别名化
## 文档
📚 **完整文档可在 [docs](docs/) 目录中找到:**
- **[API 参考](docs/api_reference.md)** - 完整的函数文档和使用示例
- **[集成指南](docs/integration.md)** - Flash Attention 集成的详细技术文档
## 从源码构建
### 开发环境设置
```bash
# 克隆包含子模块
git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn
# 在开发模式下构建
pip install -e .
# 运行测试以验证安装
python -c "import flash_dma_cuda; print('✅ Flash DMA CUDA 扩展导入成功')"
```
### 构建要求
- CUDA Toolkit 11.8+
- CUTLASS 库
- 支持 CUDA 的 PyTorch
### 支持的架构
- **SM 8.0**
- **SM 9.0**
- **SM 10.0**
- **SM 12.0**
**注意**: Flash 动态掩码注意力需要 CUDA 计算能力 8.0+ 才能获得最佳性能。不支持更早的架构。
## 基准测试
Flash-DMA 提供全面的基准测试工具,用于评估不同配置下的性能:
### 前向传播等效性
```bash
python benchmarks/forward_equivalence.py
```
验证 Python 参考实现与 CUDA 实现之间的数值一致性。
### 前向传播性能基准测试
```bash
python benchmarks/forward_performance.py
```
在各种序列长度和批大小下比较 Flash-DMA 与标准 SDPA。
### 反向传播等效性
```bash
python benchmarks/backward_equivalence.py
```
验证 Python 参考实现与 CUDA 实现之间的数值一致性。
### 反向传播性能基准测试
```bash
python benchmarks/backward_performance.py
```
比较 Flash-DMA 与标准 SDPA 在各种序列长度和批大小下的性能。
### 梯度计算
```bash
python benchmarks/grad_equivalence.py
```
测试反向传播实现和梯度等效性。
## 故障排除
### 常见问题
**编译错误**
```bash
# 确保 CUDA_HOME 设置正确
echo $CUDA_HOME # Linux/Mac
echo $env:CUDA_HOME # Windows PowerShell
# 检查 CUDA 工具包版本
nvcc --version
# 验证 PyTorch CUDA 支持
python -c "import torch; print(f'CUDA 可用: {torch.cuda.is_available()}')"
```
**导入错误**
```python
# 测试基本导入
try:
from flash_dmattn import flash_dmattn_func, get_available_backends
print("✅ Flash 动态掩码注意力导入成功")
print(f"可用后端: {get_available_backends()}")
except ImportError as e:
print(f"❌ 导入失败: {e}")
print("请确保包已正确安装,使用: pip install -e .")
```
**性能问题**
```python
# 监控 GPU 内存使用
from flash_dmattn import flash_dmattn_func
def print_memory_stats():
if torch.cuda.is_available():
print(f"GPU 内存: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print_memory_stats()
output = flash_dmattn_func(q=query, k=key, v=value, is_causal=True)
print_memory_stats()
# 如需要,清除缓存
torch.cuda.empty_cache()
```
## 贡献
我们欢迎社区的贡献!Flash-DMA 是一个开源项目,我们重视所有类型的贡献。
### 如何贡献
- **报告错误**: 发现了错误?请[提交 issue](https://github.com/SmallDoges/flash-dmattn/issues/new/choose)
- **功能请求**: 有改进想法?[告诉我们](https://github.com/SmallDoges/flash-dmattn/issues/new/choose)
- **提交代码**: 准备贡献代码?查看我们的[贡献指南](CONTRIBUTING.md)
- **改进文档**: 帮助我们完善文档
### 贡献者快速入门
1. Fork 仓库
2. 创建功能分支: `git checkout -b feature-name`
3. 进行修改并测试
4. 提交 Pull Request
详细说明请参见我们的[贡献指南](CONTRIBUTING.md)。
### 行为准则
本项目遵循[贡献者公约行为准则](CODE_OF_CONDUCT.md)。参与时,您需要遵守此准则。
## 许可证
本项目采用 BSD 3-Clause 许可证。详情请参见 [LICENSE](LICENSE)。
## 引用
如果您在研究中使用 Flash-DMA,请引用:
```bibtex
@misc{shi2025trainabledynamicmasksparse,
title={Trainable Dynamic Mask Sparse Attention},
author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Liangdong Wang and Guang Liu and Yuyu Luo},
year={2025},
eprint={2508.02124},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2508.02124},
}
```
## 致谢
本项目基于并集成了几个优秀的工作:
- **[OpenSeek](https://github.com/FlagAI-Open/OpenSeek)** - 内核开发支持
- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - 内存高效的注意力计算
- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - 高性能矩阵运算库
我们感谢开源社区对高效 Transformer 实现的贡献。🤗