diff --git a/ops/dsl/triton/sparsetoken_flash_attention_decode.py b/ops/dsl/triton/sparsetoken_flash_attention_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..e58a7c942da2ebd30005609c7e2fae91d1a75eed --- /dev/null +++ b/ops/dsl/triton/sparsetoken_flash_attention_decode.py @@ -0,0 +1,152 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FlexAttention Triton kernel for Sparse Token Decoding with Online Softmax + +This module provides the flexible attention mechanism for sparse token decoding. +It implements an efficient Triton kernel that +handles sparse attention with online softmax normalization. +""" + +import triton +import triton.language as tl + + +@triton.jit +def sparsetoken_flash_attention_decode( + q_ptr, + k_ptr, + v_ptr, + sparse_ind_ptr, + sparse_nnz_ptr, + output_ptr, + batch_size, + qo_heads, + head_dim, + gqa_group_size, + softmax_scale, + stride_q_b, + stride_q_h, + stride_q_1, + stride_k_b, + stride_k_h, + stride_k_s, + stride_v_b, + stride_v_h, + stride_v_s, + stride_sparse_ind_b, + stride_sparse_ind_h, + stride_sparse_ind_l, + stride_sparse_nnz_b, + stride_sparse_nnz_h, + stride_output_b, + stride_output_h, + stride_output_1, + blokc_head: tl.constexpr, +): + """ + Triton kernel for sparse token decoding with online softmax. + Each program instance handles one (b, h) pair. + """ + pid = tl.program_id(0) + total_tasks = batch_size * qo_heads + if pid >= total_tasks: + return + + # Compute b and h from pid + b = pid // qo_heads + h = pid % qo_heads + k_h = h // gqa_group_size + + # Load nnz from sparse_nnz + nnz_offset = b * stride_sparse_nnz_b + h * stride_sparse_nnz_h + nnz = tl.load(sparse_nnz_ptr + nnz_offset).to(tl.int32) + # If nnz is 0, set to 1 + nnz = tl.where(nnz == 0, 1, nnz) + + # Load query vector + q_offset = ( + b * stride_q_b + h * stride_q_h + + 0 * stride_q_1 + tl.arange(0, blokc_head) + ) + q_vec = tl.load( + q_ptr + q_offset, + mask=tl.arange(0, blokc_head) < head_dim, + other=0.0, + ) + + # Initialize online softmax variables + max_score = -float("inf") + sum_exp = 0.0 + out_vec = tl.zeros([blokc_head], dtype=tl.float32) + + # Loop over nnz indices + for i in range(0, nnz): + # Load index from sparse_ind + ind_offset = ( + b * stride_sparse_ind_b + + h * stride_sparse_ind_h + + i * stride_sparse_ind_l + ) + idx = tl.load(sparse_ind_ptr + ind_offset).to(tl.int32) + + # Load key vector + k_offset = ( + b * stride_k_b + k_h * stride_k_h + + idx * stride_k_s + tl.arange(0, blokc_head) + ) + k_vec = tl.load( + k_ptr + k_offset, + mask=tl.arange(0, blokc_head) < head_dim, + other=0.0, + ) + + # Compute attention score + dot_product = tl.sum(q_vec * k_vec) + score = dot_product * softmax_scale + + # Load value vector + v_offset = ( + b * stride_v_b + k_h * stride_v_h + + idx * stride_v_s + tl.arange(0, blokc_head) + ) + v_vec = tl.load( + v_ptr + v_offset, + mask=tl.arange(0, blokc_head) < head_dim, + other=0.0, + ) + + # Online softmax update + new_max = tl.maximum(max_score, score) + exp_scale = tl.exp(max_score - new_max) + exp_score = tl.exp(score - new_max) + + sum_exp = sum_exp * exp_scale + exp_score + out_vec = out_vec * exp_scale + exp_score * v_vec + max_score = new_max + + # Normalize the output + out_vec = out_vec / sum_exp + + # Store the result + output_offset = ( + b * stride_output_b + h * stride_output_h + + 0 * stride_output_1 + tl.arange(0, blokc_head) + ) + tl.store( + output_ptr + output_offset, + out_vec, + mask=tl.arange(0, blokc_head) < head_dim, + ) diff --git a/ops/dsl/triton/sparsetoken_flash_attention_decode_paged.py b/ops/dsl/triton/sparsetoken_flash_attention_decode_paged.py new file mode 100644 index 0000000000000000000000000000000000000000..8eade0da60fd13cb063a64ef3a0f60d40eb6e813 --- /dev/null +++ b/ops/dsl/triton/sparsetoken_flash_attention_decode_paged.py @@ -0,0 +1,129 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FlexAttention Triton Kernel for Sparse Token Decoding with Paged KV Cache + +This module provides the flexible attention mechanism for sparse token decoding +using a paged key-value cache. It implements an efficient Triton kernel that +handles sparse attention with online softmax normalization. +""" + +import triton +import triton.language as tl + + +@triton.jit +def sparsetoken_flash_attention_decode_paged( + q_ptr, + paged_kv_cache_ptr, + kv_page_indptr_ptr, + kv_page_indices_ptr, + sparse_ind_ptr, + sparse_nnz_ptr, + output_ptr, + h_q: tl.constexpr, + h_k: tl.constexpr, + dim: tl.constexpr, + page_size: tl.constexpr, + l_max: tl.constexpr, + sqrt_d: tl.constexpr, +): + """ + Triton 内核实现稀疏 token 分页注意力 + 每个程序处理一个 (batch, head) 对 + """ + pid = tl.program_id(0) + b = pid // h_q + h = pid % h_q + + # 加载 nnz + offset_nnz = b * h_q + h + nnz = tl.load(sparse_nnz_ptr + offset_nnz).to(tl.int32) + + # 计算输出偏移 + output_offset = b * h_q * dim + h * dim + tl.arange(0, dim) + + if nnz == 0: + # 如果 nnz 为 0,存储零输出 + tl.store(output_ptr + output_offset, 0.0) + return + + # 加载查询向量 q [1, dim] + q_offset = b * h_q * dim + h * dim + tl.arange(0, dim) + q = tl.load(q_ptr + q_offset) + + # 计算 GQA 组和 KV 头索引 + groups = h_q // h_k + h_k = h // groups + + # 初始化在线 softmax 状态 + m = -float("inf") + l = 0.0 + acc = tl.zeros([dim], dtype=tl.float32) + + # 循环处理每个 token + for i in range(0, nnz): + # 加载 token 索引 + offset_ind = b * h_q * l_max + h * l_max + i + token_idx = tl.load(sparse_ind_ptr + offset_ind).to(tl.int32) + + # 计算页面索引和偏移 + page_idx = token_idx // page_size + offset_in_page = token_idx % page_size + + # 加载页面起始指针 + ptr_start = tl.load(kv_page_indptr_ptr + b).to(tl.int32) + + # 加载页面 ID + page_id_offset = ptr_start + page_idx + page_id = tl.load(kv_page_indices_ptr + page_id_offset).to(tl.int32) + + # 加载 K 向量 [1, dim] + k_offset = ( + page_id * (2 * h_k * page_size * dim) + + 0 * (h_k * page_size * dim) + + h_k * (page_size * dim) + + offset_in_page * dim + + tl.arange(0, dim) + ) + k = tl.load(paged_kv_cache_ptr + k_offset) + + # 加载 V 向量 [1, dim] + v_offset = ( + page_id * (2 * h_k * page_size * dim) + + 1 * (h_k * page_size * dim) + + h_k * (page_size * dim) + + offset_in_page * dim + + tl.arange(0, dim) + ) + v = tl.load(paged_kv_cache_ptr + v_offset) + + # 计算注意力分数 + dot_product = tl.sum(q * k) + score = dot_product / sqrt_d # 使用预计算的 sqrt_d + + # 更新在线 softmax + m_new = tl.maximum(m, score) + alpha = tl.exp(m - m_new) + beta = tl.exp(score - m_new) + l_new = l * alpha + beta + acc_new = acc * alpha + beta * v + m = m_new + l = l_new + acc = acc_new + + # 计算最终输出 + out = acc / l + tl.store(output_ptr + output_offset, out)