diff --git a/docs/deeplearning_operators/flash_linear_attention.md b/docs/deeplearning_operators/flash_linear_attention.md index 335feda2b53b63632911526f6aa1327dc49229d0..73240602b509d9042779ba4b6e9911e6ff48cf0e 100644 --- a/docs/deeplearning_operators/flash_linear_attention.md +++ b/docs/deeplearning_operators/flash_linear_attention.md @@ -1,2 +1,250 @@ -Flash Linear Attention -====================== +# Flash Linear Attention + +
+ Author: zhangzhangJ +
+ +:::{warning} +:class: myclass1 myclass2 +:name: a-tip-reference + + This document describes the implementation of **Flash Linear Attention** using mcTileLang on MACA GPUs. + It implements a **Chunkwise Recurrent** kernel optimized for long-sequence modeling. +::: + +Flash Linear Attention (FLA) is an efficient attention mechanism that reduces the quadratic complexity $O(N^2)$ of standard Transformers to linear complexity $O(N)$. By utilizing the associativity of matrix multiplication $(Q K^T) V = Q (K^T V)$, we can maintain a compact Recurrent State ($S \in \mathbb{R}^{D \times D}$) instead of a growing KV-Cache. This makes it highly efficient for Large Language Models (LLMs) with long contexts and State Space Models (e.g., RetNet, Mamba). + +## Application Scenarios +1. **Long-Sequence Modeling:** Handling contexts > 32k tokens efficiently where standard attention memory usage explodes. +2. **Efficient Inference:** Reducing memory footprint by using a fixed-size recurrent state. +3. **Recurrent Architectures:** Implementing core operators for linear RNNs and SSMs. + +## Core Calculation Logic +The fundamental advantage of Flash Linear Attention lies in changing the order of matrix multiplication to achieve linear complexity. + +**1. Mathematical Formulation** +Standard Attention computes $Attention(Q, K, V) = softmax(Q K^T) V$, which implicitly requires calculating an $N \times N$ attention map. Linear Attention removes the softmax (or uses a linear approximation) and exploits the associativity of matrix multiplication: + +$$ +(Q K^T) V \iff Q (K^T V) +$$ + +* **LHS ($O(N^2)$)**: In standard attention, input $Q$ attends to all historic $K$, creating a massive dependency map scaling with $N^2$. +* **RHS ($O(N)$)**: In linear attention, we pre-aggregate $K$ and $V$ into a fixed-size Hidden State $S = \sum K^T V$. + +**2. Chunkwise Recurrent Algorithm (Hardware Optimized)** +To maximize GPU Tensor Core utilization, we do not process token-by-token (which has low parallelism). Instead, we process the sequence in **Chunks**: + +* **Step 1: State Update (Materialization)** + For each chunk, we calculate the contribution to the hidden state. + $$S_{new} = S_{old} + K_{chunk}^T \times V_{chunk}$$ + * *Operation:* Matrix Multiplication (GEMM). + * *Dimension:* $(D \times L) \times (L \times D) \rightarrow (D \times D)$. + * *Insight:* This step effectively "writes" the chunk's history into the state matrix. + +* **Step 2: State Projection (Retrieval)** + We query the updated state to compute the attention output. + $$O_{chunk} = Q_{chunk} \times S_{new}$$ + * *Operation:* Matrix Multiplication (GEMM). + * *Dimension:* $(L \times D) \times (D \times D) \rightarrow (L \times D)$. + * *Insight:* This step "reads" the retrieved information from the state matrix. + +**3. Complexity Shift** +* **Standard Attention:** $O(N^2 \cdot D)$ — Bottleneck is Sequence Length ($N$). +* **Linear Attention:** $O(N \cdot D^2)$ — Bottleneck is Head Dimension ($D$). + +This makes FLA extremely efficient for long-context scenarios where $N \gg D$. + +## Interface Parameters + +* **Tensor Arguments:** + * `Q` (Query): Input tensor of shape `(Batch, Heads, Seq_Len, Dim)`. + * `K` (Key): Input tensor of shape `(Batch, Heads, Seq_Len, Dim)`. + * `V` (Value): Input tensor of shape `(Batch, Heads, Seq_Len, Dim)`. + * `Output`: Output tensor of shape `(Batch, Heads, Seq_Len, Dim)`. + +* **Configuration Parameters:** + * `block_L` (int): Chunk size for the sequence dimension (Tiling size). Default: 64. + * `block_D` (int): Head dimension size. Default: 64. + * `dtype` (str): Data type for input/output tensors (e.g., "float16"). + * `accum_dtype` (str): Data type for the Recurrent State accumulation. Must be "float32" to prevent numerical overflow in long sequences. + +## Implementation Example + +The following code demonstrates a fused kernel using `mc_tilelang`. + +**MACA Optimization Note:** +To support the MACA backend efficiently and resolve layout conflicts between the two GEMM operations, we use a **Shared Memory Scratchpad** (`S_scratch`). + +1. **Layout Transformation:** Converts the Recurrent State $S$ from the Accumulator layout (output of first GEMM) to the Operand layout. +2. **Backend Compatibility:** Ensures the second GEMM is performed as `Shared * Shared`, which is compatible with MACA's `gemm_ss` instruction set. + +```python +import torch +import tilelang +import tilelang.language as T +import math + +def flash_linear_attention_kernel( + batch, heads, seq_len, dim, + block_L=64, block_D=64, + dtype="float16", accum_dtype="float" +): + """ + Flash Linear Attention Kernel (Chunkwise Recurrent). + Computes O = Q * CumSum(K^T * V) efficiently. + """ + # Calculate loop extent externally to ensure it's a static integer + # This prevents symbolic variable parsing errors in the pipeline + total_steps = (seq_len + block_L - 1) // block_L + + @T.prim_func + def main( + Q: T.Tensor((batch, heads, seq_len, dim), dtype), + K: T.Tensor((batch, heads, seq_len, dim), dtype), + V: T.Tensor((batch, heads, seq_len, dim), dtype), + Output: T.Tensor((batch, heads, seq_len, dim), dtype), + ): + # Parallelize over Heads and Batch + # threads=128 ensures good warp occupancy + with T.Kernel(heads, batch, threads=128) as (bx, by): + # 1. Shared Memory Allocations + # Double buffering is managed automatically by T.Pipelined + Q_shared = T.alloc_shared((block_L, block_D), dtype) + K_shared = T.alloc_shared((block_L, block_D), dtype) + V_shared = T.alloc_shared((block_L, block_D), dtype) + O_shared = T.alloc_shared((block_L, block_D), dtype) + + # 2. Layout Transformation Scratchpad (Shared Memory) + # Critical for connecting two GEMMs on MACA backend (gemm_ss support) + S_scratch = T.alloc_shared((block_D, block_D), dtype) + + # 3. Register Fragments + # S_accum: Recurrent State (FP32 Accumulator) + S_accum = T.alloc_fragment((block_D, block_D), accum_dtype) + # O_local: Output Accumulator (FP32 Fragment) + O_local = T.alloc_fragment((block_L, block_D), accum_dtype) + + # Initialize State to 0 + T.clear(S_accum) + + # 4. Pipelined Loop over Sequence Chunks + for lo in T.Pipelined(0, total_steps, num_stages=2): + + # Load Global -> Shared + T.copy(K[by, bx, lo * block_L, 0], K_shared) + T.copy(V[by, bx, lo * block_L, 0], V_shared) + T.copy(Q[by, bx, lo * block_L, 0], Q_shared) + + # Step A: Update State S += K^T * V + # Input: Shared, Shared -> Output: Fragment (Accumulator) + # K is transposed to (D, L) logically + T.gemm(K_shared, V_shared, S_accum, transpose_A=True, transpose_B=False) + + # Step B: Copy Accumulator to Shared Memory + # This performs: FP32->FP16 cast + Layout Transformation + # We keep data in Shared Memory to use gemm_ss later + T.copy(S_accum, S_scratch) + + # Step C: Compute Output O = Q * S + T.clear(O_local) + # Input: Q_shared (Shared), S_scratch (Shared) -> Output: Fragment + # This utilizes the efficient gemm_ss instruction on MACA + T.gemm(Q_shared, S_scratch, O_local, transpose_A=False, transpose_B=False) + + # Store Result + T.copy(O_local, O_shared) + T.copy(O_shared, Output[by, bx, lo * block_L, 0]) + + return main + +def benchmark(): + # Test Configuration + B, H, L, D = 1, 16, 2048, 64 + block_L, block_D = 64, 64 + print(f"Benchmarking Flash Linear Attention: B={B}, H={H}, L={L}, D={D}") + + device = "cuda" + torch.manual_seed(42) + if not torch.cuda.is_available(): + print("CUDA not available.") + return + + q = torch.randn(B, H, L, D, dtype=torch.float16, device=device) + k = torch.randn(B, H, L, D, dtype=torch.float16, device=device) + v = torch.randn(B, H, L, D, dtype=torch.float16, device=device) + + # Scale Q for numerical stability + q = q * (D ** -0.5) + + print("Compiling Kernel...") + func = flash_linear_attention_kernel(B, H, L, D, block_L, block_D) + # The backend is automatically detected (MACA/CUDA compatible) + kernel = tilelang.compile(func, out_idx=[3]) + + out_tl = kernel(q, k, v) + + print("Verifying correctness...") + def torch_ref(q, k, v): + out = [] + state = torch.zeros((B, H, D, D), device=q.device, dtype=torch.float32) + q_chunk = q.view(B, H, L // block_L, block_L, D) + k_chunk = k.view(B, H, L // block_L, block_L, D) + v_chunk = v.view(B, H, L // block_L, block_L, D) + + for i in range(L // block_L): + qi = q_chunk[:, :, i, :, :].float() + ki = k_chunk[:, :, i, :, :].float() + vi = v_chunk[:, :, i, :, :].float() + # S += K^T * V + state += torch.einsum('bhld,bhlx->bhdx', ki, vi) + # O = Q * S + out.append(torch.matmul(qi, state)) + return torch.cat(out, dim=2).to(torch.float16) + + out_ref = torch_ref(q, k, v) + + try: + torch.testing.assert_close(out_tl, out_ref, rtol=0.05, atol=0.05) + print("✅ Correctness Passed!") + except AssertionError as e: + print(f"❌ Verification Failed: {e}") + + # Performance Profiling + profiler = kernel.get_profiler() + latency = profiler.do_bench() + flops = 2 * B * H * L * D * D * 2 # Approx FLOPs + tflops = (flops / 1e12) / (latency / 1000) + print(f"Latency: {latency:.3f} ms") + print(f"Throughput: {tflops:.2f} TFLOPS") + +if __name__ == "__main__": + benchmark() +``` + +## Performance Analysis + +Performance benchmarks were conducted on a **XiYun C500 64G** environment. + +### 1. Benchmark Results + +| Metric | Value | Configuration | +| -------------- | ---------------- | ---------------------- | +| **Batch Size** | 1 | B=1 | +| **Heads** | 16 | H=16 | +| **Seq Length** | 2048 | L=2048 | +| **Head Dim** | 64 | D=64 | +| **Latency** | **~0.210 ms** | - | +| **Throughput** | **~2.55 TFLOPS** | block_L=64, block_D=64 | + +### 2. Recommended Configuration + +- **Block Size:** block_D=64 and block_L=64 provide the best balance between Shared Memory usage and Warp Occupancy on C500 hardware. +- **Pipeline Stages:** num_stages=2 effectively hides global memory access latency. +- **Thread Block:** threads=128 (4 Warps) matches the tile size requirements. + +### 3. Optimization Strategy + +- **Gemm Instruction**: The implementation explicitly routes data through Shared Memory (S_scratch) to utilize the high-performance gemm_ss instruction available on MACA, avoiding the slower or unsupported register-based paths. +- **Layout Management**: The use of S_scratch acts as a layout transformer, converting the Accumulator layout from the first GEMM into the Operand layout required for the second GEMM without manual bit-level manipulation. +- \ No newline at end of file