From e747666bc6e6b017b74a5aaa14fb43e878e32004 Mon Sep 17 00:00:00 2001 From: ZhaoShijie Date: Wed, 12 Nov 2025 00:02:28 +0800 Subject: [PATCH 1/2] add convolution doc and examples --- docs/deeplearning_operators/convolution.md | 250 +++++++++++++++++- examples/convolution/example_convolution.py | 4 + .../example_convolution_autotune.py | 4 +- 3 files changed, 254 insertions(+), 4 deletions(-) diff --git a/docs/deeplearning_operators/convolution.md b/docs/deeplearning_operators/convolution.md index 7477c56..9b045d6 100644 --- a/docs/deeplearning_operators/convolution.md +++ b/docs/deeplearning_operators/convolution.md @@ -1,2 +1,248 @@ -Convolution -=========== +# Convolution +============= + +
+ Contributor: @Shijie +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +Example code can be found at [`examples/convolution/example_convolution.py`](../../examples/convolution/example_convolution.py). +::: + +Convolution is a fundamental operation in deep learning, widely used for feature extraction in image processing, computer vision tasks, and neural networks like Convolutional Neural Networks (CNNs). It applies a kernel (filter) to an input tensor, computing dot products over sliding windows to produce an output feature map. + +## Operator Functionality + +### Application Scenarios +Convolution operators are essential in deep learning for tasks such as: +- **Image Feature Extraction**: Detecting edges, textures, and patterns in images. +- **Computer Vision**: Object detection, image classification, and segmentation in models like ResNet, VGG, and YOLO. +- **Signal Processing**: Applied to 1D signals (e.g., audio) or 3D data (e.g., video). + +### Core Computation Logic +The convolution operation computes the output feature map by sliding a kernel over the input tensor and performing element-wise multiplication followed by summation. For a 2D convolution, the output at position (i, j) is calculated as: + +$$ +\text{output}[i][j] = \sum_{m=0}^{k_h - 1} \sum_{n=0}^{k_w - 1} \text{input}[i \cdot s_h + m][j \cdot s_w + n] \cdot \text{kernel}[m][n] +$$ + +Where: +- \( k_h, k_w \): Kernel height and width. +- \( s_h, s_w \): Stride in height and width. +- Padding is applied to handle boundaries. + +This can be extended to multi-channel inputs and outputs with bias addition. + +## Interface Parameters + +### Input Parameters +- **Input Tensor**: Shape `(batch_size, in_channels, height, width)`, data type typically `float16` or `float32`. +- **Kernel (Weights)**: Shape `(out_channels, in_channels, kernel_height, kernel_width)`, data type matching input. +- **Bias (Optional)**: Shape `(out_channels,)`, data type matching input. + +### Output Parameters +- **Output Tensor**: Shape `(batch_size, out_channels, out_height, out_width)`, where: + - `out_height = (height + 2 * padding_h - kernel_height) // stride_h + 1` + - `out_width = (width + 2 * padding_w - kernel_width) // stride_w + 1` +- Data type matches input. + +### Optional Parameters +- **Kernel Size**: `(kernel_height, kernel_width)`, e.g., `(3, 3)`. +- **Stride**: `(stride_h, stride_w)`, e.g., `(1, 1)`. +- **Padding**: `(padding_h, padding_w)`, e.g., `(1, 1)`. +- **Dilation**: `(dilation_h, dilation_w)`, default `(1, 1)`. +- **Groups**: Integer for grouped convolution, default 1. + +## Usage Example + +Below is an example of using the convolution operator in `TileLang` within a MACA environment. The example is divided into sections for clarity: imports and parameter setup, data construction, operator call (compilation and execution), and result verification. + +### Imports and Parameter Setup + +```python +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse + +def check_hopper(): + if not torch.cuda.is_available(): + return None + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + +def ref_program(stride, padding, dilation): + def main(A, B): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + return main + +@tilelang.jit(out_idx=[2]) +def convolution(N, + C, + H, + W, + F, + K, + S, + D, + P, + block_M, + block_N, + block_K, + num_stages, + threads, + dtype="float16", + accum_dtype="float"): + # Define kernel dimensions + KH, KW = K, K + # Compute output dimensions + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + dtype = "float16" + accum_dtype = "float" + # Check if running on Hopper GPU for optimized im2col + is_hopper = check_hopper() + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + # Define kernel with block dimensions + with T.Kernel( + T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), + threads=threads) as (bx, by): + # Allocate shared memory for data, kernel, and output + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + # Flatten kernel for efficient access + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + # Annotate layouts for swizzled memory access + T.annotate_layout({ + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + }) + + # Clear output accumulator + T.clear(out_local) + # Pipelined loop over K dimension + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + # Load data using im2col for Hopper or manual indexing otherwise + if is_hopper: + T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P) + else: + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + # Compute input coordinates + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and + (access_w < W)) + data_shared[i, j] = T.if_then_else( + in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + # Copy kernel to shared memory + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + # Perform GEMM operation + T.gemm(data_shared, kernel_shared, out_local) + + # Copy results back to global memory + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main +``` + +### Data Construction + +```python +def main(argv=None): + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument('--n', type=int, default=128, help='n') + parser.add_argument('--c', type=int, default=128, help='c') + parser.add_argument('--h', type=int, default=64, help='h') + parser.add_argument('--w', type=int, default=64, help='w') + parser.add_argument('--f', type=int, default=128, help='f') + parser.add_argument('--k', type=int, default=3, help='k') + parser.add_argument('--s', type=int, default=1, help='s') + parser.add_argument('--d', type=int, default=1, help='d') + parser.add_argument('--p', type=int, default=1, help='p') + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + # Generate random input and kernel tensors + a = torch.randn(N, H, W, C).cuda().half() + b = torch.randn(K, K, C, F).cuda().half() + + a = a.contiguous() + b = b.contiguous() +``` + +### Operator Call (Compilation and Execution) + +```python + # Set block and thread parameters + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + # Compile the convolution kernel + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + + # Execute the kernel + out_c = kernel(a, b) +``` + +### Result Verification + +```python + # Generate reference output using PyTorch + ref_c = ref_program(S, P, D)(a, b) + # Verify results match within tolerance + torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") + + +if __name__ == "__main__": + main() +``` + +This example demonstrates an optimized convolution kernel using im2col and GEMM. For production use, further optimizations and hardware-specific tuning refer to the example code at [`examples/convolution/example_convolution_autotune.py`](../../examples/convolution/example_convolution_autotune.py ). + +## Performance Notes + +### Recommended Configurations on Specific Devices + +- **GPU Model**: 曦云 C500-16GB +- **Optimal Config**: {'block_M': 64, 'block_N': 128, 'block_K': 32, 'num_stages': 0, 'thread_num': 256, 'enable_rasteration': True} + - **Note:** This optimal configuration was obtained by running `python examples/convolution/example_convolution_autotune.py` with the example's default convolution parameters. It applies only to those default parameters. To obtain the best configuration for other convolution settings, re-run the autotuner and specify the desired convolution arguments via command-line options (for example: `--n`, `--c`, `--h`, `--w`, `--f`, `--k`, `--s`, `--d`, `--p`). + + +### Performance Optimization Suggestions +- **Tiling**: Divide input and output into tiles to fit in shared memory, reducing redundant loads. +- **Vectorized Loads**: Use `tl.vectorized` for loading kernel and input data in chunks (e.g., float4 for float16). +- **Autotuning**: Use `tilelang.autotune` to search optimal block sizes, thread counts, and tiling factors. +- **Fusions**: Combine convolution with activation functions (e.g., ReLU) to reduce kernel launches. +- **Hardware Awareness**: Leverage tensor cores on NVIDIA GPUs or Metax GPUs for mixed-precision computations. + +For advanced implementations, refer to the example code at [`examples/convolution/example_convolution.py`](../../examples/convolution/example_convolution.py ). Contributions to improve this documentation are welcome! \ No newline at end of file diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index e37dac2..75422f6 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -111,6 +111,9 @@ def main(argv=None): a = torch.randn(N, H, W, C).cuda().half() b = torch.randn(K, K, C, F).cuda().half() + a = a.contiguous() + b = b.contiguous() + block_m = 64 block_n = 128 block_k = 32 @@ -123,6 +126,7 @@ def main(argv=None): out_c = kernel(a, b) ref_c = ref_program(S, P, D)(a, b) torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed.✅") if __name__ == "__main__": diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index bb3a147..9aa4c8f 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -6,7 +6,7 @@ from tilelang.autotuner import * import tilelang.language as T from tilelang.autotuner import AutoTuner from tilelang.carver.template import ConvTemplate -from tilelang.carver.arch import CUDA +from tilelang.carver.arch import MACA from tilelang.carver.roller.rasterization import NoRasterization @@ -32,7 +32,7 @@ def ref_program(stride, padding, dilation): def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15): if with_roller: - arch = CUDA("cuda") + arch = MACA("maca") carve_template = ConvTemplate( N=N, C=C, -- Gitee From 86baa023dd11ca2e64b6e1906f8142c9d7ea9c60 Mon Sep 17 00:00:00 2001 From: Zhao Shijie Date: Wed, 19 Nov 2025 22:14:08 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=9C=A864GB=20C500=E4=B8=8A=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/deeplearning_operators/convolution.md | 4 ++-- examples/convolution/example_convolution.py | 2 +- examples/convolution/example_convolution_autotune.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/deeplearning_operators/convolution.md b/docs/deeplearning_operators/convolution.md index 9b045d6..90b764e 100644 --- a/docs/deeplearning_operators/convolution.md +++ b/docs/deeplearning_operators/convolution.md @@ -233,8 +233,8 @@ This example demonstrates an optimized convolution kernel using im2col and GEMM. ### Recommended Configurations on Specific Devices -- **GPU Model**: 曦云 C500-16GB -- **Optimal Config**: {'block_M': 64, 'block_N': 128, 'block_K': 32, 'num_stages': 0, 'thread_num': 256, 'enable_rasteration': True} +- **GPU Model**: 曦云 C500-64GB +- **Optimal Config**: {'block_M': 64, 'block_N': 128, 'block_K': 32, 'num_stages': 1, 'thread_num': 256, 'enable_rasteration': True} - **Note:** This optimal configuration was obtained by running `python examples/convolution/example_convolution_autotune.py` with the example's default convolution parameters. It applies only to those default parameters. To obtain the best configuration for other convolution settings, re-run the autotuner and specify the desired convolution arguments via command-line options (for example: `--n`, `--c`, `--h`, `--w`, `--f`, `--k`, `--s`, `--d`, `--p`). diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index 75422f6..8a715ea 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -126,7 +126,7 @@ def main(argv=None): out_c = kernel(a, b) ref_c = ref_program(S, P, D)(a, b) torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2) - print("All checks passed.✅") + print("All checks passed.") if __name__ == "__main__": diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 9aa4c8f..a012275 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -291,14 +291,15 @@ def main(n: int = 128, with_roller: bool = True): N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p ref_prog = ref_program(S, P, D) - use_autotune = True + # use_autotune = True if use_autotune: result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller) print(result.config) kernel = result.kernel else: config = get_heuristic_config() - kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2]) + kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, config["block_M"], config["block_N"], config["block_K"], + config["num_stages"], config["thread_num"]), out_idx=[2]) profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) tilelang_latency = profiler.do_bench() @@ -327,7 +328,7 @@ if __name__ == "__main__": parser.add_argument( "--with_roller", action="store_true", - default=True, + default=False, help="Whether to enable BitBLAS roller for search space") args = parser.parse_args() main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, -- Gitee