From 1a49eb0085bc0f212e3a95231ca6d63af8e6218b Mon Sep 17 00:00:00 2001 From: hujiahui8 Date: Fri, 19 Sep 2025 16:43:29 +0800 Subject: [PATCH] attention operator generation --- .../ai_kernel_generator/core/agent/coder.py | 2 +- .../docs/triton_docs/suggestion_docs.md | 86 ++++++++++++--- .../attention_op/attention_op_task.py | 62 +++++++++++ .../flash_liner_attention_op_task.py | 103 ++++++++++++++++++ .../attention_op/liner_attention_op_task.py | 75 +++++++++++++ 5 files changed, 311 insertions(+), 17 deletions(-) create mode 100644 aikg/tests/resources/attention_op/attention_op_task.py create mode 100644 aikg/tests/resources/attention_op/flash_liner_attention_op_task.py create mode 100644 aikg/tests/resources/attention_op/liner_attention_op_task.py diff --git a/aikg/python/ai_kernel_generator/core/agent/coder.py b/aikg/python/ai_kernel_generator/core/agent/coder.py index a185a94ed..98838d359 100644 --- a/aikg/python/ai_kernel_generator/core/agent/coder.py +++ b/aikg/python/ai_kernel_generator/core/agent/coder.py @@ -187,7 +187,7 @@ class Coder(AgentBase): # 查找所有Python文件 for file_path in local_dir.glob("*.py"): # 检查文件名是否包含op_name(不区分大小写) - if self.op_name.lower() in file_path.stem.lower(): + if file_path.stem.lower() in self.op_name.lower(): try: with open(file_path, "r", encoding="utf-8") as f: content = f.read().strip() diff --git a/aikg/python/ai_kernel_generator/resources/docs/triton_docs/suggestion_docs.md b/aikg/python/ai_kernel_generator/resources/docs/triton_docs/suggestion_docs.md index 013ddfcde..dcc821a66 100644 --- a/aikg/python/ai_kernel_generator/resources/docs/triton_docs/suggestion_docs.md +++ b/aikg/python/ai_kernel_generator/resources/docs/triton_docs/suggestion_docs.md @@ -2,7 +2,62 @@ 本文档提供 Triton 开发的技巧、性能优化和问题排查指南。 -## 1. 性能优化 +## 1. 特定算子优化 + +### matmul 算子 + +**注意**:合理的切分是提升matmul算子性能的关键。 + +#### Ascend后端切分优化:充分发挥带宽,算子行宽为512B的整数倍,且单次行数尽量大,以fp16和bf16为例: +- **A、B都不转置**:分块行宽分别为K0和N0,则M0=128,K0=256,N0=256 +- **A不转置,B转置**:分块行宽都是K0,则K0=256,M0和N0影响较小 +- **A、B都转置**:分块行宽分别为M0和K0,则M0=256,K0=256,N0=128 +- **A转置,B不转置**:分块行宽分别为M0和N0,则左右矩阵均无法同时满足512B的整数倍,需根据实际情况调整 + +### Attention 算子 + +#### 标准Attention计算流程: +1. **QK^T计算**:`scores = Q @ K^T / sqrt(d_k)`,计算注意力分数 +2. **Softmax归一化**:`attn_weights = softmax(scores)`,确保权重和为1 +3. **加权求和**:`output = attn_weights @ V`,得到最终输出 + +#### Flash Attention优化策略: +- **分块计算**:将大矩阵分块处理,减少内存占用 +- **在线Softmax**:使用增量式softmax算法,分块计算,维护全局最大值和归一化因子,避免存储完整注意力矩阵,具体逻辑如下: +```python +# 初始化全局统计量 +m_i = -float("inf") # 全局最大值 +l_i = 0.0 # 全局exp和 +acc = 0.0 # 输出累加器 + +# 分块处理 +for start_n in range(0, seq_len, BLOCK_SIZE): + # 1. 加载当前块的分数 + scores = tl.load(scores_ptr + start_n, mask=load_mask, other=-float("inf")) + + # 2. 更新全局最大值 + m_ij = tl.maximum(m_i, tl.max(scores, 0)) + + # 3. 计算当前块的exp值(数值稳定化) + scores = scores - m_ij + p = tl.math.exp2(scores) + + # 4. 更新全局exp和 + l_ij = tl.sum(p, 0) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + # 5. 更新输出累加器 + acc = acc * alpha + p + + # 6. 更新全局最大值 + m_i = m_ij + +# 最终归一化 +acc = acc / l_i +``` + +## 2. 性能优化 ### 块大小选择策略 - **基础**: 2的幂(256, 512, 1024) @@ -17,7 +72,7 @@ ### 算子拆分策略 - **复杂算子**: 拆分为多个简单kernel,避免单个kernel过于复杂 -## 2. 数值稳定性技巧 +## 3. 数值稳定性技巧 ### 防溢出处理 ```python @@ -34,13 +89,15 @@ exp_data = tl.exp(stable_data) - **中间计算**: 关键步骤转为float32提升精度 - **累加操作**: 使用高精度累加器防止精度丢失 -## 3. API使用限制与替代方案 +## 4. API使用限制与替代方案 ### 禁止使用的语法 - 禁止 `return`, `break`, `continue` → 使用mask控制 - 禁止 lambda表达式 → 使用内联函数或tl.where - 禁止 链式布尔运算 → 分步计算mask - 禁止 张量直接索引 → 使用tl.load/tl.store +**Ascend后端** +- 禁止 `tl.where` → 使用if-else ### tl.constexpr 正确用法 - **仅在内核参数中使用**: `BLOCK_SIZE: tl.constexpr` @@ -53,19 +110,16 @@ exp_data = tl.exp(stable_data) ### Ascend 后端避免使用 tl.where 计算内存偏移 Ascend 后端对`tl.where`生成的复杂指针运算支持不完全。复杂条件判断可以采用if-else静态分支处理,而非在内存访问时动态计算。 -**推荐示例** -```python -if input_shape_0 == 1: - input_offsets = input_offsets_n - case1() -elif input_shape_1 == 1: - input_offsets = input_offsets_m * input_shape_1 - case2() -else: - case3() -``` +### 标量类型转换 +- **仅支持to(type)**: 如`scalar.to(tl.float16)`,禁止使用`tl.float16(scalar)` +- **tl.constexpr类型转换**: 将常量赋值给临时变量再转换,如`scalar = CONST_A` + +### 切分设置 +**Ascend后端** +- BLOCK_SIZE必须小于65536,并且线程块所占内存必须符合硬件限制 +- 若shape过大,单次切分后超过硬件缓存,并且BLOCK_SIZE超过限制,可以对循环进行多次切分 -## 4. 调试与排查清单 +## 5. 调试与排查清单 ### 内存访问问题 - [ ] 所有load/store是否都有mask或boundary_check? @@ -116,7 +170,7 @@ if M > max_grid_size: | 数据错位 | 计算结果错误 | 验证stride设置 | | 竞争条件 | 结果不确定 | 使用原子操作 | -## 6. 开发建议 +## 7. 开发建议 ### 代码风格 - 添加充分的注释说明计算逻辑 diff --git a/aikg/tests/resources/attention_op/attention_op_task.py b/aikg/tests/resources/attention_op/attention_op_task.py new file mode 100644 index 000000000..cea1a791c --- /dev/null +++ b/aikg/tests/resources/attention_op/attention_op_task.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self, scale='auto'): + """ + 初始化Scaled Dot Product Attention模块 + + 参数: + scale: 缩放因子,可以是以下值之一: + - float: 自定义缩放因子 + - None: 不使用缩放(相当于scale=1.0) + """ + super(Model, self).__init__() + self.scale = scale + + def forward(self, query, key, value): + """ + 带有可配置缩放因子的Scaled Dot Product Attention实现 + + 参数: + query: [batch_size, num_heads, seq_len, head_dim] + key: [batch_size, num_heads, seq_len, head_dim] + value: [batch_size, num_heads, seq_len, head_dim] + + 返回: + output: [batch_size, num_heads, seq_len, head_dim] + """ + # 1. 确定缩放因子 + if self.scale is None: + scale_factor = 1.0 / (query.size(-1) ** 0.5) + else: + scale_factor = float(self.scale) + + # 2. 对query进行缩放 + LOG2E = 1.44269504 + scaled_query = query * scale_factor * LOG2E + + # 3. 计算注意力分数 (缩放后的query和key的点积) + # [batch_size, num_heads, seq_len, seq_len] + attn_scores = torch.matmul(scaled_query, key.transpose(-2, -1)) + + # 4. 应用softmax获取注意力权重 + attn_weights = torch.softmax(attn_scores, dim=-1) + + # 5. 计算输出(注意力权重与value的点积) + output = torch.matmul(attn_weights, value) + + return output + +def get_inputs(): + batch, num_heads, seq_len, head_dim = 32, 8, 1024, 64 + shape = (batch, num_heads, seq_len, head_dim) + + query = torch.randn(shape, dtype=torch.float16) + key = torch.randn(shape, dtype=torch.float16) + value = torch.randn(shape, dtype=torch.float16) + return [query, key, value] + +def get_init_inputs(): + scale = None + return [scale] \ No newline at end of file diff --git a/aikg/tests/resources/attention_op/flash_liner_attention_op_task.py b/aikg/tests/resources/attention_op/flash_liner_attention_op_task.py new file mode 100644 index 000000000..ad9bd044e --- /dev/null +++ b/aikg/tests/resources/attention_op/flash_liner_attention_op_task.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, query, key, value, g, beta, initial_state, orig_mask): + initial_dtype = query.dtype + chunk_size=64 + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + # reshape to chunks + # (batch, heads, seq_len, dim) → (batch, heads, num_chunks, chunk_size, dim) + # 门控制向量:(batch, heads, seq_len) → (batch, heads, num_chunks, chunk_size) + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + g = g.cumsum(dim=-1) + # g.unsqueeze(-1):(batch, heads, num_chunks, chunk_size) → (batch, heads, num_chunks, chunk_size, 1) + # g.unsqueeze(-2):(batch, heads, num_chunks, chunk_size) → (batch, heads, num_chunks, 1, chunk_size) + # 减法:(batch, heads, num_chunks, chunk_size, chunk_size) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + + mask = torch.triu(orig_mask, diagonal=0) + # [batch_size, num_heads, num_chunks, chunk_size, k_head_dim] @ [batch_size, num_heads, num_chunks, k_head_dim, chunk_size] + # = [batch_size, num_heads, num_chunks, chunk_size, chunk_size] + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + + # attn:[batch_size, num_heads, num_chunks, chunk_size, chunk_size] + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + # [batch_size, num_heads, num_chunks, chunk_size, chunk_size] @ [batch_size, num_heads, num_chunks, chunk_size, v_head_dim] + # = [batch_size, num_heads, num_chunks, chunk_size, v_head_dim] + value = attn @ v_beta + # [batch_size, num_heads, num_chunks, chunk_size, chunk_size] @ [batch_size, num_heads, num_chunks, chunk_size, k_head_dim] + # = [batch_size, num_heads, num_chunks, chunk_size, k_head_dim] + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + # ​​循环状态初始化​,(batch, heads, k_head_dim, v_head_dim) + last_recurrent_state = initial_state + core_attn_out = torch.zeros_like(value) + mask = torch.triu(orig_mask, diagonal=1) + + # for each chunk + # 分块循环处理​ + for i in range(0, sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + # [batch_size, num_heads, 1, chunk_size, k_head_dim] @ [batch_size, num_heads, 1, k_head_dim, chunk_size] + # = [batch_size, num_heads, 1, chunk_size, chunk_size] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + + # [batch_size, num_heads, 1, chunk_size, k_head_dim] @ [batch_size, num_heads, k_head_dim, v_head_dim] + # = [batch_size, num_heads, 1, chunk_size, v_head_dim] + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + + # [batch_size, num_heads, 1, chunk_size, k_head_dim] @ [batch_size, num_heads, k_head_dim, v_head_dim] + # = [batch_size, num_heads, 1, chunk_size, v_head_dim] + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + # [batch_size, num_heads, 1, chunk_size, chunk_size] @ [batch_size, num_heads, 1, chunk_size, v_head_dim] + # = [batch_size, num_heads, 1, chunk_size, v_head_dim] + core_attn_out[:, :, i] = attn_inter + attn @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + +def get_inputs(): + batch, num_heads, seq_len, head_dim = 32, 8, 1024, 64 + chunk_size = 64 + shape = (batch, num_heads, seq_len, head_dim) + + query = torch.randn(shape, dtype=torch.float32) + key = torch.randn(shape, dtype=torch.float32) + value = torch.randn(shape, dtype=torch.float32) + g = torch.randn((batch, num_heads, seq_len), dtype=torch.float32) + beta = torch.randn((batch, num_heads, seq_len), dtype=torch.float32) + + initial_state = torch.zeros(batch, num_heads, head_dim, head_dim, dtype=torch.float32) + mask = torch.ones(chunk_size, chunk_size, dtype=torch.bool) + return [query, key, value, g, beta, initial_state, mask] + +def get_init_inputs(): + return [] \ No newline at end of file diff --git a/aikg/tests/resources/attention_op/liner_attention_op_task.py b/aikg/tests/resources/attention_op/liner_attention_op_task.py new file mode 100644 index 000000000..9a1fdc3bb --- /dev/null +++ b/aikg/tests/resources/attention_op/liner_attention_op_task.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, query, key, value, g, beta, initial_state): + initial_dtype = query.dtype + # 转置时确保维度正确 + # query = query.transpose(1, 2).contiguous().to(torch.float32) + # key = key.transpose(1, 2).contiguous().to(torch.float32) + # value = value.transpose(1, 2).contiguous().to(torch.float32) + # beta = beta.transpose(1, 2).contiguous().to(torch.float32) + # g = g.transpose(1, 2).contiguous().to(torch.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim, device=value.device, dtype=torch.float32) + + # 正确初始化状态张量 + if initial_state is None: + last_recurrent_state = torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, device=value.device, dtype=torch.float32) + else: + last_recurrent_state = initial_state.to(value.device).to(torch.float32) + + for i in range(sequence_length): + q_t = query[:, :, i] # [batch, heads, 1, dim] + k_t = key[:, :, i] # [batch, heads, 1, dim] + v_t = value[:, :, i] # [batch, heads, 1, dim] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) # [batch, heads, 1, 1] + beta_t = beta[:, :, i].unsqueeze(-1) # [batch, heads, 1] + + # 更新循环状态 + # batch_size, num_heads, k_head_dim, v_head_dim + last_recurrent_state = last_recurrent_state * g_t + + # 修复维度不匹配问题:确保k_t有正确的维度 + k_t_expanded = k_t.unsqueeze(-1) # [batch, heads, dim, 1] + + # 计算kv_mem + # batch_size, num_heads, 1, v_head_dim + kv_mem = (last_recurrent_state * k_t_expanded).sum(dim=-2) # [batch, heads, v_dim] + + # 计算delta并更新状态 + delta = (v_t - kv_mem) * beta_t + delta_expanded = delta.unsqueeze(-2) # [batch, heads, 1, v_dim] + # batch_size, num_heads, k_head_dim, v_head_dim + last_recurrent_state = last_recurrent_state + k_t_expanded * delta_expanded + + # 计算输出 + q_t_expanded = q_t.unsqueeze(-1) # [batch, heads, dim, 1] + # batch_size, num_heads, 1, v_head_dim + core_attn_out[:, :, i] = (last_recurrent_state * q_t_expanded).sum(dim=-2) + + # batch_size, sequence_length, num_heads, v_head_dim + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + +def get_inputs(): + batch, num_heads, seq_len, head_dim = 32, 8, 1024, 64 + shape = (batch, num_heads, seq_len, head_dim) + + query = torch.randn(shape, dtype=torch.float32) + key = torch.randn(shape, dtype=torch.float32) + value = torch.randn(shape, dtype=torch.float32) + g = torch.randn((batch, num_heads, seq_len), dtype=torch.float32) + beta = torch.randn((batch, num_heads, seq_len), dtype=torch.float32) + return [query, key, value, g, beta, None] + +def get_init_inputs(): + return [] \ No newline at end of file -- Gitee