diff --git a/docs/deeplearning_operators/matmul_dequant.md b/docs/deeplearning_operators/matmul_dequant.md index cdbc3cfc87a54797ee308490c7863aefc5dc6d17..15dc5df9afb68ff982d0562bce9e1c9930df3c4d 100644 --- a/docs/deeplearning_operators/matmul_dequant.md +++ b/docs/deeplearning_operators/matmul_dequant.md @@ -1,2 +1,408 @@ -General Matrix-Matrix Multiplication with Dequantization -========================================================= +### 🚀 High-Performance Matrix Multiplication Dequantization (MatMul Dequant) + +TileLang is designed to lower the barrier for building high-performance custom operators. In this document, we present the implementation of a Matrix Multiplication with Dequantization (MatMul Dequant) operator. This kernel is specifically optimized for the MACA (MetaX) architecture, supporting both Per-Tensor and Per-Channel quantization modes. + +#### 1\. Operator Function + +##### 1.1 Application Scenario + +In Large Language Model (LLM) inference (e.g., Llama-3, Qwen), memory bandwidth is the primary bottleneck. \*\*Weight-Only Quantization (W8A16)\*\* is the industry standard to address this: + +\* \*\*Storage:\*\* Weights (A) are compressed to `Int8` to halve memory usage. + +\* \*\*Computation:\*\* Activations (B) remain in `FP16`. + +\* \*\*Process:\*\* The operator dequantizes weights \*on-the-fly\* in registers and immediately performs matrix multiplication. + + + +##### 1.2 Core Computation Logic + +This operator performs a fused operation of dequantization followed by matrix multiplication. The computation workflow is as follows: + + + +\*\*Step 1: Dequantization\*\* + +Restores the quantized integer matrix to a floating-point matrix. + + + +$$ +A\_{float} = (A\_{quantized} - ZeroPoint) *times Scale +$$ + +\*\*Step 2: Matrix Multiplication\*\* + +Performs standard matrix multiplication on the dequantized matrix. + + + +$$ +Output = A\_{float} @times B +$$ + +\*\*Fusion Advantages\*\* + +Traditional implementations typically execute dequantization first to generate a complete intermediate result in global memory, followed by matrix multiplication. This operator fuses these two steps: dequantization is executed \*\*on-the-fly\*\* immediately after loading data into registers. This avoids the explicit storage of intermediate results, significantly reducing memory access overhead and bandwidth usage. + +#### 2\. Interface Parameters + +##### 2.1 Input Parameters + +| Parameter | Type | Shape | Description | +| :--- | :--- | :--- | :--- | +| **A** | `int8` | `[M, K]` | The quantized weight matrix. | +| **B** | `float16` | `[K, N]` | The activation input matrix. | +| **Scale** | `float16` | `[1]` or `[K]` | Scaling factor. Scalar for Per-Tensor; Vector for Per-Channel. | +| **ZeroPoint**| `float16` | `[1]` or `[K]` | Zero point offset. Scalar for Per-Tensor; Vector for Per-Channel. | + +##### 2.2 Output Parameters + +| Parameter | Type | Shape | Description | +| :--- | :--- | :--- | :--- | +| **Output** | `float16` | `[M, N]` | The result of the matrix multiplication. | + +##### 2.3 Optional Parameters + +| Parameter | Type | Default | Description | +| :--- | :--- | :--- | :--- | +| **quant_mode** | `str` | `"per_tensor"` | `"per_tensor"` (High Perf) or `"per_channel"` (High Accuracy). | +| **dtype** | `str` | `"float16"` | Output data precision. | + +#### 3. Usage Examples + +##### 3.1 Code Example + +The code below illustrates how to implement both **Per-Tensor** and **Per-Channel** quantization using TileLang. It leverages software pipelining optimizations specifically tuned for the MACA architecture. + +###### 1.Per-Tensor test code + +``` +import torch +import tilelang as tl +import tilelang.language as T + +def matmul_dequant_kernel(M, N, K, block_M=128, block_N=128, block_K=32): + # 定义数据类型 + dtype_a = "int8" + dtype_b = "float16" + dtype_c = "float16" + dtype_accum = "float32" + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype_a), + B: T.Tensor((K, N), dtype_b), + C: T.Tensor((M, N), dtype_c), + # [Per-Tensor] Scale 和 ZP 是标量 + Scale: T.Tensor((1,), dtype_b), + ZeroPoint: T.Tensor((1,), dtype_b), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + + # Shared Memory 分配 (MACA 必须显式管理) + A_shared = T.alloc_shared((block_M, block_K), dtype_a) + B_shared = T.alloc_shared((block_K, block_N), dtype_b) + + C_local = T.alloc_fragment((block_M, block_N), dtype_accum) + T.clear(C_local) + + scale_val = Scale[0] + zp_val = ZeroPoint[0] + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared) + + A_int8_frag = T.alloc_fragment((block_M, block_K), dtype_a) + A_frag = T.alloc_fragment((block_M, block_K), dtype_b) + T.copy(A_shared, A_int8_frag) + + for i, j in T.Parallel(block_M, block_K): + A_frag[i, j] = (T.cast(A_int8_frag[i, j], dtype_b) - zp_val) * scale_val + + T.gemm(A_frag, B_shared, C_local) + + T.copy(C_local, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + +def matmul_dequant(A, B, scale, zero_point, quant_mode="per_tensor"): + M, K = A.shape + K_b, N = B.shape + assert K == K_b, "Matrix dimensions mismatch" + + # 根据 quant_mode 选择 Kernel (这里演示 Per-Tensor) + if quant_mode == "per_tensor": + tl_kernel = matmul_dequant_kernel(M, N, K) + else: + # 实际项目中这里会调用 per_channel kernel + # 为了本脚本演示,我们默认跑 per_tensor + tl_kernel = matmul_dequant_kernel(M, N, K) + + # 编译 + program = tl.compile(tl_kernel, target="maca") + + # 构造输出 + C = torch.zeros(M, N, dtype=torch.float16, device=A.device) + + # 执行 + program(A, B, C, scale, zero_point) + + return C + +def run_benchmark(): + M, N, K = 4096, 4096, 4096 + + try: + device = torch.device("maca") + except: + device = torch.device("cuda") + + print(f"Running Per-Tensor Benchmark via Wrapper on {device}...") + + # 数据构造 + A_quant = torch.randint(-128, 127, (M, K), dtype=torch.int8, device=device) + B = torch.randn(K, N, dtype=torch.float16, device=device) + + scale = torch.tensor([0.005], dtype=torch.float16, device=device) + zero_point = torch.tensor([0.0], dtype=torch.float16, device=device) + + # 1. 验证调用 (使用 Wrapper) + print("Function Call Check:") + C = matmul_dequant(A_quant, B, scale, zero_point, quant_mode="per_tensor") + print("Call Successful.") + + # 2. 性能测试 (为了准确计时,我们这里还是直接编译一次跑循环) + # 虽然 Wrapper 也能测,但包含了编译开销不太准,所以测速我们还是测 Kernel + tl_kernel = matmul_dequant_kernel(M, N, K) + print("Compiling for performance test...") + program = tl.compile(tl_kernel, target="maca") + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(100): + program(A_quant, B, C, scale, zero_point) + end.record() + + if hasattr(torch, "maca"): torch.maca.synchronize() + else: torch.cuda.synchronize() + + avg_time_ms = start.elapsed_time(end) / 100 + tflops = (2 * M * N * K) / (avg_time_ms / 1000) / 1e12 + + print(f"TileLang Time: {avg_time_ms:.3f} ms") + print(f"TileLang Performance: {tflops:.2f} TFLOPS") + + # 3. 结果验证 + print("Verifying results...") + A_dq = (A_quant.float() - zero_point.item()) * scale.item() + M_verify = 1024 + C_ref = torch.matmul(A_dq[:M_verify].half(), B) + diff = torch.abs(C[:M_verify] - C_ref).max().item() + print(f"Max Diff: {diff:.6f}") + + if diff < 1e-1: + print("✅ SUCCESS (Per-Tensor Verified)") + +if __name__ == "__main__": + run_benchmark() +``` + +2.Per-Channel code + +``` +import torch +import tilelang as tl +import tilelang.language as T + +def matmul_dequant_kernel(M, N, K, block_M=128, block_N=128, block_K=32): + # 定义数据类型 + dtype_a = "int8" + dtype_b = "float16" + dtype_c = "float16" + dtype_accum = "float32" + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype_a), + B: T.Tensor((K, N), dtype_b), + C: T.Tensor((M, N), dtype_c), + # [Per-Channel] Scale 和 ZP 是长度为 K 的向量 + Scale: T.Tensor((K,), dtype_b), + ZeroPoint: T.Tensor((K,), dtype_b), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + + # Shared Memory 分配 + A_shared = T.alloc_shared((block_M, block_K), dtype_a) + B_shared = T.alloc_shared((block_K, block_N), dtype_b) + # [Per-Channel] 为量化参数分配 Shared Memory + Scale_shared = T.alloc_shared((block_K,), dtype_b) + ZP_shared = T.alloc_shared((block_K,), dtype_b) + + C_local = T.alloc_fragment((block_M, block_N), dtype_accum) + T.clear(C_local) + + # 流水线循环 + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + + # [Stage 1] Global -> Shared 加载 + T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared) + T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared) + + # [Per-Channel] 加载当前 Block 对应的量化参数切片 + T.copy(Scale[k * block_K : (k + 1) * block_K], Scale_shared) + T.copy(ZeroPoint[k * block_K : (k + 1) * block_K], ZP_shared) + + # [Stage 2] Shared -> Register (A 矩阵) + A_int8_frag = T.alloc_fragment((block_M, block_K), dtype_a) + A_frag = T.alloc_fragment((block_M, block_K), dtype_b) + T.copy(A_shared, A_int8_frag) + + # [Stage 3] Dequantize (Per-Channel 逻辑) + # j 对应 K 维度,每个 j 使用独立的 Scale/ZP + for i, j in T.Parallel(block_M, block_K): + scale_val = Scale_shared[j] + zp_val = ZP_shared[j] + A_frag[i, j] = (T.cast(A_int8_frag[i, j], dtype_b) - zp_val) * scale_val + + # [Stage 4] Matrix Multiplication + T.gemm(A_frag, B_shared, C_local) + + # [Stage 5] 结果回写 + T.copy(C_local, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N]) + + return main + +def matmul_dequant(A, B, scale, zero_point, quant_mode="per_channel"): + """ + High-level Python wrapper. + Args: + quant_mode: "per_tensor" or "per_channel". + """ + M, K = A.shape + K_b, N = B.shape + assert K == K_b, "Matrix dimensions mismatch" + + # 实际项目中这里会判断 quant_mode 分发不同的 Kernel + # 本脚本专用于测试 Per-Channel + tl_kernel = matmul_dequant_kernel(M, N, K) + + # 编译 + program = tl.compile(tl_kernel, target="maca") + + # 构造输出 + C = torch.zeros(M, N, dtype=torch.float16, device=A.device) + + # 执行 + program(A, B, C, scale, zero_point) + + return C + +def run_benchmark(): + M, N, K = 4096, 4096, 4096 + + try: + device = torch.device("maca") + except: + device = torch.device("cuda") + + print(f"Running Per-Channel Benchmark via Wrapper on {device}...") + + # 数据构造 + A_quant = torch.randint(-128, 127, (M, K), dtype=torch.int8, device=device) + B = torch.randn(K, N, dtype=torch.float16, device=device) + + # [Per-Channel] 构造向量参数 (长度为 K) + scale = torch.rand(K, dtype=torch.float16, device=device) * 0.01 + 0.001 + zero_point = torch.randint(-10, 10, (K,), dtype=torch.float16, device=device) + + # 1. 验证调用 (使用 Wrapper) + print("Function Call Check:") + C = matmul_dequant(A_quant, B, scale, zero_point, quant_mode="per_channel") + print("Call Successful.") + + # 2. 性能测试 (编译 Kernel 进行测速) + tl_kernel = matmul_dequant_kernel(M, N, K) + print("Compiling for performance test...") + program = tl.compile(tl_kernel, target="maca") + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(100): + program(A_quant, B, C, scale, zero_point) + end.record() + + if hasattr(torch, "maca"): torch.maca.synchronize() + else: torch.cuda.synchronize() + + avg_time_ms = start.elapsed_time(end) / 100 + tflops = (2 * M * N * K) / (avg_time_ms / 1000) / 1e12 + + print(f"TileLang Time: {avg_time_ms:.3f} ms") + print(f"TileLang Performance: {tflops:.2f} TFLOPS") + + # 3. 结果验证 (Per-Channel 逻辑) + print("Verifying results...") + # PyTorch 广播: (M, K) - (1, K) -> (M, K) + A_dq = (A_quant.float() - zero_point.view(1, K)) * scale.view(1, K) + + M_verify = 1024 + C_ref = torch.matmul(A_dq[:M_verify].half(), B) + + diff = torch.abs(C[:M_verify] - C_ref).max().item() + print(f"Max Diff: {diff:.6f}") + + if diff < 1e-1: + print("✅ SUCCESS (Per-Channel Verified)") + else: + print("❌ FAIL") + +if __name__ == "__main__": + run_benchmark() +``` + +##### 3.2Performance + +###### 3.2.1Per-Tensor Performance (Baseline) + +``` +:width: 80% +:alt: Per-Tensor Performance +:align: center + +Figure 1: Per-Tensor mode achieving 28.32 TFLOPS (4.852 ms) with Max Diff 0.0625. +``` + +###### 3.2.2 Per-Channel Performance (High Accuracy) + +``` +:width: 80% +:alt: Per-Channel Performance +:align: center + +Figure 2: Per-Channel mode achieving 23.38 TFLOPS (5.878 ms) with Max Diff 0.0625. +``` + +#### 4.Performance Specifications + +##### 4.1Recommended Parameter Configuration + +| Feature | Recommendation | Reason | +| :--- | :--- | :--- | +| **Block Size** | `128x128x32` | Optimal balance for C500 Compute Units occupancy. | +| **Pipeline** | `num_stages=3` | Sufficient depth to hide Global Memory latency. | +| **Memory Scope** | **B in Shared** | MACA intrinsics require Operand B in Shared Memory. Loading B into Registers will cause compilation failure. | +| **Layout** | **Contiguous** | Ensure all input tensors are contiguous for vectorized loads. | + +##### 4.2Performance Optimization Recommendations + +1. **Matrix Size Optimization**:Optimal performance achieved when M, N, K are all greater than 256For small matrices, use batch processing mode +2. **Memory Layout Optimization**:Ensure input matrices have contiguous memoryPay attention to channel alignment in Per-Channel mode +3. **Precision Selection Recommendations**:Inference scenarios: Recommend float16 to balance precision and performanceTraining scenarios: Recommend float32 for numerical stability +4. **Quantization Parameter Optimization**:Per-Tensor mode offers higher computational efficiencyPer-Channel mode provides better precision with slightly higher computational overhead +