diff --git a/aikg/python/ai_kernel_generator/config/default_tilelang_ascendc_pto_config.yaml b/aikg/python/ai_kernel_generator/config/default_tilelang_ascendc_pto_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e76f1f410aadd72b98a15a24627c1d0a99e297e --- /dev/null +++ b/aikg/python/ai_kernel_generator/config/default_tilelang_ascendc_pto_config.yaml @@ -0,0 +1,30 @@ +# Model preset configuration +agent_model_config: + designer: vllm_deepseek_v31_default + coder: vllm_deepseek_v31_default + conductor: vllm_deepseek_v31_default + api_generator: vllm_deepseek_v31_default + example_compressor: vllm_deepseek_v31_default + feature_extractor: vllm_deepseek_v31_default + sketch: vllm_deepseek_v31_default + default: vllm_deepseek_v31_default + +# Log configuration +log_dir: "~/aikg_logs" + +# Workflow configuration +workflow_config_path: "config/default_workflow.yaml" + +# Documentation directory configuration +docs_dir: + designer: "resources/docs/sketch_docs" + coder: "resources/docs/tilelang_ascendc_pto_docs" + +# Performance analysis configuration +profile_settings: + run_times: 50 + warmup_times: 5 + +# Verification configuration +verify_timeout: 300 # Timeout for verification in seconds (default 5 minutes) + diff --git a/aikg/python/ai_kernel_generator/config/vllm_tilelang_ascendc_pto_coderonly_config.yaml b/aikg/python/ai_kernel_generator/config/vllm_tilelang_ascendc_pto_coderonly_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c0f6612b702909a7ffb563e170c43ea5a42447a --- /dev/null +++ b/aikg/python/ai_kernel_generator/config/vllm_tilelang_ascendc_pto_coderonly_config.yaml @@ -0,0 +1,27 @@ +# Model preset configuration +agent_model_config: + designer: vllm_deepseek_v31_default + coder: vllm_deepseek_v31_default + conductor: vllm_deepseek_v31_default + api_generator: vllm_deepseek_v31_default + feature_extractor: vllm_deepseek_v31_default + sketch: vllm_deepseek_v31_default + +# Log configuration +log_dir: "~/aikg_logs" + +# Workflow configuration +workflow_config_path: "config/coder_only_workflow.yaml" + +# Documentation directory configuration +docs_dir: + designer: "resources/docs/sketch_docs" + coder: "resources/docs/tilelang_ascendc_pto_docs" + +# Performance analysis configuration +profile_settings: + run_times: 50 + warmup_times: 5 + +# Verification configuration +verify_timeout: 300 # Timeout for verification in seconds (default 5 minutes) \ No newline at end of file diff --git a/aikg/python/ai_kernel_generator/core/utils.py b/aikg/python/ai_kernel_generator/core/utils.py index e7e9ae351507013a5402b02be746cd09c121279d..538066706652d58ac0781226c61f2e16c5be3765 100644 --- a/aikg/python/ai_kernel_generator/core/utils.py +++ b/aikg/python/ai_kernel_generator/core/utils.py @@ -50,7 +50,7 @@ def normalize_dsl(dsl: str, backend: str = None) -> str: dsl = dsl.lower() # 如果已经是规范化的类型,直接返回 - if dsl in ["triton_cuda", "triton_ascend", "triton-russia", "swft", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "ascendc"]: + if dsl in ["triton_cuda", "triton_ascend", "triton-russia", "swft", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "tilelang_ascendc_pto", "ascendc"]: return dsl # 如果是通用的triton,需要根据backend转换 @@ -82,7 +82,7 @@ def check_dsl(dsl: str): Args: dsl: 实现类型(triton_cuda/triton_ascend/triton-russia/swft等) """ - valid_dsls = ["triton_cuda", "triton_ascend", "triton-russia", "swft", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "ascendc"] + valid_dsls = ["triton_cuda", "triton_ascend", "triton-russia", "swft", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "tilelang_ascendc_pto", "ascendc"] if dsl not in valid_dsls: raise ValueError( f"dsl must be one of {valid_dsls}. " @@ -116,11 +116,11 @@ VALID_CONFIGS = { }, "torch": { "ascend": { - "ascend910b1": ["triton_ascend", "triton-russia", "tilelang_npuir", "ascendc"], - "ascend910b2": ["triton_ascend", "triton-russia", "tilelang_npuir", "ascendc"], - "ascend910b2c": ["triton_ascend", "triton-russia", "tilelang_npuir", "ascendc"], - "ascend910b3": ["triton_ascend", "triton-russia", "tilelang_npuir", "ascendc"], - "ascend910b4": ["triton_ascend", "triton-russia", "tilelang_npuir", "ascendc"], + "ascend910b1": ["triton_ascend", "triton-russia", "tilelang_npuir", "tilelang_ascendc_pto", "ascendc"], + "ascend910b2": ["triton_ascend", "triton-russia", "tilelang_npuir", "tilelang_ascendc_pto", "ascendc"], + "ascend910b2c": ["triton_ascend", "triton-russia", "tilelang_npuir", "tilelang_ascendc_pto", "ascendc"], + "ascend910b3": ["triton_ascend", "triton-russia", "tilelang_npuir", "tilelang_ascendc_pto", "ascendc"], + "ascend910b4": ["triton_ascend", "triton-russia", "tilelang_npuir", "tilelang_ascendc_pto", "ascendc"], "ascend310p3": ["swft", "ascendc"] }, "cuda": { diff --git a/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/tilelang_ascendc_pto.py b/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/tilelang_ascendc_pto.py new file mode 100644 index 0000000000000000000000000000000000000000..0fddaa29631513eaacb692f4e467c95f46e4df36 --- /dev/null +++ b/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/tilelang_ascendc_pto.py @@ -0,0 +1,100 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TileLang AscendC PTO DSL adapter.""" + +from typing import Any, Optional + +from .base import DSLAdapter + + +class DSLAdapterTilelangAscendcPto(DSLAdapter): + """Adapter for TileLang AscendC PTO DSL.""" + + def get_import_statements(self, framework: str) -> str: + """Return TileLang AscendC PTO import statements.""" + code = "import tilelang\ntilelang.cache.clear_cache()\n" + if framework == "torch": + code += "import torch\nimport torch_npu\n" + return code + + def get_impl_import(self, op_name: str, impl_func_name: str) -> str: + """Return implementation function import. + + 统一使用 ModelNew 类格式(KernelBench 风格)。 + """ + return f"from {op_name}_tilelang_ascendc_pto import ModelNew\n" + + def create_impl_module(self, framework: str, + framework_adapter: Any, + init_params_var: str = "init_params", + device_var: str = "device") -> str: + """生成创建 impl_model 的代码(只实例化一次)。 + + Args: + framework: Framework name (torch, mindspore, numpy) + framework_adapter: Framework adapter instance + init_params_var: Variable name for init_params (default: "init_params") + device_var: Variable name for device (default: "device") + + Returns: + str: Code string to create impl_model + """ + code = f"impl_model = ModelNew(*{init_params_var})\n" + if framework == "torch": + code += f"impl_model = impl_model.to({device_var})\n" + + return code + + def call_impl(self, impl_func_name: str, inputs: str, device_id: int, + framework_adapter: Any, op_name: str, + data_dir: Optional[str] = None, + framework_output: Optional[str] = None) -> str: + """Return code string to call Triton Ascend implementation function. + + 调用已经实例化好的 impl_model(可以多次调用)。 + """ + return f"impl_output = impl_model(*{inputs})\n" + + def needs_binary_io(self) -> bool: + """TileLang AscendC PTO doesn't need binary I/O.""" + return False + + def needs_compilation(self) -> bool: + """TileLang AscendC PTO doesn't need compilation.""" + return False + + def benchmark_impl(self, impl_func_name: str, inputs: str, + warmup: int, runs: int, backend: str, op_name: str, + case_idx: int = 0, framework_model: Optional[str] = None, + framework_adapter: Optional[Any] = None, + device_id: Optional[int] = None) -> str: + """Return code string to benchmark TileLang AscendC PTO implementation.""" + # Similar to tilelang_npuir, use traditional timing + sync_code = "torch.npu.synchronize()" if backend == "ascend" else "" + code = f""" # dsl:tilelang_ascendc_pto + import time + start_time = time.time() + for _ in range({warmup + runs}): + framework_output = {framework_model}(*{inputs}) + {sync_code} + end_time = time.time() + execution_time_ms = (end_time - start_time) * 1000 / {warmup + runs} # 转换为毫秒 + method = "traditional_timing" +""" + return code + + def get_special_setup_code(self) -> str: + """Return special setup code for tilelang_ascendc_pto.""" + return "import tilelang\ntilelang.cache.clear_cache()\n" diff --git a/aikg/python/ai_kernel_generator/core/verifier/adapters/factory.py b/aikg/python/ai_kernel_generator/core/verifier/adapters/factory.py index 72f8e7357f109fc5ded4614b88ec91c915cb0e84..589a368398b36a0edabceed86b4b586f3a91be99 100644 --- a/aikg/python/ai_kernel_generator/core/verifier/adapters/factory.py +++ b/aikg/python/ai_kernel_generator/core/verifier/adapters/factory.py @@ -74,6 +74,9 @@ def get_dsl_adapter(dsl: str): elif dsl_lower == "tilelang_cuda": from .dsl.tilelang_cuda import DSLAdapterTilelangCuda return DSLAdapterTilelangCuda() + elif dsl_lower == "tilelang_ascendc_pto": + from .dsl.tilelang_ascendc_pto import DSLAdapterTilelangAscendcPto + return DSLAdapterTilelangAscendcPto() else: raise ValueError(f"Unsupported DSL: {dsl}") diff --git a/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py b/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py index 7bb8a8a8f77a0dbbaf254fcea7c8577b87ae724a..8c67f318e7f939612da707f70cec9f65cd095eb3 100644 --- a/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py +++ b/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py @@ -47,7 +47,7 @@ RUN_TEMPLATE_PATH = os.path.join(get_project_root(), "utils", "compile_tools", " # 类型定义 FrameworkType = Literal["torch", "mindspore", "numpy"] -ImplType = Literal["triton_cuda", "triton_ascend", "triton-russia", "swft", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "ascendc"] +ImplType = Literal["triton_cuda", "triton_ascend", "triton-russia", "swft", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "tilelang_ascendc_pto", "ascendc"] BackendType = Literal["cuda", "ascend", "cpu"] ArchType = Literal["a100", "v100", "h20", "l20", "rtx3090", "ascend910b4", "ascend310p3", "x86_64", "aarch64"] @@ -158,6 +158,15 @@ class KernelVerifier: "import tilelang", "import tilelang.language as T" ] + elif self.dsl == "tilelang_ascendc_pto": + if self.framework == "torch": + import_lines = [ + "import torch", + "import torch_npu", + "import tilelang", + "import tilelang.language as T" + "from tilelang.intrinsics import make_zn_layout" + ] elif self.dsl == "swft": import_lines = [ "from swft.core import *", @@ -276,7 +285,7 @@ class KernelVerifier: raise # 创建具体实现文件 - if "ascendc" in self.dsl: + if self.dsl == "ascendc": logger.info(f"[{self.op_name}] 检测到AscendC DSL,生成编译项目") self.generate_ascendc_project(impl_code, verify_dir) else: diff --git a/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/api/api.md b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/api/api.md new file mode 100644 index 0000000000000000000000000000000000000000..a9a14ae79c0653c544737eb0737cff5b84ffbb3c --- /dev/null +++ b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/api/api.md @@ -0,0 +1,379 @@ +# TileLang-Ascend API 速查 + +## 1. 内核定义 + +### @tilelang.jit 装饰器 +```python +@tilelang.jit(out_idx=[-1], pass_configs={...}) +def kernel_func(M, N, K, ...): + @T.prim_func + def main(...): + ... + return main # 返回主计算函数对象供调用 +``` +- **作用**: 标识函数会被JIT动态编译 +- `out_idx=[-1]`: 表示主计算函数返回参数索引 +- `pass_configs`: 可选的Pass配置 +- `return main`: 返回主计算函数对象供调用 + +### @T.prim_func 装饰器 +```python +@T.prim_func +def main(A: T.Tensor((M, N), dtype), ...): +``` +- **作用**: 在模块内定义主计算函数 +- 函数名推荐为`main` + +### T.Tensor +```python +A: T.Tensor((M, N), dtype) +``` +- **作用**: 声明数据类型Tensor数据缓冲区,并指定其形状和数据类型 +- 形状用元组`(M, N)`定义 +- `dtype`: 数据类型,如`"float16"`, `"float32"` + +### T.Kernel 原语 +```python +with T.Kernel(core_num, is_npu=True) as (cid, vid): +``` +- **作用**: 触发算子kernel调用,对应Ascend C的kernel调用 +- `core_num`: AI Core数量(编译时Python整数),对应Ascend C中`<<<...>>>`的第一个参数 +- `is_npu=True`: 必须设置 + +**cid 与 vid 的区别**: +| 参数 | 范围 | 说明 | +|------|------|------| +| `cid` | [0, core_num) | 核心ID,标识当前运行在哪个AI Core上 | +| `vid` | 0 或 1 | Vector核索引,因A2的Cube核与Vector核配比为1:2,用于指定当前Vector核的索引 | + +> **注意**: 如果不使用vid,可以用`_`占位,如`(cid, _)` + +## 2. 内存分配 + +```python +# L1缓存 (一级缓存) +A_L1 = T.alloc_L1((S1, block_M, K_L1), dtype) + +# L0A矩阵A缓存 +A_L0 = T.alloc_L0A((S2, block_M, block_K), dtype) + +# L0B矩阵B缓存 +B_L0 = T.alloc_L0B((S2, block_K, block_N), dtype) + +# L0C累加器 +C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype) + +# UB统一缓冲区 (向量操作) +A_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) +``` + +| 内存层级 | 分配API | 用途 | +|---------|---------|------| +| L1 | `T.alloc_L1()` | 一级缓存,存储矩阵数据 | +| L0A | `T.alloc_L0A()` | 矩阵A缓存 | +| L0B | `T.alloc_L0B()` | 矩阵B缓存 | +| L0C | `T.alloc_L0C()` | 矩阵累加器 | +| UB | `T.alloc_ub()` | 统一缓冲区,向量操作 | + +## 3. 数据传输 + +```python +# 标准数据拷贝 +T.copy(A[bx * block_M, 0], A_L1[0, :, :]) +T.copy(A_L1[k % S1, 0, kk * block_K], A_L0[kk % S2, :, :]) +T.copy(C_L0, C[bx * block_M, by * block_N]) +``` + +## 4. 矩阵运算 + +TileLang提供两种矩阵乘法API: + +| API | 特点 | 适用场景 | +|-----|------|----------| +| `T.gemm_v0` | 高级API,自动处理L0A/L0B | 快速原型开发、中小矩阵 | +| `T.mma` | 底层API,需手动管理L0A/L0B | 性能优化、大矩阵 | + +```python +# 高级API(自动处理L0A/L0B搬运) +T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) + +# 底层API(需手动搬运到L0A/L0B) +T.mma(A_L0, B_L0, C_L0, init=True) # C = A × B (首次初始化) +T.mma(A_L0, B_L0, C_L0) # C = C + A × B (累加) +``` +- `init=True`: 首次计算时初始化累加器,默认为累加模式 +- 详细对比见 `basic_docs.md` 中的 GEMM 实现对比分析章节 + +## 5. 向量运算 + +### 基础算术运算 +```python +# 二元运算 +T.add(c_ub, a_ub, b_ub) # c = a + b (向量加法) +T.sub(c_ub, a_ub, b_ub) # c = a - b (向量减法) +T.mul(c_ub, a_ub, b_ub) # c = a * b (向量乘法) +T.div(c_ub, a_ub, b_ub) # c = a / b (向量除法) +T.max(c_ub, a_ub, b_ub) # c = max(a, b) (逐元素取最大) +T.min(c_ub, a_ub, b_ub) # c = min(a, b) (逐元素取最小) + +# 标量运算 +T.mul(c_ub, a_ub, scalar) # c = a * scalar (与标量相乘) +``` + +### 一元数学函数 +```python +T.exp(b_ub, a_ub) # b = exp(a) (指数运算) +T.sqrt(b_ub, a_ub) # b = sqrt(a) (平方根) +T.rsqrt(b_ub, a_ub) # b = 1/sqrt(a) (逆平方根) +T.ln(b_ub, a_ub) # b = ln(a) (自然对数) +T.abs(b_ub, a_ub) # b = |a| (绝对值) +T.reciprocal(b_ub, a_ub) # b = 1/a (倒数) +T.relu(b_ub, a_ub) # b = max(0, a) (ReLU激活) +``` + +### 规约运算 +```python +# 沿指定维度规约 +T.reduce_max(out_ub, in_ub, tmp_ub, dim=-1) # 最大值规约 +T.reduce_sum(out_ub, in_ub, tmp_ub, dim=-1) # 求和规约 + +# 块规约运算(带详细参数控制) +T.block_reduce_sum(b_ub, a_ub, repeat, mask, dstRepStride, srcBlkStride, srcRepStride) +T.block_reduce_max(b_ub, a_ub, repeat, mask, dstRepStride, srcBlkStride, srcRepStride) +``` + +### 其他向量操作 +```python +T.fill(a_ub, value) # 用value填充张量 +T.copy(src, dst) # 数据拷贝(也可用于向量单元内部数据传输) +T.cast_tl(b_ub, a_ub, mode, count) # 类型转换,mode如"CAST_RINT" +``` + +## 6. 作用域与同步 + +### 计算作用域 +- **作用**: 标识代码的计算单元 +- `"C"`: 表示在Cube核上运行 +- `"V"`: 表示在Vector核上运行 + +```python +with T.Scope("C"): # 在Cube核上运行 + T.mma(...) + +with T.Scope("V"): # 在Vector核上运行 + T.add(...) + T.copy(...) +``` + +### 流水线同步标志 +```python +# 初始化标志 +@T.macro +def init_flag(): + T.set_flag("mte2", "mte1", 0) + T.set_flag("mte2", "mte1", 1) + T.set_flag("mte1", "m", 0) + T.set_flag("mte1", "m", 1) + T.set_flag("fix", "m", 0) + +# 设置标志 +T.set_flag("mte2", "mte1", 0) + +# 等待标志 +T.wait_flag("mte1", "mte2", 0) + +# 全局屏障 +T.barrier_all() +``` + +**管道标识符**: +- `mte1`: MTE1内存搬运单元 +- `mte2`: MTE2内存搬运单元 +- `m`: 矩阵计算单元 +- `fix`: 固定流水线 + +## 7. 循环控制 + +```python +# 串行循环 +for k in T.serial(loop_k): + ... + +# 向上取整除法 +n_iter = T.ceildiv(K, block_K) +``` + +## 8. 布局优化 + +### T.annotate_layout +```python +from tilelang.intrinsics import make_zn_layout + +T.annotate_layout({ + A_L1: make_zn_layout(A_L1), + B_L1: make_zn_layout(B_L1), +}) +``` +- **作用**: 为L1缓存中的张量指定内存布局格式,用于矩阵计算的数据排布优化 +- **参数**: 接收一个字典,键为张量变量,值为对应的布局函数返回值 +- **make_zn_layout**: 生成ZN格式布局,这是昇腾AI处理器矩阵运算的推荐布局格式,可以提升数据访问效率 +- **使用场景**: 在使用底层`T.mma`接口进行矩阵乘法时,需要配合布局标注以获得最佳性能 + +### T.use_swizzle +```python +# Swizzle优化 - 任务块到核心的映射优化 +cid = T.use_swizzle(i * core_num + cid, M, N, K, block_M, block_N, off=3, in_loop=True) +``` +- **作用**: 优化计算任务块(tile)到AI Core的映射策略,改善数据局部性和缓存命中率 +- **参数**: + - 第一个参数: 线性化的任务索引 + - `M, N, K`: 矩阵维度 + - `block_M, block_N`: 分块大小 + - `off`: 偏移参数,控制swizzle模式 + - `in_loop`: 是否在循环内使用 +- **返回值**: 重新映射后的核心ID(`cid`) + +## 9. 宏定义 + +### @T.macro 装饰器 +```python +@T.macro +def init_flag(): + """初始化流水线同步标志""" + T.set_flag("mte1", "mte2", 0) + T.set_flag("mte1", "mte2", 1) + T.set_flag("m", "mte1", 0) + T.set_flag("m", "mte1", 1) + T.set_flag("fix", "m", 0) + +@T.macro +def clear_flag(): + """清理流水线同步标志""" + T.wait_flag("mte1", "mte2", 0) + T.wait_flag("mte1", "mte2", 1) + T.wait_flag("m", "mte1", 0) + T.wait_flag("m", "mte1", 1) + T.wait_flag("fix", "m", 0) +``` +- **作用**: 定义可复用的操作序列宏,编译时会将宏调用内联展开到调用位置 +- **使用方法**: + 1. 在`@T.prim_func`装饰的函数**外部**定义宏 + 2. 在`@T.prim_func`函数**内部**直接调用宏名称(如`init_flag()`) +- **典型用途**: + - 流水线同步标志的初始化和清理 + - 重复使用的数据搬运序列 + - 封装复杂的同步逻辑 + + +## 10. 高级配置 + +### 自动同步插入 +```python +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True +} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + @T.prim_func + def main(...): + T.func_attr({"enable_auto_sync": True}) # 启用自动同步 + ... +``` + +### 自动Buffer重用 +```python +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True +} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + # 无需手动配置 T.annotate_address + ... +``` + +## 11. API参考 + +### 11.1 算术运算 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.add(dst, src1, src2)` | 加法 | `T.add(c_ub, a_ub, b_ub)` | +| `T.sub(dst, src1, src2)` | 减法 | `T.sub(c_ub, a_ub, b_ub)` | +| `T.mul(dst, src1, src2)` | 乘法 | `T.mul(c_ub, a_ub, b_ub)` | +| `T.div(dst, src1, src2)` | 除法 | `T.div(c_ub, a_ub, b_ub)` | +| `T.axpy(dst, src, scalar)` | 标量乘加 | `T.axpy(b_ub, a_ub, 2.0)` | + +### 11.2 数学函数 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.abs(dst, src)` | 绝对值 | `T.abs(b_ub, a_ub)` | +| `T.exp(dst, src)` | 指数 | `T.exp(b_ub, a_ub)` | +| `T.sqrt(dst, src)` | 平方根 | `T.sqrt(b_ub, a_ub)` | +| `T.rsqrt(dst, src)` | 逆平方根 | `T.rsqrt(b_ub, a_ub)` | +| `T.reciprocal(dst, src)` | 倒数 | `T.reciprocal(b_ub, a_ub)` | +| `T.ln(dst, src)` | 自然对数 | `T.ln(b_ub, a_ub)` | + +### 11.3 激活函数 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.relu(dst, src)` | ReLU | `T.relu(b_ub, a_ub)` | +| `T.leaky_relu(dst, src, alpha)` | LeakyReLU | `T.leaky_relu(b_ub, a_ub, 0.01)` | + +### 11.4 比较与选择 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.max(dst, src1, src2)` | 取最大值 | `T.max(c_ub, a_ub, b_ub)` | +| `T.min(dst, src1, src2)` | 取最小值 | `T.min(c_ub, a_ub, b_ub)` | +| `T.compare(dst, src1, src2, mode)` | 比较 | `T.compare(c_ub, a_ub, b_ub, "LT")` | +| `T.select(dst, mask, src1, src2, mode)` | 选择 | `T.select(c_ub, mask, a_ub, b_ub, "VSEL_CMPMASK_SPR")` | + +### 11.5 归约操作 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.reduce_max(dst, src, tmp, dim)` | 最大值归约 | `T.reduce_max(m_i, acc_s_ub, tmp_ub, dim=-1)` | +| `T.reduce_sum(dst, src, tmp, dim)` | 求和归约 | `T.reduce_sum(sum_ub, acc_s_ub, tmp_ub, dim=-1)` | +| `T.block_reduce_sum(dst, src, ...)` | 块归约求和 | `T.block_reduce_sum(b_ub, a_ub, repeat, mask, ...)` | +| `T.block_reduce_max(dst, src, ...)` | 块归约最大值 | `T.block_reduce_max(b_ub, a_ub, repeat, mask, ...)` | + +### 11.6 矩阵运算 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.gemm_v0(A, B, C, init, transpose_B)` | 高级矩阵乘法 | `T.gemm_v0(A_L1, B_L1, C_L0, init=True)` | +| `T.mma(A, B, C, init)` | 底层矩阵乘加 | `T.mma(A_L0, B_L0, C_L0, init=True)` | + +### 11.7 数据操作 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.copy(src, dst)` | 数据拷贝 | `T.copy(A[offset], a_ub)` | +| `T.fill(dst, value)` | 填充常量 | `T.fill(a_ub, 0.0)` | +| `T.transpose(dst, src)` | 转置 | `T.transpose(b_ub, a_ub)` | +| `T.gather(dst, src, indices, axis)` | 聚集 | `T.gather(c_ub, a_ub, indices_ub, 0)` | +| `T.cast_tl(dst, src, mode, count)` | 类型转换 | `T.cast_tl(b_ub, a_ub, "CAST_RINT", 4096)` | + +### 11.8 位运算 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.and_tl(dst, src1, src2)` | 按位与 | `T.and_tl(c_ub, a_ub, b_ub)` | +| `T.or_tl(dst, src1, src2)` | 按位或 | `T.or_tl(c_ub, a_ub, b_ub)` | +| `T.not_tl(dst, src)` | 按位非 | `T.not_tl(b_ub, a_ub)` | +| `T.shiftleft(dst, src, bits)` | 左移 | `T.shiftleft(b_ub, a_ub, 2)` | +| `T.shiftright(dst, src, bits)` | 右移 | `T.shiftright(b_ub, a_ub, 2)` | + +### 11.9 同步操作 + +| API | 功能 | 示例 | +|-----|------|------| +| `T.barrier_all()` | 全局同步 | `T.barrier_all()` | +| `T.set_flag(producer, consumer, stage)` | 设置流水线标志 | `T.set_flag("mte2", "mte1", 0)` | +| `T.wait_flag(producer, consumer, stage)` | 等待流水线标志 | `T.wait_flag("mte1", "mte2", 0)` | +| `T.set_cross_flag(unit, flag)` | 设置跨核心标志 | `T.set_cross_flag("FIX", 0)` | +| `T.wait_cross_flag(flag)` | 等待跨核心标志 | `T.wait_cross_flag(0)` | diff --git a/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/basic_docs.md b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/basic_docs.md new file mode 100644 index 0000000000000000000000000000000000000000..da950a4cda7f4363873681ea640932a455d6f0f2 --- /dev/null +++ b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/basic_docs.md @@ -0,0 +1,783 @@ +# TileLang-Ascend 编程基础 + +本文档介绍TileLang-Ascend的核心概念和编程模式,专为华为Ascend NPU设计。TileLang-Ascend是Ascend TileLang适配器,提供类Python的DSL语法来编写高性能NPU内核。 + +## 1. 核心概念 + +### 1.1 内核定义 + +```python +import tilelang +import tilelang.language as T + +@tilelang.jit(out_idx=[-1]) +def create_kernel(M, N, K, block_M, block_N, block_K, dtype="float16"): + # 编译时计算核心数 + core_num = M // block_M * N // block_N + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(core_num, is_npu=True) as (cid, vid): + # 内核实现 + pass + + return main +``` + +**关键点**: +- `@tilelang.jit(out_idx=[-1])`: JIT编译装饰器,`out_idx`指定输出张量索引 +- `@T.prim_func`: 内核函数装饰器 +- `T.Tensor((M, K), dtype)`: 使用元组定义张量形状 +- `T.Kernel(core_num, is_npu=True)`: 启动NPU内核,必须设置`is_npu=True` +- `(cid, vid)`: 核心ID和向量核心ID(Vector ID) + +### 1.2 NPU内核启动 + +**关键约束**: +- 必须设置`is_npu=True` +- **只支持一维网格** +- **Grid Size必须是Python整数或Python整数运算**(最重要!) + +**正确模式** - 在`@T.prim_func`外部使用Python整数运算: +```python +def create_kernel(M, N, block_M, block_N, dtype="float32"): + # ✅ 在外部用Python整数运算(编译时确定) + m_num = M // block_M # Python整数除法 + n_num = N // block_N # Python整数除法 + core_num = m_num * n_num # Python整数乘法 + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), ...): + # ✅ 直接使用外部的Python整数 + with T.Kernel(core_num, is_npu=True) as (cid, vid): + # cid: 核心索引 (core index) + # vid: 向量核心索引 (vector core id) + ... +``` + +**错误模式** - 使用TileLang DSL表达式(如`T.ceildiv`、`T.min`等): +```python +def create_kernel(M, N, block_size, dtype="float32"): + @T.prim_func + def main(A: T.Tensor((M, N), dtype), ...): + # ❌ T.ceildiv是DSL表达式,不是Python整数 + n_blocks = T.ceildiv(M * N, block_size) + + with T.Kernel(n_blocks, is_npu=True) as (cid, _): + # ❌ 导致编译错误 + +# 同样错误的写法: +with T.Kernel(T.ceildiv(M, block_M), is_npu=True) # ❌ 不能用T.ceildiv +with T.Kernel(T.min(M, 32), is_npu=True) # ❌ 不能用T.min +``` + +### 1.3 NPU内存层次 + +| 分配API | Ascend内存 | 用途 | +|---------|------------|------| +| 张量参数 | GM | 全局内存 | +| `T.alloc_L1()` | L1 | 一级缓存 | +| `T.alloc_ub()` | UB | 统一缓冲区 | +| `T.alloc_L0A()` | L0A | 矩阵A缓存 | +| `T.alloc_L0B()` | L0B | 矩阵B缓存 | +| `T.alloc_L0C()` | L0C | 累加器 | + +**内存速度**: L0 (最快) > L1 > UB > GM (最慢) + +### 1.4 双核心架构与作用域 + +Ascend NPU采用双核心架构,需要在不同作用域中执行不同类型的操作。 + +**Cube核心** (矩阵运算): +```python +with T.Scope("C"): # Cube scope + T.mma(A_L0, B_L0, C_L0, init=True) +``` + +**Vector核心** (向量运算): +```python +with T.Scope("V"): # Vector scope + T.copy(A[offset], a_ub) + T.add(c_ub, a_ub, b_ub) + T.copy(c_ub, C[offset]) +``` + +**作用域规则** (关键!): +```python +with T.Kernel(core_num, is_npu=True) as (cid, vid): + # ===== T.Kernel作用域 ===== + # ✅ 可以在这里执行: + # - 分配buffer: T.alloc_ub(), T.alloc_L1()等 + # - 计算索引: offset = cid * block_size + + a_ub = T.alloc_ub((block_M, block_N), dtype) + bx = cid // n_num + by = cid % n_num + + # ===== T.Scope作用域 ===== + with T.Scope("V"): # 或 "C" + # ✅ 必须在这里执行: + # - 数据传输: T.copy() + # - 所有计算: T.add(), T.mul()等 + + T.copy(A[bx * block_M, by * block_N], a_ub) + T.add(c_ub, a_ub, b_ub) + T.copy(c_ub, C[bx * block_M, by * block_N]) +``` + +## 2. 标准编程模式 + +### 2.1 向量操作模式 + +向量操作使用Vector核心执行,适用于逐元素操作。 + +```python +@tilelang.jit(out_idx=[-1]) +def vec_add(M, N, block_M, block_N, dtype="float"): + m_num = M // block_M + n_num = N // block_N + VEC_NUM = 2 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid): + bx = cid // n_num + by = cid % n_num + + # 分配UB缓冲区 + a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + b_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + c_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + + with T.Scope("V"): + # 加载数据(使用vid分配工作) + T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub) + T.copy(B[bx * block_M + vid * block_M // VEC_NUM, by * block_N], b_ub) + + T.barrier_all() + T.add(c_ub, a_ub, b_ub) + T.barrier_all() + + # 存储结果 + T.copy(c_ub, C[bx * block_M + vid * block_M // VEC_NUM, by * block_N]) + + return main +``` + +### 2.2 矩阵乘法模式 (GEMM) + +矩阵乘法是Ascend NPU的核心操作,使用Cube核心执行。TileLang提供两种实现方式: + +#### 简单版本(使用 T.gemm_v0) + +适用于原型开发和中小矩阵场景: + +```python +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, K_L1, dtype="float16", accum_dtype="float"): + m_num = M // block_M + n_num = N // block_N + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, _): + bx = cid // n_num + by = cid % n_num + + A_L1 = T.alloc_L1((block_M, K_L1), dtype) + B_L1 = T.alloc_L1((K_L1, block_N), dtype) + C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype) + + with T.Scope("C"): + loop_k = T.ceildiv(K, K_L1) + for k in T.serial(loop_k): + T.copy(A[bx * block_M, k * K_L1], A_L1) + T.copy(B[k * K_L1, by * block_N], B_L1) + + T.barrier_all() + T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) + T.barrier_all() + + T.copy(C_L0, C[bx * block_M, by * block_N]) + + return main +``` + +#### 高级版本(使用 T.mma + 流水线优化) + +适用于性能优化和大矩阵场景,详见第4节流水线优化。 + +#### 两种实现对比 + +| 特性 | 简单版本 | 高级版本 | +|------|----------|----------| +| **内存层级** | L1 → L0C | L1 → L0A/L0B → L0C | +| **矩阵乘法API** | `T.gemm_v0` | `T.mma` | +| **同步方式** | `T.barrier_all()` | `T.set_flag`/`T.wait_flag` | +| **Buffer策略** | 单buffer | 多buffer双缓冲(S1, S2) | +| **布局优化** | 无 | `make_zn_layout` | +| **工作分配** | 每核一个tile | `T.use_swizzle` 优化 | +| **代码复杂度** | 低 | 高 | +| **性能** | 中等 | 高 | +| **适用规模** | 中小矩阵 | 大矩阵 | + +### 2.3 Cube-Vector融合模式 + +当需要在矩阵乘法后进行向量操作时,使用Cube-Vector融合模式。 + +```python +@tilelang.jit(out_idx=[-2]) +def matmul_add(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + m_num = M // block_M + n_num = N // block_N + VEC_NUM = 2 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + D: T.Tensor((M, N), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid): + bx = cid // n_num + by = cid % n_num + + A_L1 = T.alloc_L1((block_M, block_K), dtype) + B_L1 = T.alloc_L1((block_K, block_N), dtype) + C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype) + + d_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + c_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + + # Cube核心:矩阵乘法 + with T.Scope("C"): + loop_k = T.ceildiv(K, block_K) + for k in T.serial(loop_k): + T.copy(A[bx * block_M, k * block_K], A_L1) + T.copy(B[k * block_K, by * block_N], B_L1) + + T.barrier_all() + T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) + T.barrier_all() + + T.copy(C_L0, C[bx * block_M, by * block_N]) + T.set_cross_flag("FIX", 0) # 通知Vector核心 + + # Vector核心:加法操作 + with T.Scope("V"): + T.wait_cross_flag(0) # 等待Cube核心完成 + + T.copy(C[bx * block_M + vid * block_M // VEC_NUM, by * block_N], c_ub) + T.copy(D[bx * block_M + vid * block_M // VEC_NUM, by * block_N], d_ub) + + T.barrier_all() + T.add(c_ub, c_ub, d_ub) + T.barrier_all() + + T.copy(c_ub, C[bx * block_M + vid * block_M // VEC_NUM, by * block_N]) + + return main +``` + +### 2.4 归约模式 + +```python +@tilelang.jit(out_idx=[-1]) +def block_reduce_sum(M, N, block_M, block_N, repeat, mask, dstRepStride, srcBlkStride, srcRepStride, dataBlockNum, dtype="float16"): + m_num = M // block_M + n_num = N // block_N + VEC_NUM = 2 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N // dataBlockNum), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid): + bx = cid // n_num + by = cid % n_num + + a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + b_ub = T.alloc_ub((block_M // VEC_NUM, block_N // dataBlockNum), dtype) + + with T.Scope("V"): + T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub) + + T.barrier_all() + T.block_reduce_sum(b_ub, a_ub, repeat, mask, dstRepStride, srcBlkStride, srcRepStride) + T.barrier_all() + + T.copy(b_ub, B[bx * block_M + vid * block_M // VEC_NUM, by * block_N // dataBlockNum]) + + return main +``` + +**高级归约API**: +```python +# 使用reduce_max进行最大值归约 +T.reduce_max(m_i, acc_s_ub, tmp_ub, dim=-1) + +# 使用reduce_sum进行求和归约 +T.reduce_sum(sumexp_i_ub, acc_s_ub, tmp_ub, dim=-1) +``` + +## 3. 编程技巧详解 + +### 3.1 循环与索引计算 + +**串行循环 (T.serial)**: +```python +# K维度分块循环 +loop_k = T.ceildiv(K, K_L1) +for k in T.serial(loop_k): + # 加载当前tile + T.copy(A[bx * block_M, k * K_L1], A_L1) +``` + +**二维块分解**: +```python +# 计算2D网格的块索引 +m_num = M // block_M +n_num = N // block_N +core_num = m_num * n_num + +with T.Kernel(core_num, is_npu=True) as (cid, vid): + # 从1D索引恢复2D坐标 + bx = cid // n_num + by = cid % n_num + + m_offset = bx * block_M + n_offset = by * block_N +``` + +### 3.2 Buffer索引与切片 + +**标量索引**: +```python +# 单元素访问 +value = a_ub[row_id, col_id] # 访问(row_id, col_id)位置 +``` + +**起始位置索引(用于T.copy)**: +```python +# 从offset位置开始拷贝 +T.copy(A[bx * block_M, by * block_N], a_ub) + +# 使用vid分配工作 +T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub) +``` + +**多buffer索引(流水线)**: +```python +# 使用stage索引访问多buffer +T.copy(A[...], A_L1[k % S1, :, :]) # k % S1选择buffer +T.copy(A_L1[k % S1, 0, kk * block_K], A_L0[kk % S2, :, :]) +``` + +### 3.3 数据拷贝与格式转换 + +**GM ↔ UB (向量操作)**: +```python +# 加载数据 +T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub) +# 存储数据 +T.copy(c_ub, C[bx * block_M + vid * block_M // VEC_NUM, by * block_N]) +``` + +**GM → L1 → L0C (矩阵运算)**: +```python +# GM到L1 +T.copy(A[bx * block_M, k * K_L1], A_L1) +T.copy(B[k * K_L1, by * block_N], B_L1) + +# 使用gemm_v0自动处理L0 +T.gemm_v0(A_L1, B_L1, C_L0, init=True) + +# L0C到GM +T.copy(C_L0, C[bx * block_M, by * block_N]) +``` + +**GM → L1 → L0A/L0B → L0C (高级矩阵运算)**: +```python +# GM到L1 +T.copy(A[...], A_L1[k % S1, :, :]) + +# L1到L0A/L0B +T.copy(A_L1[k % S1, 0, kk * block_K], A_L0[kk % S2, :, :]) +T.copy(B_L1[k % S1, kk * block_K, 0], B_L0[kk % S2, :, :]) + +# 矩阵乘法 +T.mma(A_L0[kk % S2, :, :], B_L0[kk % S2, :, :], C_L0, init=...) +``` + +**数据格式说明**: +- **ND格式**: 标准内存布局(用于GM和UB) +- **NZ格式**: NPU优化布局(用于矩阵运算的L0/L1) +- **布局优化**: 使用 `make_zn_layout` 自动转换为NZ格式 +- T.copy 会自动处理边界对齐 + +### 3.4 VEC_NUM与vid使用 + +**VEC_NUM**: 向量并行数,通常为2 +**vid**: 向量核心ID (0或1) + +```python +VEC_NUM = 2 + +@T.prim_func +def main(...): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid): + # 分配时考虑VEC_NUM + a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + + with T.Scope("V"): + # 使用vid分配工作,每个vid处理不同的行 + offset_m = bx * block_M + vid * block_M // VEC_NUM + T.copy(A[offset_m, by * block_N], a_ub) +``` + +### 3.5 边界处理 + +```python +# 使用T.ceildiv处理非整除情况 +loop_k = T.ceildiv(K, block_K) +for k in T.serial(loop_k): + # 最后一个块可能不完整,使用条件处理 + if k < loop_k - 1: + # 完整块处理 + T.copy(A[bx * block_M, k * block_K], A_L1) + else: + # 最后一个块,可能需要特殊处理 + T.copy(A[bx * block_M, k * block_K], A_L1) +``` + +**注意**: +- T.copy 会自动处理边界对齐 +- 对于固定分块,确保输入维度是 block_size 的整数倍 + +### 3.6 广播操作 + +支持自动广播: +```python +# 支持的广播模式 - 逐行操作 +for h_i in range(block_M // 2): + T.sub(acc_s_ub[h_i, :], acc_s_ub[h_i, :], m_i[h_i]) + T.mul(acc_o[h_i, :], acc_o[h_i, :], scale[h_i]) + T.div(acc_o[h_i, :], acc_o[h_i, :], sumexp[h_i]) + +# 标量运算(编译时常量) +sm_scale = (1.0 / dim)**0.5 # 在prim_func外定义 +T.mul(acc_s_ub, acc_s_ub, sm_scale) + +# 使用T.axpy进行标量乘法 +T.axpy(b_ub, a_ub, 2.0) # b = a * 2.0 +``` + +**重要**: 标量操作数应在`@T.prim_func`定义前定义为编译时常量: +```python +def create_kernel(M, N, dim, dtype="float32"): + # 在prim_func外定义标量 + sm_scale = (1.0 / dim)**0.5 + + @T.prim_func + def main(...): + with T.Kernel(core_num, is_npu=True) as (cid, vid): + with T.Scope("V"): + T.mul(acc_ub, acc_ub, sm_scale) # 正确:使用预定义标量 + # T.mul(acc_ub, acc_ub, 2.5) # 也支持:直接使用字面量 + + return main +``` + +## 4. 流水线与同步 + +### 4.1 基础同步 + +**barrier同步**: +```python +with T.Scope("V"): + T.copy(A[...], a_ub) + T.barrier_all() # 等待所有操作完成 + T.add(c_ub, a_ub, b_ub) + T.barrier_all() + T.copy(c_ub, C[...]) +``` + +**跨核心同步** (Cube-Vector协作): +```python +# Cube核心完成后通知Vector核心 +with T.Scope("C"): + T.gemm_v0(A_L1, B_L1, C_L0, init=True) + T.copy(C_L0, C[bx * block_M, by * block_N]) + T.set_cross_flag("FIX", 0) + +# Vector核心等待Cube核心 +with T.Scope("V"): + T.wait_cross_flag(0) + T.copy(C[bx * block_M + vid * block_M // VEC_NUM, by * block_N], c_ub) + T.add(c_ub, c_ub, d_ub) +``` + +### 4.2 流水线标志 + +使用流水线标志实现计算与内存传输的重叠执行。 + +```python +@T.macro +def init_flag(): + """初始化流水线标志""" + T.set_flag("mte1", "mte2", 0) + T.set_flag("mte1", "mte2", 1) + T.set_flag("m", "mte1", 0) + T.set_flag("m", "mte1", 1) + T.set_flag("fix", "m", 0) + +@T.macro +def clear_flag(): + """清理流水线标志""" + T.wait_flag("mte1", "mte2", 0) + T.wait_flag("mte1", "mte2", 1) + T.wait_flag("m", "mte1", 0) + T.wait_flag("m", "mte1", 1) + T.wait_flag("fix", "m", 0) +``` + +**流水线标志说明**: +- `T.set_flag(producer, consumer, stage)`: 生产者通知消费者 +- `T.wait_flag(producer, consumer, stage)`: 消费者等待生产者 +- `stage`: 双缓冲索引(0/1) + +**管道单元**: +| 标识符 | 单元 | 功能 | +|--------|------|------| +| `mte1` | MTE1 | 内存传输引擎1 | +| `mte2` | MTE2 | 内存传输引擎2 | +| `m` | Cube | 矩阵计算单元 | +| `fix` | Fixpipe | 固定流水线 | + +### 4.3 宏定义 + +使用宏封装可复用的操作序列: + +```python +@T.macro +def load_and_compute(A_gm, A_L1, A_L0, B_L0, C_L0, offset, stage): + """加载并计算的宏""" + T.wait_flag("mte1", "mte2", stage) + T.copy(A_gm[offset], A_L1[stage, :, :]) + T.set_flag("mte2", "mte1", stage) + + T.wait_flag("mte2", "mte1", stage) + T.copy(A_L1[stage, :, :], A_L0[stage, :, :]) + T.mma(A_L0[stage, :, :], B_L0[stage, :, :], C_L0) + T.set_flag("mte1", "m", stage) + +# 使用宏 +with T.Scope("C"): + for k in T.serial(loop_k): + load_and_compute(A, A_L1, A_L0, B_L0, C_L0, k * block_K, k % 2) +``` + +## 5. 布局与内存优化 + +### 5.1 布局标注 + +```python +from tilelang.intrinsics import make_zn_layout + +@T.prim_func +def main(...): + with T.Kernel(core_num, is_npu=True) as (cid, _): + A_L1 = T.alloc_L1((S1, block_M, K_L1), dtype) + B_L1 = T.alloc_L1((S1, K_L1, block_N), dtype) + + # 布局优化 + T.annotate_layout({ + A_L1: make_zn_layout(A_L1), + B_L1: make_zn_layout(B_L1), + }) +``` + +### 5.2 地址标注 + +手动指定buffer地址,用于优化内存布局: +```python +T.annotate_address({ + # L1 address + q_l1: 0, + k_l1: block_M * dim * DataType(dtype).bits // 8, + + # L0C address + acc_s_l0c: 0, + acc_o_l0c: 0, + + # UB address + acc_o: 0, + sumexp: 65536, + m_i: 65664, +}) +``` + +### 5.3 Swizzle优化 + +```python +cid = T.use_swizzle(i * core_num + cid, M, N, K, block_M, block_N, off=3) +``` +- 优化L2缓存局部性 +- `off`: 偏移参数 + +## 6. 高级特性 + +### 6.1 自动同步插入 + +启用自动同步,编译器自动插入流水线同步指令: + +```python +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True +} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + @T.prim_func + def main(...): + T.func_attr({"enable_auto_sync": True}) + with T.Kernel(n_blocks, is_npu=True) as (cid, vid): + # 无需手动set_flag/wait_flag + with T.Scope("V"): + T.copy(A[...], a_ub) + T.add(c_ub, a_ub, b_ub) + T.copy(c_ub, C[...]) + return main +``` + +### 6.2 自动Buffer重用 + +启用自动内存规划,编译器自动复用buffer地址: + +```python +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True +} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + @T.prim_func + def main(...): + # 无需手动配置 T.annotate_address + ... +``` + +## 7. 性能优化要点 + +### 7.1 Tile大小选择 + +- **矩阵乘法**: block_M/N 通常16-128, block_K 通常64-256 +- **向量操作**: 128-512 +- 确保对齐要求(通常是16或32的倍数) + +### 7.2 内存层次优化 + +1. 频繁访问数据放L0 +2. 中间结果保留在L0C +3. 使用L1作为GM和L0之间的缓冲 +4. 仅在必要时写回GM + +### 7.3 流水线优化 + +1. 使用双缓冲(S1=2, S2=4) +2. 正确配对set_flag/wait_flag +3. 考虑使用自动同步功能 + +### 7.4 双核心协作 + +- Cube核心: 矩阵乘法、卷积 +- Vector核心: 激活函数、归约、逐元素操作 +- 使用不同的Scope区分操作 + +### 7.5 版本选择建议 + +1. **原型开发阶段**:使用简单版本快速验证算法正确性 +2. **性能优化阶段**:切换到高级版本获取更好性能 +3. **小矩阵场景**:简单版本即可满足需求 +4. **大矩阵场景**:推荐使用高级版本充分利用流水线 + +## 8. 调试与验证 + +### 8.1 检查清单 + +**基本检查**: +- [ ] Tensor形状用元组 `(M, N)` +- [ ] 设置 `is_npu=True` +- [ ] core_num是编译时Python整数 + +**内存检查**: +- [ ] 正确使用各级内存(L1/L0/UB) +- [ ] Buffer大小不超过硬件限制 +- [ ] 数据对齐要求满足 + +**流水线检查**: +- [ ] set_flag和wait_flag正确配对 +- [ ] 循环结束调用barrier_all() +- [ ] 或使用自动同步功能 + +### 8.2 常见问题 + +1. **编译错误**: 检查核心数是否为Python整数 +2. **运行时错误**: 检查内存大小和对齐 +3. **性能问题**: 检查流水线配置和Tile大小 + +### 8.3 数值验证 + +```python +# 使用PyTorch进行数值验证 +ref_c = a @ b # PyTorch参考实现 +c = func(a, b) # TileLang内核输出 + +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel Output Match!") +``` + +### 8.4 缓存管理 + +TileLang 会自动缓存编译后的内核以提升开发效率。在某些场景下需要手动清除缓存: + +```python +import tilelang + +# 清除编译缓存 +tilelang.cache.clear_cache() +``` + +**使用场景**: +- 修改内核代码后需要强制重新编译 +- 缓存文件损坏导致运行时错误 +- 磁盘空间清理 +- 调试编译问题时确保使用最新代码 + +**最佳实践**: +```python +import tilelang +import tilelang.language as T + +# 开发调试阶段建议在文件开头清除缓存 +tilelang.cache.clear_cache() + +@tilelang.jit(out_idx=[-1]) +def my_kernel(...): + # 内核实现 + pass +``` diff --git a/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_gemm.py b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..a01f158aeef42a356b1d02043510d7acda1c549e --- /dev/null +++ b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_gemm.py @@ -0,0 +1,72 @@ +import tilelang +import tilelang.language as T +import torch + +tilelang.cache.clear_cache() + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, K_L1, dtype="float16", accum_dtype="float"): + m_num = M // block_M + n_num = N // block_N + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, _): + bx = cid // n_num + by = cid % n_num + + A_L1 = T.alloc_L1((block_M, K_L1), dtype) + B_L1 = T.alloc_L1((K_L1, block_N), dtype) + + C_L0 = T.alloc_L0C((block_M, block_N), accum_dtype) + + with T.Scope("C"): + loop_k = T.ceildiv(K, K_L1) + for k in T.serial(loop_k): + T.copy(A[bx * block_M, k * K_L1], A_L1) + T.copy(B[k * K_L1, by * block_N], B_L1) + + T.barrier_all() + T.gemm_v0(A_L1, B_L1, C_L0, init=(k == 0)) + + T.barrier_all() + + T.copy(C_L0, C[bx * block_M, by * block_N]) + + return main + + +class ModelNew(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x0, x1): + M, K0 = x0.shape + K1, N = x1.shape + assert K0 == K1, f"矩阵维度不匹配: {K0} != {K1}" + + func = matmul(M, N, K0, 128, 256, 64) + c = func(x0, x1) + + return c + +# if __name__ == "__main__": +# M, N, K = 1024, 1024, 1024 +# torch.manual_seed(0) +# a = torch.randn(M, K).half().npu() +# b = torch.randn(K, N).half().npu() +# c = torch.empty(M, N).half().npu() +# torch.npu.synchronize() +# print("init successful!") + +# model = ModelNew() +# c = model(a, b) + +# ref_c = a @ b + +# torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +# print("Kernel Output Match!") \ No newline at end of file diff --git a/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_reduce_sum.py b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_reduce_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..9f59954cdfe8c8810ccd8da214a1bfbc60f25594 --- /dev/null +++ b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_reduce_sum.py @@ -0,0 +1,85 @@ +import tilelang +import tilelang.language as T +import torch + +tilelang.cache.clear_cache() + +M = 2 +N = 512 +block_M = 2 +block_N = 128 +dataBlockHalfNum = 16 +mask = 128 +repeat = 1 +dstRepStride = 1 +srcBlkStride = 1 +srcRepStride = 8 + + +@tilelang.jit(out_idx=[-1]) +def block_reduce_sum(M, N, block_M, block_N, repeat, mask, dstRepStride, srcBlkStride, srcRepStride, dataBlockNum, dtype="float16"): + m_num = M // block_M + n_num = N // block_N + + VEC_NUM = 2 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N // dataBlockNum), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid): + bx = cid // n_num + by = cid % n_num + + a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + b_ub = T.alloc_ub((block_M // VEC_NUM, block_N // dataBlockNum), dtype) + with T.Scope("V"): + T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub) + + T.barrier_all() + T.block_reduce_sum(b_ub, a_ub, repeat, mask, dstRepStride, srcBlkStride, srcRepStride) + T.barrier_all() + + T.copy(b_ub, B[bx * block_M + vid * block_M // VEC_NUM, by * block_N // dataBlockNum]) + + return main + + + +class ModelNew(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x0): + M, N = x0.shape + + func = block_reduce_sum(M, N, block_M, block_N, repeat, mask, dstRepStride, srcBlkStride, srcRepStride, dataBlockHalfNum) + c = func(x0) + + return c + +# if __name__ == "__main__": +# M, N = 1024, 1024 +# torch.manual_seed(0) +# a = torch.randn(M, N, dtype=torch.float16).npu() +# torch.npu.synchronize() +# print("init successful!") + +# model = ModelNew() +# b = model(a) + +# num_groups = M * N // dataBlockHalfNum +# ref_b = torch.zeros((1, num_groups)).to(torch.float16) +# a_flag = a.reshape(-1) +# for i in range(num_groups): +# start = i * dataBlockHalfNum +# end = start + dataBlockHalfNum +# group = a_flag[start:end] +# sum_val = torch.sum(group).item() +# ref_b[0, i] = sum_val +# ref_b = ref_b.reshape(M, N // dataBlockHalfNum) +# ref_b = ref_b.npu().to(dtype=torch.float16) + +# torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2) +# print("Kernel Output Match!") \ No newline at end of file diff --git a/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_vec_add.py b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_vec_add.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba134e9820bf88ddf458d98e2b6fc16374456eb --- /dev/null +++ b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/examples/torch_vec_add.py @@ -0,0 +1,70 @@ +import tilelang +import tilelang.language as T +import torch + +tilelang.cache.clear_cache() + + +@tilelang.jit(out_idx=[-1]) +def vec_add(M, N, block_M, block_N, dtype="float"): + m_num = M // block_M + n_num = N // block_N + + VEC_NUM = 2 + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(m_num * n_num, is_npu=True) as (cid, vid): + bx = cid // n_num + by = cid % n_num + + a_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + b_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + c_ub = T.alloc_ub((block_M // VEC_NUM, block_N), dtype) + with T.Scope("V"): + T.copy(A[bx * block_M + vid * block_M // VEC_NUM, by * block_N], a_ub) + T.copy(B[bx * block_M + vid * block_M // VEC_NUM, by * block_N], b_ub) + + T.barrier_all() + T.add(c_ub, a_ub, b_ub) + T.barrier_all() + + T.copy(c_ub, C[bx * block_M + vid * block_M // VEC_NUM, by * block_N]) + + return main + + +class ModelNew(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x0, x1): + M0, N0 = x0.shape + M1, N1 = x1.shape + assert M0 == M1, f"矩阵维度不匹配: {M0} != {M1}" + assert N0 == N1, f"矩阵维度不匹配: {N0} != {N1}" + + func = vec_add(M0, N0, 128, 256) + c = func(x0, x1) + + return c + +# if __name__ == "__main__": +# M, N = 1024, 1024 +# torch.manual_seed(0) +# a = torch.randn(M, N).npu() +# b = torch.randn(M, N).npu() +# torch.npu.synchronize() +# print("init successful!") + +# model = ModelNew() +# c = model(a, b) + +# ref_c = a + b + +# torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +# print("Kernel Output Match!") \ No newline at end of file diff --git a/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/suggestion_docs.md b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/suggestion_docs.md new file mode 100644 index 0000000000000000000000000000000000000000..145c52ac638762706699512878a5312744ef9082 --- /dev/null +++ b/aikg/python/ai_kernel_generator/resources/docs/tilelang_ascendc_pto_docs/suggestion_docs.md @@ -0,0 +1,726 @@ +# TileLang-Ascend 优化建议与常见错误 + +本文档提供TileLang-Ascend编程的关键规则、常见错误和性能优化建议。 + +--- + +## 关键规则速查 + +### 必须遵守的规则 +1. **函数名**: `@T.prim_func def main(...)` +2. **NPU标志**: `T.Kernel(..., is_npu=True)` +3. **Tensor形状**: 使用元组 `T.Tensor((M, N), dtype)` +4. **核心数**: 必须是编译时Python整数,在闭包外计算,且至少为1 +5. **作用域**: 矩阵运算用`T.Scope("C")`,向量运算用`T.Scope("V")` +6. **流水线同步**: set_flag和wait_flag必须正确配对,并且在正确的位置插入同步语句 +7. **内存层级**: 正确使用L1/L0A/L0B/L0C/UB +8. **浮点常量**: 避免整数形式(`5.5`而非`5.0`) +9. **科学计数法**: 不支持,使用Tensor传递或`0.0001` +10. **常量操作符**: 直接使用字面量常量,不要定义和传递,注意常量与向量的类型匹配 +11. **编译时常量**: 直接使用外部常量,不要在`@T.prim_func`内重新赋值 +12. **复合表达式**: 必须拆分为多个步骤,不能直接作为函数参数 +13. **切分限制**: 切分值小于shape大小 +14. **高阶API**: 计算较为复杂时,确认是否存在高阶API,如leaky_relu等 + +--- + +## 1. 常见错误 + +### 错误1: 函数名不是main + +```python +# ❌ 错误 +@T.prim_func +def kernel_copy(...): + ... + +# ✅ 正确 +@T.prim_func +def main(...): + ... +``` +**错误信息**: `rtFunctionRegister failed` + +### 错误2: 未设置is_npu=True + +```python +# ❌ 错误 +with T.Kernel(n_blocks) as (cid, _): + ... + +# ✅ 正确 +with T.Kernel(n_blocks, is_npu=True) as (cid, _): + ... +``` +**错误信息**: 编译可能通过,但运行时设备分配错误 + +### 错误3: 核心数在函数内部计算 + +```python +# ❌ 错误:在@T.prim_func内部计算 +@T.prim_func +def main(...): + core_num = T.ceildiv(M, block_M) # TVM表达式,不是Python整数 + with T.Kernel(core_num, is_npu=True) as (cid, _): + ... + +# ✅ 正确:在闭包外部用Python计算 +def create_kernel(M, N, block_M, block_N): + core_num = (M // block_M) * (N // block_N) # Python整数 + + @T.prim_func + def main(...): + with T.Kernel(core_num, is_npu=True) as (cid, _): + ... +``` +**错误信息**: `AttributeError: 'NoneType' object has no attribute 'group'` + +### 错误4: Tensor形状使用列表 + +```python +# ❌ 错误 +T.Tensor([M, N], dtype) + +# ✅ 正确 +T.Tensor((M, N), dtype) +``` +**错误信息**: `rtFunctionRegister failed` 或类型不匹配 + +### 错误5: 错误使用内存层级 + +```python +# ❌ 错误:向量操作使用L0内存 +with T.Scope("V"): + A_L0 = T.alloc_L0A((block_M, block_K), dtype) # L0是Cube专用 + T.add(A_L0, B_L0, C_L0) + +# ✅ 正确:向量操作使用UB +with T.Scope("V"): + a_ub = T.alloc_ub((block_M, block_N), dtype) + b_ub = T.alloc_ub((block_M, block_N), dtype) + c_ub = T.alloc_ub((block_M, block_N), dtype) + T.add(c_ub, a_ub, b_ub) +``` + +**内存层级与作用域对应**: +| 内存 | 作用域 | 用途 | +|------|--------|------| +| L0A/L0B/L0C | `T.Scope("C")` | Cube矩阵运算 | +| UB | `T.Scope("V")` | Vector向量运算 | +| L1 | 两者皆可 | 数据缓存 | + +### 错误6: 缺少流水线同步 + +```python +# ❌ 错误:缺少同步 +with T.Scope("V"): + T.copy(A[...], a_ub) + T.copy(B[...], b_ub) + T.add(c_ub, a_ub, b_ub) + T.copy(c_ub, C[...]) + +# ✅ 正确:添加同步 +with T.Scope("V"): + T.copy(A[...], a_ub) + T.copy(B[...], b_ub) + T.barrier_all() + T.add(c_ub, a_ub, b_ub) + T.barrier_all() + T.copy(c_ub, C[...]) +``` + +**同步位置**: 在搬入-计算,计算-搬出之间插入同步 + +### 错误7: set_flag和wait_flag不配对 + +```python +# ❌ 错误:标志不配对 +T.set_flag("mte2", "mte1", 0) +# 缺少对应的 wait_flag + +# ✅ 正确:正确配对 +T.set_flag("mte2", "mte1", 0) # 生产者设置 +... +T.wait_flag("mte2", "mte1", 0) # 消费者等待 +``` + +**流水线标志规则**: +- `set_flag(src, dst, eventId)`: src流水线完成后设置标志通知dst +- `wait_flag(src, dst, eventId)`: dst流水线等待src完成的标志 +- eventId用于双缓冲(0/1) + +### 错误8: 矩阵乘法初始化错误 + +```python +# ❌ 错误:累加时使用init=True +for k in T.serial(loop_k): + T.mma(A_L0[...], B_L0[...], C_L0, init=True) # 每次都清零 + +# ✅ 正确:首次初始化,后续累加 +for k in T.serial(loop_k): + if k == 0: + T.mma(A_L0[...], B_L0[...], C_L0, init=True) + else: + T.mma(A_L0[...], B_L0[...], C_L0) +``` + +### 错误9: 作用域缺失 + +```python +# ❌ 错误:在Kernel内直接操作 +with T.Kernel(core_num, is_npu=True) as (cid, _): + T.mma(A_L0, B_L0, C_L0) # 缺少Scope + +# ✅ 正确:使用正确的Scope +with T.Kernel(core_num, is_npu=True) as (cid, _): + with T.Scope("C"): + T.mma(A_L0, B_L0, C_L0) +``` + +### 错误10: 自动同步配置不完整 + +```python +# ❌ 错误:只设置pass_configs,缺少func_attr +pass_configs = {tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + @T.prim_func + def main(...): + # 缺少 T.func_attr({"enable_auto_sync": True}) + ... + +# ✅ 正确:同时设置 +pass_configs = {tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + @T.prim_func + def main(...): + T.func_attr({"enable_auto_sync": True}) # 必须添加 + ... +``` + +### 错误11: 整数形式的浮点常量 + +```python +# ❌ 错误 +value = 5.0 +T.mul(buffer, value, result) + +# 方案1: 使用非整数形式(快速修复) +value = 5.5 # 或 5.00001 + +# 方案2: 在buffer内定义(推荐) +# 需要初始化为5.0的场景 +init_buffer: T.Tensor((N,), dtype) +# 在host端初始化为5.0后传入 +``` +**错误信息**: `unexpected decimal integer literal` + +### 错误12: 使用科学计数法 + +```python +# ❌ 错误 +eps = 1e-5 +eps = 0.00001 # 也会被自动转换为 1e-05 + +# 方案1: Tensor传递(推荐) +eps_tensor: T.Tensor((1, 1), dtype) +# 在host端初始化为1e-5后传入 + +# 方案2: 使用较大值(如果精度允许) +eps = 0.0001 # 0.0001不会被转换 +``` +**错误信息**: `custom op 'e' is unknown` + +### 错误13: 常量类型确定 + +```python +# ❌ 错误, 常量和向量类型不匹配, 常量默认是float +x_ub = T.alloc_ub((...), "float16") +result_ub = T.alloc_ub((...), "float16") +with T.Scope("V"): + T.copy(X[...], x_ub) + T.barrier_all() + T.add(result_ub, x_ub, 1.0) + +# ✅ 正确, 常量和向量类型匹配, +x_ub = T.alloc_ub((...), "float") +result_ub = T.alloc_ub((...), "float") +with T.Scope("V"): + T.copy(X[...], x_ub) + T.barrier_all() + T.add(result_ub, x_ub, 1.0) +``` +**错误信息**: `no matching function for call to 'Adds'` + +### 错误14: 直接使用字面量常量 + +```python +# ❌ 错误, 定义变量并传递 +x_ub = T.alloc_ub((...), "float") +result_ub = T.alloc_ub((...), "float") +zero_val = 0.5 +with T.Scope("V"): + T.copy(X[...], x_ub) + T.barrier_all() + T.add(result_ub, x_ub, zero_val) + +# ✅ 正确:直接使用字面量 +x_ub = T.alloc_ub((...), "float") +result_ub = T.alloc_ub((...), "float") +with T.Scope("V"): + T.copy(X[...], x_ub) + T.barrier_all() + T.add(result_ub, x_ub, 0.5) +``` +**错误信息**: `expected expression main_kernel_tiling(, zero_val);` + +### 错误15: 编译时常量重新赋值 + +```python +# ❌ 错误:重新赋值外部常量 +n_blocks_n = 128 # 外部定义 + +@T.prim_func +def main(...): + with T.Kernel(total_blocks, is_npu=True) as (cid, _): + n_blocks_n_int = n_blocks_n # ❌ 重新赋值 + block_m_id = cid // n_blocks_n_int + +# ✅ 正确:直接使用 +@T.prim_func +def main(...): + with T.Kernel(total_blocks, is_npu=True) as (cid, _): + block_m_id = cid // n_blocks_n # ✅ 直接使用外部常量 +``` +**错误信息**: `error: custom op 'v_' is unknown (tried 'func.v_' as well)` + +**关键规则**: +- ✅ 直接使用:`cid // n_blocks_n` +- ❌ 重新赋值:`temp = n_blocks_n; cid // temp` + +### 错误16: 复合表达式未拆分 + +```python +# ❌ 错误:T.min参数中使用复合表达式 +for k_idx in T.serial(T.ceildiv(K, block_k)): + k_offset = k_idx * block_k + tail_k = T.min(block_k, K - k_offset) # ❌ 复合表达式 + +# ✅ 正确:拆分为两步 +for k_idx in T.serial(T.ceildiv(K, block_k)): + k_offset = k_idx * block_k + k_remaining = K - k_offset # ✅ 先计算 + tail_k = T.min(block_k, k_remaining) # ✅ 再使用 +``` +**错误信息**: `error: expected SSA operand` + +**关键规则**: +- ✅ 拆分:`temp = K - k_offset; T.min(block_k, temp)` +- ❌ 直接使用:`T.min(block_k, K - k_offset)` +- 适用于`T.min/max`等所有内置函数的复合表达式参数 + +### 错误17: T.copy在T.Scope外部调用 + +```python +# ❌ 错误:数据传输操作在Scope外部 +with T.Kernel(n_blocks, is_npu=True) as (cid, _): + a_ub = T.alloc_ub((block_size,), dtype) + + T.copy(A[offset], a_ub) # ❌ 在Scope外 + + with T.Scope("V"): + T.add(c_ub, a_ub, b_ub) + +# ✅ 正确:所有数据操作在Scope内部 +with T.Kernel(n_blocks, is_npu=True) as (cid, _): + a_ub = T.alloc_ub((block_size,), dtype) # 分配可以在外部 + + with T.Scope("V"): + T.copy(A[offset], a_ub) # ✅ 在Scope内 + T.add(c_ub, a_ub, b_ub) + T.copy(c_ub, C[offset]) +``` +**错误信息**: `TVMError: x_ub should be a memref` + +**原因**: `T.copy`和所有计算操作必须在`T.Scope("V")`或`T.Scope("C")`内部执行 + +**作用域规则**: +```python +with T.Kernel(n_blocks, is_npu=True) as (cid, _): + # ✅ 可以在这里: 分配buffer, 计算索引 + x_ub = T.alloc_ub((size,), dtype) + offset = cid * block_size + + with T.Scope("V"): # 或 "C" + # ✅ 必须在这里: T.copy, T.add, T.mul 等 + T.copy(X[offset], x_ub) + T.mul(x_ub, y_ub, 2.5) + T.copy(y_ub, Y[offset]) +``` + +### 错误18: 切分值大于shape未做边界处理 + +```python +# ❌ 错误:block_size > N 时未做任何处理 +def create_kernel(M, N, block_size, dtype="float32"): + n_blocks = N // block_size # 当 block_size > N 时,n_blocks = 0 + + @T.prim_func + def main(X: T.Tensor((M, N), dtype), Y: T.Tensor((M, N), dtype)): + with T.Kernel(n_blocks, is_npu=True) as (cid, _): # ❌ n_blocks=0 导致错误 + ... + +# ✅ 正确方案1:确保切分值不大于shape +def create_kernel(M, N, block_size, dtype="float32"): + actual_block_size = min(block_size, N) # ✅ 限制切分值 + n_blocks = (N + actual_block_size - 1) // actual_block_size + + @T.prim_func + def main(X: T.Tensor((M, N), dtype), Y: T.Tensor((M, N), dtype)): + with T.Kernel(n_blocks, is_npu=True) as (cid, _): + ... +``` +**错误信息**: `AttributeError: undefined symbol: get_last_error` + +--- + +## 2. 性能优化 + +### 2.1 Tile大小选择 + +| 操作类型 | 参数 | 推荐范围 | 说明 | +|---------|------|----------|------| +| 矩阵乘法 | block_M | 16-128 | M维度分块 | +| 矩阵乘法 | block_N | 16-128 | N维度分块 | +| 矩阵乘法 | block_K | 64-256 | K维度分块 | +| 向量操作 | block_size | 128-512 | 向量长度 | + +**选择原则**: +- 必须满足硬件对齐要求(通常16或32的倍数) +- 平衡内存使用和计算效率 +- 考虑L1/L0内存大小限制 + +### 2.2 内存层次优化 + +**数据流优化**: +``` +GM → L1 → L0A/L0B → Cube计算 → L0C → GM + ↓ +GM → UB → Vector计算 → UB → GM +``` + +**优化建议**: +1. **L1缓存**: 使用多级缓冲减少GM访问 +2. **L0复用**: 尽量复用L0C中的中间结果 +3. **预取**: 使用双缓冲实现计算和传输重叠 + +```python +# 双缓冲示例 +S1 = 2 # L1双缓冲 +S2 = 4 # L0四缓冲 + +A_L1 = T.alloc_L1((S1, block_M, K_L1), dtype) +A_L0 = T.alloc_L0A((S2, block_M, block_K), dtype) + +for k in T.serial(loop_k): + # 使用 k % S1 和 kk % S2 索引实现双缓冲 + T.copy(A[...], A_L1[k % S1, :, :]) + T.copy(A_L1[k % S1, ...], A_L0[kk % S2, :, :]) +``` + +### 2.3 流水线优化 + +**手动流水线**: +```python +@T.macro +def init_flag(): + T.set_flag("mte1", "mte2", 0) + T.set_flag("mte1", "mte2", 1) + T.set_flag("m", "mte1", 0) + T.set_flag("m", "mte1", 1) + +with T.Scope("C"): + init_flag() + + for k in T.serial(loop_k): + # 阶段1: 数据加载 + T.wait_flag("mte1", "mte2", k % 2) + T.copy(A[...], A_L1[k % 2, :, :]) + T.set_flag("mte2", "mte1", k % 2) + + # 阶段2: L1→L0传输 + T.wait_flag("mte2", "mte1", k % 2) + T.copy(A_L1[k % 2, ...], A_L0[k % 2, :, :]) + T.set_flag("mte1", "m", k % 2) + + # 阶段3: 计算 + T.wait_flag("mte1", "m", k % 2) + T.mma(A_L0[k % 2, ...], B_L0[k % 2, ...], C_L0) + T.set_flag("m", "mte1", k % 2) +``` + +**自动流水线**(推荐): +```python +pass_configs = {tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True} + +@tilelang.jit(out_idx=[-1], pass_configs=pass_configs) +def kernel(...): + @T.prim_func + def main(...): + T.func_attr({"enable_auto_sync": True}) + # 编译器自动插入同步指令 +``` + +### 2.4 布局优化 + +```python +from tilelang.intrinsics import make_zn_layout + +# ZN布局优化(适合矩阵运算) +T.annotate_layout({ + A_L1: make_zn_layout(A_L1), + B_L1: make_zn_layout(B_L1), +}) + +# Swizzle优化(改善L2缓存局部性) +T.use_swizzle(i * core_num + cid, M, N, K, block_M, block_N, off=3, in_loop=True) +``` + +### 2.5 自动Buffer重用 + +启用自动内存规划,避免手动配置地址: +```python +pass_configs = { + tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True +} + +# 无需手动配置: +# T.annotate_address({ +# q_l1: 0, +# kv_l1: 65536, +# ... +# }) +``` + +--- + +## 3. 调试清单 + +### 基本检查 +- [ ] 函数名为`main` +- [ ] 设置 `is_npu=True` +- [ ] Tensor用元组 `(M, N)` +- [ ] core_num在闭包外计算(Python整数) + +### 内存检查 +- [ ] Cube操作使用L0A/L0B/L0C +- [ ] Vector操作使用UB +- [ ] Buffer大小不超过硬件限制 +- [ ] 使用双缓冲优化 + +### 流水线检查 +- [ ] set_flag和wait_flag正确配对 +- [ ] 循环结束调用barrier_all() +- [ ] 或启用自动同步功能 + +### 常量检查 +- [ ] 直接使用字面量常量,不需要定义并传递 +- [ ] 计算时常量和向量的类型需要一致 +- [ ] 避免科学计数法 +- [ ] 浮点常量使用非整数值 +- [ ] 小常量(<0.0001)通过Tensor传递 +- [ ] 编译时常量直接使用,不要重新赋值 +- [ ] 复合表达式拆分为多个步骤 + +### Tile切分检查 +- [ ] 切分值不大于对应维度的shape +- [ ] 核心数(n_blocks)至少为1 +- [ ] 使用向上取整或min/max处理边界 + +### 性能检查 +- [ ] Tile大小满足对齐要求 +- [ ] 使用布局优化 +- [ ] 考虑使用自动内存规划 + +--- + +## 4. 典型错误信息速查表 + +| 错误信息 | 原因 | 解决方案 | 错误编号 | +|---------|------|----------|---------| +| `rtFunctionRegister failed` | Tensor用列表或函数名不是main | 用元组、函数名改为main | 错误1/4 | +| 设备分配错误 | 未设置is_npu=True | 添加is_npu=True | 错误2 | +| `'NoneType' object has no attribute 'group'` | core_num不是Python整数 | 在闭包外计算core_num | 错误3 | +| 内存访问错误 | 内存层级使用错误 | 检查L0/UB使用 | 错误5 | +| 结果错误 | 未正确使用同步 | 检查搬入-计算,计算-搬出之间是否存在同步语句 | 错误6 | +| 同步死锁 | 流水线标志不配对 | 检查set/wait_flag | 错误7 | +| 结果错误 | mma初始化问题 | 检查init参数 | 错误8 | +| `unexpected decimal integer literal` | 整数形式浮点数 | 使用`5.5`而非`5.0` | 错误11 | +| `custom op 'e' is unknown` | 科学计数法(如1e-5) | Tensor传递或`0.0001` | 错误12 | +| `no matching function for call to 'Adds'` | 常量和向量类型不匹配 | 修改类型 | 错误13 | +| `expected expression main_kernel_tiling(, zero_val);` | 定义变量传递常量 | 直接使用字面量常量 | 错误14 | +| `custom op 'v_' is unknown` | 编译时常量重新赋值 | 直接使用外部常量,不要重新赋值 | 错误15 | +| `expected SSA operand` | 复合表达式未拆分(如`K - k_offset`) | 将复合表达式拆分为多个步骤 | 错误16 | +| `TVMError: x_ub should be a memref` | T.copy在T.Scope外调用 | 所有数据操作放在Scope内 | 错误17 | +| `undefined symbol: get_last_error` | 切分值大于shape未处理 | 调整切分值 | 错误18 | + +--- + +## 5. 开发建议 + +### 渐进式开发 +1. **功能正确性**: 先实现基础版本,验证结果正确 +2. **添加流水线**: 逐步添加set_flag/wait_flag +3. **布局优化**: 添加annotate_layout和swizzle +4. **性能调优**: 调整Tile大小和缓冲级数 + +### 调试技巧 +1. **简化测试**: 使用小矩阵验证正确性 +2. **逐步添加**: 每次只添加一个优化特性 +3. **使用自动功能**: 优先使用AUTO_SYNC和MEMORY_PLANNING + +### 常量处理 +- 直接使用字面量常量,不需要定义并传递 +- 计算时常量和向量的类型需要一致 +- 避免科学计数法 +- 小常量通过Tensor传递 + +### 代码组织 +```python +# 推荐的代码结构 +@tilelang.jit(out_idx=[-1], pass_configs={...}) +def kernel(M, N, K, ...): + # 编译时参数计算 + core_num = ... + + # @T.macro 必须定义在 @tilelang.jit 内部 + @T.macro + def init_flag(): + ... + + @T.macro + def clear_flag(): + ... + + @T.macro + def load_data(...): + ... + + @T.prim_func + def main(...): + with T.Kernel(core_num, is_npu=True) as (cid, _): + # 内存分配 + ... + + # 布局优化 + T.annotate_layout({...}) + + with T.Scope("C"): # 或 "V" + # 初始化 + init_flag() + + # 主循环 + for k in T.serial(loop_k): + load_data(...) + T.mma(...) + + # 清理 + clear_flag() + T.barrier_all() + + return main +``` + +--- + +## 附录: 完整示例 + +### 矩阵乘法(带流水线) + +```python +import tilelang +import tilelang.language as T + +@tilelang.jit(out_idx=[-1]) +def gemm_pipelined(M, N, K, block_M, block_N, block_K, dtype="float16"): + m_num = M // block_M + n_num = N // block_N + core_num = m_num * n_num + K_L1 = block_K * 4 + S1, S2 = 2, 4 + + @T.macro + def init_flag(): + T.set_flag("mte1", "mte2", 0) + T.set_flag("mte1", "mte2", 1) + T.set_flag("m", "mte1", 0) + T.set_flag("m", "mte1", 1) + T.set_flag("fix", "m", 0) + + @T.macro + def clear_flag(): + T.wait_flag("mte1", "mte2", 0) + T.wait_flag("mte1", "mte2", 1) + T.wait_flag("m", "mte1", 0) + T.wait_flag("m", "mte1", 1) + T.wait_flag("fix", "m", 0) + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(core_num, is_npu=True) as (cid, _): + A_L1 = T.alloc_L1((S1, block_M, K_L1), dtype) + B_L1 = T.alloc_L1((S1, K_L1, block_N), dtype) + A_L0 = T.alloc_L0A((S2, block_M, block_K), dtype) + B_L0 = T.alloc_L0B((S2, block_K, block_N), dtype) + C_L0 = T.alloc_L0C((block_M, block_N), "float32") + + with T.Scope("C"): + init_flag() + + bx = cid // n_num + by = cid % n_num + loop_k = T.ceildiv(K, K_L1) + + T.wait_flag("mte1", "mte2", 0) + T.copy(A[bx * block_M, 0], A_L1[0, :, :]) + T.copy(B[0, by * block_N], B_L1[0, :, :]) + T.set_flag("mte2", "mte1", 0) + + for k in T.serial(loop_k): + if k < loop_k - 1: + T.wait_flag("mte1", "mte2", (k + 1) % S1) + T.copy(A[bx * block_M, (k + 1) * K_L1], A_L1[(k + 1) % S1, :, :]) + T.copy(B[(k + 1) * K_L1, by * block_N], B_L1[(k + 1) % S1, :, :]) + T.set_flag("mte2", "mte1", (k + 1) % S1) + + loop_kk = T.ceildiv(K_L1, block_K) + for kk in T.serial(loop_kk): + if kk == 0: + T.wait_flag("mte2", "mte1", k % S1) + T.wait_flag("m", "mte1", kk % S2) + T.copy(A_L1[k % S1, 0, kk * block_K], A_L0[kk % S2, :, :]) + T.copy(B_L1[k % S1, kk * block_K, 0], B_L0[kk % S2, :, :]) + T.set_flag("mte1", "m", kk % S2) + T.wait_flag("mte1", "m", kk % S2) + + if k == 0 and kk == 0: + T.mma(A_L0[kk % S2, :, :], B_L0[kk % S2, :, :], C_L0, init=True) + else: + T.mma(A_L0[kk % S2, :, :], B_L0[kk % S2, :, :], C_L0) + + T.set_flag("m", "mte1", kk % S2) + + T.copy(C_L0, C[bx * block_M, by * block_N]) + + clear_flag() + T.barrier_all() + + return main +``` diff --git a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 index 867b85f921bf09e62de5785d3c1da76f128c150a..8c7ed20eed3d217087f33f859939a7f7d32c87c7 100644 --- a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 +++ b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 @@ -64,6 +64,15 @@ try: except ImportError: pass from {{ op_name }}_tilelang_npuir import {{ impl_func_name }} +{% elif dsl == "tilelang_ascendc_pto" %} +import tilelang +tilelang.cache.clear_cache() +try: + from ai_kernel_generator.utils.tilelang_compile_patch import apply_tilelang_patches + apply_tilelang_patches() +except ImportError: + pass +from {{ op_name }}_tilelang_ascendc_pto import {{ impl_func_name }} {% elif dsl == "swft" %} from {{ op_name }}_swft import {{ impl_func_name }} {% elif dsl == "cuda_c" %} @@ -424,7 +433,7 @@ def verify_implementations(): # 加载SWFT输出 impl_output = load_binary_data(data_dir, framework_output) - {% elif dsl in ["triton_cuda", "triton_ascend", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda"] %} + {% elif dsl in ["triton_cuda", "triton_ascend", "cuda_c", "cpp", "tilelang_npuir", "tilelang_cuda", "tilelang_ascendc_pto"] %} # 运行实现 impl_output = {{ impl_func_name }}(*inputs_for_impl) {% elif dsl == "ascendc" %} diff --git a/aikg/tests/st/test_task_tilelang_ascendc_pto.py b/aikg/tests/st/test_task_tilelang_ascendc_pto.py new file mode 100644 index 0000000000000000000000000000000000000000..6c67fec91ac07b5eeee319117ee58dd31532054c --- /dev/null +++ b/aikg/tests/st/test_task_tilelang_ascendc_pto.py @@ -0,0 +1,70 @@ +import pytest +from ai_kernel_generator.core.task import Task +from ai_kernel_generator.core.async_pool.task_pool import TaskPool +from ai_kernel_generator.core.worker.manager import register_local_worker +from ..utils import ( + get_kernelbench_op_name, get_kernelbench_task_desc, add_op_prefix, + process_task_results, get_device_id +) +from ai_kernel_generator.config.config_validator import load_config +from ai_kernel_generator.utils.environment_check import check_env_for_task + + +device_id = get_device_id() + + +@pytest.mark.level0 +@pytest.mark.torch +@pytest.mark.tilelang_ascendc_pto +@pytest.mark.ascend +@pytest.mark.ascend910b4 +@pytest.mark.use_model +@pytest.mark.asyncio +async def test_kernelbench_torch_tilelang_ascendc_pto(): + """测试 KernelBench - PyTorch TileLang AscendC PTO Ascend910B4""" + framework = "torch" + dsl = "tilelang_ascendc_pto" + backend = "ascend" + arch = "ascend910b4" + benchmark = "KernelBench" + + task_pool = TaskPool() + # device_pool = DevicePool([device_id]) # 旧写法 + # or load_config("/your-path-to-config/xxx_config.yaml") + config = load_config(config_path="./python/ai_kernel_generator/config/vllm_tilelang_ascendc_pto_coderonly_config.yaml") + + check_env_for_task(framework, backend, dsl, config) + + # 新写法:注册 LocalWorker + await register_local_worker([device_id], backend=backend, arch=arch) + + # KernelBench: 按序号读取 + benchmark_name = get_kernelbench_op_name( + task_index_list=[19, ], framework=framework) + + if benchmark_name is None: + raise RuntimeError(f"benchmark '{benchmark}' 不支持") + + for i in range(len(benchmark_name)): + task_desc = get_kernelbench_task_desc( + benchmark_name[i], framework=framework) + op_name = add_op_prefix(benchmark_name[i], benchmark=benchmark) + + task = Task( + op_name=op_name, + task_desc=task_desc, + task_id=str(i), + backend=backend, + arch=arch, + dsl=dsl, + config=config, + framework=framework, + workflow="coder_only_workflow" + ) + task_pool.create_task(task.run) + + results = await task_pool.wait_all() + + # 使用通用的结果处理函数 + success = process_task_results(results, print_summary=True) + assert success, "存在测试case失败" \ No newline at end of file