diff --git a/docs/deeplearning_operators/flash_attention.md b/docs/deeplearning_operators/flash_attention.md
index 115f318c0a4fe76e9fe09ef992ad7f55d043ba83..2ff04fed76df7d8445d5981d42be1dba3b89d6a8 100644
--- a/docs/deeplearning_operators/flash_attention.md
+++ b/docs/deeplearning_operators/flash_attention.md
@@ -1,2 +1,429 @@
-Flash Attention
-==================
+# Flash Attention Operator
+=====================
+
+
+
+:::{warning}
+ This document is still **experimental** and may be incomplete.
+ Suggestions and improvements are highly encouraged—please submit a PR!
+:::
+
+:::{tip}
+Example code can be found at [`examples/flash_attention/example_mha_fwd_bshd_extended.py`](../../examples/flash_attention/example_mha_fwd_bshd_extended.py).
+:::
+
+Flash Attention is a memory-efficient attention mechanism designed to reduce the memory footprint of standard attention computations in Transformer models. It achieves this by computing attention in small tiles, avoiding the need to store the full attention matrix in memory.
+
+## Operator Functionality
+
+### Application Scenarios
+Flash Attention operators are essential in deep learning for tasks such as:
+- **Transformer Model Attention Layers**: Efficiently computing self-attention in encoder-decoder architectures.
+- **Large Language Model Training and Inference**: Scaling attention computations for models like GPT, BERT, and LLaMA by reducing memory usage from O(N²) to O(N).
+- **Sequence Modeling**: Handling long sequences in tasks like machine translation, text generation, and time-series prediction.
+
+### Core Computation Logic
+The Flash Attention operation computes attention by dividing the input sequences into smaller blocks and performing incremental softmax updates. This avoids storing the full attention matrix, significantly reducing memory usage while maintaining numerical stability.
+
+Key innovations include:
+- **Tiled Computation**: Process Q, K, and V in small blocks to fit in fast memory (e.g., GPU shared memory).
+- **Online Softmax**: Update maximum values and exponential sums incrementally for each tile, ensuring numerical stability without storing the entire matrix.
+- **Memory Efficiency**: Reduce memory from O(N²) to O(N) by recomputing attention weights on-the-fly.
+
+The core formula for attention is:
+
+$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$
+
+Where Flash Attention computes this in tiles with online softmax updates.
+
+## Interface Parameters
+
+### Input Parameters
+- **Q (Query)**: Tensor of shape `(batch_size, seq_len, num_heads, head_dim)`, data type typically `float16` or `float32`.
+- **K (Key)**: Tensor of shape `(batch_size, seq_len, num_heads, head_dim)`, data type matching Q.
+- **V (Value)**: Tensor of shape `(batch_size, seq_len, num_heads, head_dim)`, data type matching Q.
+
+### Output Parameters
+- **Output**: Tensor of shape `(batch_size, seq_len, num_heads, head_dim)`, data type matching inputs.
+
+### Optional Parameters
+
+These are the optional parameters provided in the example program:
+
+- **batch_size**: Integer, number of sequences in the batch.
+- **seq_len**: Integer, length of the input sequences.
+- **num_heads**: Integer, number of attention heads.
+- **head_dim**: Integer, dimension of each attention head.
+- **dropout_p**: Float, dropout probability (default 0.0).
+- **scale**: Float, scaling factor for attention scores (default `1/sqrt(head_dim)`).
+- **causal**: Boolean, whether to apply causal masking (default False).
+
+## Usage Example
+
+Below is an example of using the Flash Attention operator in TileLang within a MACA environment. This assumes a CUDA-compatible setup with mcTileLang installed. The example is divided into sections for clarity: imports and setup, kernel definition, reference implementation, data creation, and main function.
+
+### Imports and Setup
+
+```python
+import os
+import torch
+import torch.nn.functional as F
+import tilelang
+from tilelang.autotuner import *
+import tilelang.language as T
+import itertools
+import argparse
+from functools import partial
+import math
+
+def get_configs():
+ """Auto-tuning configuration parameters."""
+ iter_params = dict(
+ block_M=[64, 128],
+ block_N=[64, 128],
+ num_stages=[1, 2],
+ threads=[128, 256]
+ )
+ return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
+```
+
+### Kernel Definition
+
+```python
+@autotune(configs=get_configs(), warmup=10, rep=10)
+@tilelang.jit(out_idx=[3])
+def flash_attention(
+ batch: int,
+ heads: int,
+ seq_len: int,
+ dim: int,
+ dropout_p: float = 0.0,
+ scale: float = None,
+ causal: bool = False,
+ block_M: int = 64,
+ block_N: int = 64,
+ num_stages: int = 1,
+ threads: int = 128
+):
+ """
+ Inputs: tensors with shape [batch, seq_len, heads, dim] (BSHD layout).
+ Returns: tensor with same shape.
+ """
+ # Compute scale factor
+ if scale is None:
+ scale_factor = 1.0 / math.sqrt(dim)
+ else:
+ scale_factor = scale
+
+ shape = [batch, seq_len, heads, dim]
+ dtype = "float16"
+ accum_dtype = "float"
+
+ @T.macro
+ def QK_Attention(
+ K: T.Tensor(shape, dtype),
+ Q_shared: T.SharedBuffer([block_M, dim], dtype),
+ K_shared: T.SharedBuffer([block_N, dim], dtype),
+ acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
+ k: T.int32,
+ bx: T.int32,
+ by: T.int32,
+ bz: T.int32,
+ ):
+ """Compute Q*K^T scores and apply masks."""
+ # Load K block to shared memory
+ T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
+
+ if causal:
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.if_then_else(
+ bx * block_M + i >= k * block_N + j,
+ 0,
+ -T.infinity(acc_s.dtype)
+ )
+ else:
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.if_then_else(
+ k * block_N + j >= seq_len,
+ -T.infinity(acc_s.dtype),
+ 0
+ )
+
+ T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] *= scale_factor
+
+ @T.macro
+ def Attention_Value(
+ V: T.Tensor(shape, dtype),
+ V_shared: T.SharedBuffer([block_N, dim], dtype),
+ acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
+ acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
+ k: T.int32,
+ by: T.int32,
+ bz: T.int32,
+ ):
+ """Compute attention-weighted V contribution."""
+ T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
+ T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
+
+ @T.macro
+ def Online_Softmax(
+ acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
+ acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
+ scores_max: T.FragmentBuffer([block_M], accum_dtype),
+ scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
+ scores_scale: T.FragmentBuffer([block_M], accum_dtype),
+ scores_sum: T.FragmentBuffer([block_M], accum_dtype),
+ logsum: T.FragmentBuffer([block_M], accum_dtype),
+ ):
+ """Online softmax with numerical stability and optional dropout."""
+ T.copy(scores_max, scores_max_prev)
+ T.fill(scores_max, -T.infinity(accum_dtype))
+
+ # Compute row max for current block
+ T.reduce_max(acc_s, scores_max, dim=1, clear=False)
+
+ for i in T.Parallel(block_M):
+ scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
+
+ # Compute scaling factors and exponentiate
+ for i in T.Parallel(block_M):
+ scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
+
+ keep_prob = 1.0 - dropout_p
+ keep_prob_threshold = int((1.0 - dropout_p) * 100)
+
+ for i, j in T.Parallel(block_M, block_N):
+ rand_val = (i * 31 + j * 17) % 100
+ dropout_mask = T.if_then_else(
+ rand_val < keep_prob_threshold,
+ 1.0 / keep_prob,
+ 0.0
+ )
+ acc_s[i, j] = acc_s[i, j] * dropout_mask
+
+ # Compute row sums and update global logsum
+ T.reduce_sum(acc_s, scores_sum, dim=1)
+ for i in T.Parallel(block_M):
+ logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
+
+ T.copy(acc_s, acc_s_cast)
+
+ @T.macro
+ def Rescale_Output(
+ acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
+ scores_scale: T.FragmentBuffer([block_M], accum_dtype),
+ ):
+ """Rescale partial outputs by scores_scale."""
+ for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] *= scores_scale[i]
+
+ @T.prim_func
+ def flash_attention_kernel(
+ Q: T.Tensor(shape, dtype),
+ K: T.Tensor(shape, dtype),
+ V: T.Tensor(shape, dtype),
+ Output: T.Tensor(shape, dtype),
+ ):
+ """Flash Attention kernel."""
+ # Launch 3D thread grid: (seq_blocks, heads, batch)
+ with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
+ # Shared memory buffers
+ Q_shared = T.alloc_shared([block_M, dim], dtype)
+ K_shared = T.alloc_shared([block_N, dim], dtype)
+ V_shared = T.alloc_shared([block_N, dim], dtype)
+ O_shared = T.alloc_shared([block_M, dim], dtype)
+
+ # Register fragments and accumulators
+ acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
+ acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
+ acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
+ scores_max = T.alloc_fragment([block_M], accum_dtype)
+ scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
+ scores_scale = T.alloc_fragment([block_M], accum_dtype)
+ scores_sum = T.alloc_fragment([block_M], accum_dtype)
+ logsum = T.alloc_fragment([block_M], accum_dtype)
+
+ # Load Q block
+ T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
+
+ T.fill(acc_o, 0)
+ T.fill(logsum, 0)
+ T.fill(scores_max, -T.infinity(accum_dtype))
+ loop_range = T.ceildiv(seq_len, block_N) # Default value
+ if causal:
+ loop_range = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
+
+ # Main pipelined loop over K/V blocks
+ for k in T.Pipelined(loop_range, num_stages=num_stages):
+ QK_Attention(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
+ Online_Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
+ Rescale_Output(acc_o, scores_scale)
+ Attention_Value(V, V_shared, acc_s_cast, acc_o, k, by, bz)
+
+ # Final normalization and write back
+ for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] /= logsum[i]
+ T.copy(acc_o, O_shared)
+ T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
+
+ return flash_attention_kernel
+```
+
+### Reference Implementation
+
+```python
+def reference_attention(Q, K, V, dropout_p=0.0, scale=None, causal=False):
+ batch, seq_len, heads, dim = Q.shape
+ if scale is None:
+ scale = 1.0 / math.sqrt(dim)
+
+ scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) * scale
+
+ if causal:
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=scores.device), diagonal=1).bool()
+ mask = mask.unsqueeze(0).unsqueeze(0)
+ scores = scores.masked_fill(mask, float('-inf'))
+
+ attention_weights = F.softmax(scores, dim=-1)
+ if dropout_p > 0.0:
+ attention_weights = F.dropout(attention_weights, p=dropout_p, training=False)
+
+ output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
+ return output
+```
+
+### Data Creation
+
+```python
+def create_test_data(batch, heads, seq_len, dim, device='meta', dtype=torch.float16):
+ """Create test Q/K/V tensors."""
+ torch.manual_seed(42)
+ Q = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) * 0.1
+ K = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) * 0.1
+ V = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) * 0.1
+ return Q, K, V
+```
+
+### Main Function
+
+```python
+def main(
+ batch: int = 4,
+ heads: int = 16,
+ seq_len: int = 1024,
+ dim: int = 64,
+ dropout_p: float = 0.0,
+ scale: float = None,
+ causal: bool = False,
+ tune: bool = False
+):
+ """Example usage and benchmarking of Enhanced Flash Attention."""
+ print("mcTileLang Flash Attention example (Auto-tuning enabled)")
+ print(f"Config: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}")
+ print(f"dropout_p={dropout_p}, scale={scale if scale is not None else f'default(1/sqrt({dim})={1.0/math.sqrt(dim):.4f})'}, causal={causal}")
+
+ flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
+ total_flops = 2 * flops_per_matmul
+ if causal:
+ total_flops *= 0.5
+ print(f"Theoretical FLOPs: {total_flops / 1e9:.2f} GFlops")
+
+ Q, K, V = create_test_data(batch, heads, seq_len, dim)
+ print(f"Q shape: {Q.shape}, dtype: {Q.dtype}")
+ print(f"K shape: {K.shape}, dtype: {K.dtype}")
+ print(f"V shape: {V.shape}, dtype: {V.dtype}")
+
+ print("Running auto-tuning and creating kernel...")
+
+ # Call flash_attention once to get tuning result
+ result = flash_attention(
+ batch=batch,
+ heads=heads,
+ seq_len=seq_len,
+ dim=dim,
+ dropout_p=dropout_p,
+ scale=scale,
+ causal=causal,
+ )
+
+ # Display tuning results (detailed if tune=True)
+ if tune:
+ print(f"\n{'='*60}")
+ print("Auto-tuning Results:")
+ print(f"{'='*60}")
+ print(f"Best latency: {result.latency:.2f} ms")
+ print(f"Best config: {result.config}")
+ print(f"Best TFlops: {total_flops / result.latency * 1e-9:.2f}")
+ if hasattr(result, 'ref_latency') and result.ref_latency:
+ print(f"Reference latency: {result.ref_latency:.2f} ms")
+ print(f"Speedup: {result.ref_latency / result.latency:.2f}x")
+
+ # Use tuning result, no need to recreate kernel
+ kernel = result
+
+ profiler = kernel.get_profiler()
+ ref_program_partial = partial(
+ reference_attention,
+ dropout_p=dropout_p,
+ scale=scale,
+ causal=causal,
+ )
+
+ # Check if dropout is enabled
+ if dropout_p > 0.0:
+ print(f"\nDropout enabled (p={dropout_p}), skipping correctness verification due to randomness")
+ print("Note: Dropout introduces randomness, making deterministic comparison impossible")
+ else:
+ # Only perform correctness verification when dropout is not enabled
+ print("\nVerifying correctness...")
+ try:
+ profiler.assert_allclose(ref_program_partial, rtol=0.01, atol=0.01)
+ print("✅ Correctness: PASSED")
+ except AssertionError as e:
+ print(f"❌ Correctness: FAILED: {e}")
+ return # Early return if correctness check fails
+
+ print("\nBenchmarking...")
+
+ ref_latency = profiler.do_bench(ref_program_partial, warmup=500)
+ print(f"Reference latency: {ref_latency:.2f} ms")
+ print(f"Reference: {total_flops / ref_latency * 1e-9:.2f} TFlops")
+
+ tl_latency = profiler.do_bench(warmup=500)
+ print(f"TileLang latency: {tl_latency:.2f} ms")
+ print(f"TileLang: {total_flops / tl_latency * 1e-9:.2f} TFlops")
+ print(f"Speedup: {ref_latency / tl_latency:.2f}x")
+
+ # Always display best config info
+ print(f"\nUsing optimized configuration: {result.config}")
+
+ # If dropout is enabled, provide additional notes
+ if dropout_p > 0.0:
+ print(f"\nNote: Results include dropout (p={dropout_p}) effects")
+ print("Performance comparison is still valid as both implementations use similar dropout")
+```
+
+## Performance Notes
+
+### Recommended Configurations on Specific Devices
+- **GPU Model**: Xiyun C500-64GB
+ - **Optimal Config**: {'block_M': 64, 'block_N': 64, 'num_stages': 1, 'threads': 256}
+ - **Note**: This configuration was determined via autotuning for default parameters (batch=4, heads=16, seq_len=2048, dim=64) within the manually set tuning range in the example program. Adjust for specific use cases.
+
+Recommended configurations can be obtained by setting tuning configurations to tune over a larger range and for other data scenarios. For example, expand the value range of `iter_params` in the `get_configs()` function (such as adding options for block_M and block_N), and run autotuning for different batch_size, seq_len, or heads values to obtain optimized configurations for specific hardware and data shapes.
+
+### Performance Optimization Suggestions
+- **Pipelining**: Overlap computation and memory accesses to hide latency and maximize compute utilization.
+- **Variable Sequence Lengths**: Support varying sequence lengths in batches to skip padding computations.
+- **Data Layout**: Try BHSD layout instead of BSHD for improved memory coalescing.
+- **Inference for Long Sequences**: Chunk K and V along sequence dimension for long-context handling.
+- **Causal Masking**: For causal attention, adjust loop ranges to skip unnecessary computations.
+
+The complete example code mentioned in the sections above is located at [`examples/flash_attention/example_mha_fwd_bshd_extended.py`](../../examples/flash_attention/example_mha_fwd_bshd_extended.py). This example can be considered an extended version of [`examples/flash_attention/example_mha_fwd_bshd.py`](../../examples/flash_attention/example_mha_fwd_bshd.py)., supporting additional optional parameters and enhanced features. Contributions to improve this documentation are welcome!
\ No newline at end of file
diff --git a/examples/flash_attention/example_mha_fwd_bshd_extended.py b/examples/flash_attention/example_mha_fwd_bshd_extended.py
new file mode 100644
index 0000000000000000000000000000000000000000..72598fda894ffd93a2c37670fdd40b52a9af10a1
--- /dev/null
+++ b/examples/flash_attention/example_mha_fwd_bshd_extended.py
@@ -0,0 +1,353 @@
+import os
+import torch
+import torch.nn.functional as F
+import tilelang
+from tilelang.autotuner import *
+import tilelang.language as T
+import itertools
+import argparse
+from functools import partial
+import math
+
+
+def get_configs():
+ """Auto-tuning configuration parameters."""
+ iter_params = dict(
+ block_M=[64, 128],
+ block_N=[64, 128],
+ num_stages=[1, 2],
+ threads=[128, 256]
+ )
+ return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
+
+
+@autotune(configs=get_configs(), warmup=10, rep=10)
+@tilelang.jit(out_idx=[3])
+def flash_attention(
+ batch: int,
+ heads: int,
+ seq_len: int,
+ dim: int,
+ dropout_p: float = 0.0,
+ scale: float = None,
+ causal: bool = False,
+ block_M: int = 64,
+ block_N: int = 64,
+ num_stages: int = 1,
+ threads: int = 128
+):
+ """
+ Inputs: tensors with shape [batch, seq_len, heads, dim] (BSHD layout).
+ Returns: tensor with same shape.
+ """
+ # Compute scale factor
+ if scale is None:
+ scale_factor = 1.0 / math.sqrt(dim)
+ else:
+ scale_factor = scale
+
+ shape = [batch, seq_len, heads, dim]
+ dtype = "float16"
+ accum_dtype = "float"
+
+ @T.macro
+ def QK_Attention(
+ K: T.Tensor(shape, dtype),
+ Q_shared: T.SharedBuffer([block_M, dim], dtype),
+ K_shared: T.SharedBuffer([block_N, dim], dtype),
+ acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
+ k: T.int32,
+ bx: T.int32,
+ by: T.int32,
+ bz: T.int32,
+ ):
+ """Compute Q*K^T scores and apply masks."""
+ # Load K block to shared memory
+ T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
+
+ if causal:
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.if_then_else(
+ bx * block_M + i >= k * block_N + j,
+ 0,
+ -T.infinity(acc_s.dtype)
+ )
+ else:
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.if_then_else(
+ k * block_N + j >= seq_len,
+ -T.infinity(acc_s.dtype),
+ 0
+ )
+
+ T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] *= scale_factor
+
+ @T.macro
+ def Attention_Value(
+ V: T.Tensor(shape, dtype),
+ V_shared: T.SharedBuffer([block_N, dim], dtype),
+ acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
+ acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
+ k: T.int32,
+ by: T.int32,
+ bz: T.int32,
+ ):
+ """Compute attention-weighted V contribution."""
+ T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
+ T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
+
+ @T.macro
+ def Online_Softmax(
+ acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
+ acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
+ scores_max: T.FragmentBuffer([block_M], accum_dtype),
+ scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
+ scores_scale: T.FragmentBuffer([block_M], accum_dtype),
+ scores_sum: T.FragmentBuffer([block_M], accum_dtype),
+ logsum: T.FragmentBuffer([block_M], accum_dtype),
+ ):
+ """Online softmax with numerical stability and optional dropout."""
+ T.copy(scores_max, scores_max_prev)
+ T.fill(scores_max, -T.infinity(accum_dtype))
+
+ # Compute row max for current block
+ T.reduce_max(acc_s, scores_max, dim=1, clear=False)
+
+ for i in T.Parallel(block_M):
+ scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
+
+ # Compute scaling factors and exponentiate
+ for i in T.Parallel(block_M):
+ scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i])
+ for i, j in T.Parallel(block_M, block_N):
+ acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i])
+
+ keep_prob = 1.0 - dropout_p
+ keep_prob_threshold = int((1.0 - dropout_p) * 100)
+
+ for i, j in T.Parallel(block_M, block_N):
+ rand_val = (i * 31 + j * 17) % 100
+ dropout_mask = T.if_then_else(
+ rand_val < keep_prob_threshold,
+ 1.0 / keep_prob,
+ 0.0
+ )
+ acc_s[i, j] = acc_s[i, j] * dropout_mask
+
+ # Compute row sums and update global logsum
+ T.reduce_sum(acc_s, scores_sum, dim=1)
+ for i in T.Parallel(block_M):
+ logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
+
+ T.copy(acc_s, acc_s_cast)
+
+ @T.macro
+ def Rescale_Output(
+ acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
+ scores_scale: T.FragmentBuffer([block_M], accum_dtype),
+ ):
+ """Rescale partial outputs by scores_scale."""
+ for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] *= scores_scale[i]
+
+ @T.prim_func
+ def flash_attention_kernel(
+ Q: T.Tensor(shape, dtype),
+ K: T.Tensor(shape, dtype),
+ V: T.Tensor(shape, dtype),
+ Output: T.Tensor(shape, dtype),
+ ):
+ """Flash Attention kernel."""
+ # Launch 3D thread grid: (seq_blocks, heads, batch)
+ with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
+ # Shared memory buffers
+ Q_shared = T.alloc_shared([block_M, dim], dtype)
+ K_shared = T.alloc_shared([block_N, dim], dtype)
+ V_shared = T.alloc_shared([block_N, dim], dtype)
+ O_shared = T.alloc_shared([block_M, dim], dtype)
+
+ # Register fragments and accumulators
+ acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
+ acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
+ acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
+ scores_max = T.alloc_fragment([block_M], accum_dtype)
+ scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
+ scores_scale = T.alloc_fragment([block_M], accum_dtype)
+ scores_sum = T.alloc_fragment([block_M], accum_dtype)
+ logsum = T.alloc_fragment([block_M], accum_dtype)
+
+ # Load Q block
+ T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
+
+ T.fill(acc_o, 0)
+ T.fill(logsum, 0)
+ T.fill(scores_max, -T.infinity(accum_dtype))
+ loop_range = T.ceildiv(seq_len, block_N) # Default value
+ if causal:
+ loop_range = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
+
+ # Main pipelined loop over K/V blocks
+ for k in T.Pipelined(loop_range, num_stages=num_stages):
+ QK_Attention(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
+ Online_Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
+ Rescale_Output(acc_o, scores_scale)
+ Attention_Value(V, V_shared, acc_s_cast, acc_o, k, by, bz)
+
+ # Final normalization and write back
+ for i, j in T.Parallel(block_M, dim):
+ acc_o[i, j] /= logsum[i]
+ T.copy(acc_o, O_shared)
+ T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
+
+ return flash_attention_kernel
+
+
+def reference_attention(Q, K, V, dropout_p=0.0, scale=None, causal=False):
+ batch, seq_len, heads, dim = Q.shape
+ if scale is None:
+ scale = 1.0 / math.sqrt(dim)
+
+ scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) * scale
+
+ if causal:
+ mask = torch.triu(torch.ones(seq_len, seq_len, device=scores.device), diagonal=1).bool()
+ mask = mask.unsqueeze(0).unsqueeze(0)
+ scores = scores.masked_fill(mask, float('-inf'))
+
+ attention_weights = F.softmax(scores, dim=-1)
+ if dropout_p > 0.0:
+ attention_weights = F.dropout(attention_weights, p=dropout_p, training=False)
+
+ output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
+ return output
+
+
+def create_test_data(batch, heads, seq_len, dim, device='meta', dtype=torch.float16):
+ """Create test Q/K/V tensors."""
+ torch.manual_seed(42)
+ Q = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) * 0.1
+ K = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) * 0.1
+ V = torch.randn(batch, seq_len, heads, dim, device=device, dtype=dtype) * 0.1
+ return Q, K, V
+
+def main(
+ batch: int = 4,
+ heads: int = 16,
+ seq_len: int = 1024,
+ dim: int = 64,
+ dropout_p: float = 0.0,
+ scale: float = None,
+ causal: bool = False,
+ tune: bool = False
+):
+ """Example usage and benchmarking of Enhanced Flash Attention."""
+ print("mcTileLang Flash Attention example (Auto-tuning enabled)")
+ print(f"Config: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}")
+ print(f"dropout_p={dropout_p}, scale={scale if scale is not None else f'default(1/sqrt({dim})={1.0/math.sqrt(dim):.4f})'}, causal={causal}")
+
+ flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
+ total_flops = 2 * flops_per_matmul
+ if causal:
+ total_flops *= 0.5
+ print(f"Theoretical FLOPs: {total_flops / 1e9:.2f} GFlops")
+
+ Q, K, V = create_test_data(batch, heads, seq_len, dim)
+ print(f"Q shape: {Q.shape}, dtype: {Q.dtype}")
+ print(f"K shape: {K.shape}, dtype: {K.dtype}")
+ print(f"V shape: {V.shape}, dtype: {V.dtype}")
+
+ print("Running auto-tuning and creating kernel...")
+
+ # Call flash_attention only once to get tuning results
+ result = flash_attention(
+ batch=batch,
+ heads=heads,
+ seq_len=seq_len,
+ dim=dim,
+ dropout_p=dropout_p,
+ scale=scale,
+ causal=causal,
+ )
+
+ # Display tuning results (if tune=True, show detailed info)
+ if tune:
+ print(f"\n{'='*60}")
+ print("Auto-tuning Results:")
+ print(f"{'='*60}")
+ print(f"Best latency: {result.latency:.2f} ms")
+ print(f"Best config: {result.config}")
+ print(f"Best TFlops: {total_flops / result.latency * 1e-9:.2f}")
+ if hasattr(result, 'ref_latency') and result.ref_latency:
+ print(f"Reference latency: {result.ref_latency:.2f} ms")
+ print(f"Speedup: {result.ref_latency / result.latency:.2f}x")
+
+ # Use tuning results, no need to recreate kernel
+ kernel = result
+
+ profiler = kernel.get_profiler()
+ ref_program_partial = partial(
+ reference_attention,
+ dropout_p=dropout_p,
+ scale=scale,
+ causal=causal,
+ )
+
+ # Check if dropout is enabled
+ if dropout_p > 0.0:
+ print(f"\nDropout enabled (p={dropout_p}), skipping correctness verification due to randomness")
+ print("Note: Dropout introduces randomness, making deterministic comparison impossible")
+ else:
+ # Only perform correctness verification when dropout is not enabled
+ print("\nVerifying correctness...")
+ try:
+ profiler.assert_allclose(ref_program_partial, rtol=0.01, atol=0.01)
+ print("✅ Correctness: PASSED")
+ except AssertionError as e:
+ print(f"❌ Correctness: FAILED: {e}")
+ return # Early return if correctness check fails
+
+ print("\nBenchmarking...")
+
+ ref_latency = profiler.do_bench(ref_program_partial, warmup=500)
+ print(f"Reference latency: {ref_latency:.2f} ms")
+ print(f"Reference: {total_flops / ref_latency * 1e-9:.2f} TFlops")
+
+ tl_latency = profiler.do_bench(warmup=500)
+ print(f"TileLang latency: {tl_latency:.2f} ms")
+ print(f"TileLang: {total_flops / tl_latency * 1e-9:.2f} TFlops")
+ print(f"Speedup: {ref_latency / tl_latency:.2f}x")
+
+ # Always display best config info
+ print(f"\nUsing optimized configuration: {result.config}")
+
+ # If dropout is enabled, provide additional notes
+ if dropout_p > 0.0:
+ print(f"\nNote: Results include dropout (p={dropout_p}) effects")
+ print("Performance comparison is still valid as both implementations use similar dropout")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Enhanced Flash Attention example')
+ parser.add_argument('--batch', type=int, default=4)
+ parser.add_argument('--heads', type=int, default=16)
+ parser.add_argument('--seq_len', type=int, default=2048)
+ parser.add_argument('--dim', type=int, default=64)
+ parser.add_argument('--dropout_p', type=float, default=0.0)
+ parser.add_argument('--scale', type=float, default=None)
+ parser.add_argument('--causal', action='store_true')
+ parser.add_argument('--tune', action='store_true', help='Show detailed tuning results')
+ args = parser.parse_args()
+
+ main(
+ batch=args.batch,
+ heads=args.heads,
+ seq_len=args.seq_len,
+ dim=args.dim,
+ dropout_p=args.dropout_p,
+ scale=args.scale,
+ causal=args.causal,
+ tune=args.tune
+ )