From 5fcf04fd02735bb6f34ddf375f448dc2a48d1580 Mon Sep 17 00:00:00 2001 From: zhu_yiyang Date: Thu, 4 Sep 2025 16:01:12 +0800 Subject: [PATCH] add: triton2torch --- .../core/verifier/kernel_verifier.py | 7 +- .../templates/kernel_verify_template.j2 | 25 ++++-- .../ai_kernel_generator/utils/common_utils.py | 13 +++ .../default_triton2torch_config.yaml | 9 ++ .../test_triton2torch_converter.py | 61 +++++++++++++ .../triton2torch/triton2torch_converter.py | 80 +++++++++++++++++ .../triton2torch_converter_template.j2 | 87 +++++++++++++++++++ 7 files changed, 275 insertions(+), 7 deletions(-) create mode 100644 aikg/tools/triton2torch/default_triton2torch_config.yaml create mode 100644 aikg/tools/triton2torch/test_triton2torch_converter.py create mode 100644 aikg/tools/triton2torch/triton2torch_converter.py create mode 100644 aikg/tools/triton2torch/triton2torch_converter_template.j2 diff --git a/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py b/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py index 69c605113..a42c62a7e 100644 --- a/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py +++ b/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py @@ -53,6 +53,7 @@ class KernelVerifier: backend: BackendType = "cuda", arch: ArchType = "a100", impl_func_name: Optional[str] = None, + impl_model_name: Optional[str] = None, config: Optional[Dict[str, Any]] = None): """ 初始化Kernel验证器。 @@ -67,6 +68,7 @@ class KernelVerifier: backend (BackendType): 计算设备后端,可选值包括 "cuda", "ascend" arch (ArchType): 硬件架构,可选值包括 "a100", "v100", "ascend910b4", "ascend310p3" impl_func_name (str, optional): 实现函数名,默认为op_name_dsl_framework + impl_model_name (str, optional): 实现模型名, 默认为None """ self.op_name = op_name self.framework_code = framework_code @@ -75,6 +77,7 @@ class KernelVerifier: self.backend = backend.lower() self.arch = arch.lower() self.task_id = task_id + self.impl_model_name = impl_model_name # 从config中获取log_dir if config: @@ -204,6 +207,7 @@ class KernelVerifier: dsl=self.dsl, device_id=device_id, impl_func_name=self.impl_func_name, + impl_model_name=self.impl_model_name, backend=self.backend, arch=self.arch, is_dynamic_shape=is_dynamic_shape, @@ -226,7 +230,7 @@ class KernelVerifier: os.chdir(verify_dir) python_cmd = ["python", f"verify_{self.op_name}.py"] # 使用run_command但禁用timeout,让验证脚本无限制运行 - return run_command(python_cmd, f"verify_{self.op_name}", timeout=None) + return run_command(python_cmd, f"verify_{self.op_name}", timeout=timeout*2) finally: try: os.chdir(original_cwd) @@ -260,6 +264,7 @@ class KernelVerifier: dsl=self.dsl, device_id=device_id, impl_func_name=self.impl_func_name, + impl_model_name=self.impl_model_name, backend=self.backend, arch=self.arch, warmup_times=warmup_times, diff --git a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 index 82a3d10e5..c147976fe 100644 --- a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 +++ b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template.j2 @@ -47,7 +47,9 @@ MS_TO_NP_DTYPE_MAP = { } {% endif %} -{% if "triton" in dsl %} +{% if "triton" in dsl and impl_model_name %} +from {{ op_name }}_triton import {{ impl_model_name }} +{% elif "triton" in dsl%} from {{ op_name }}_triton import {{ impl_func_name }} {% elif dsl == "swft" %} from {{ op_name }}_swft import {{ impl_func_name }} @@ -379,12 +381,16 @@ def verify_implementations(): def verify_single_case(inputs): """验证单个案例的公共逻辑""" - {% if backend == "ascend" %} - torch.npu.manual_seed(0) + {% if framework == "torch"%} + torch.manual_seed(0) {% endif %} - + # 运行框架实现 framework_output = framework_model(*inputs) + + {% if backend == "ascend" and framework == "torch" %} + torch.npu.manual_seed(0) + {% endif %} {% if dsl == "swft" %} # 运行SWFT实现 @@ -398,6 +404,11 @@ def verify_implementations(): # 加载SWFT输出 impl_output = load_binary_data(data_dir, framework_output) + {% elif "triton" in dsl and impl_model_name %} + # 运行Triton实现 + model = {{ impl_model_name }}(*init_params) + model = model.to(device) + impl_output = model(*inputs) {% elif dsl in ["triton", "cuda_c", "cpp"] %} # 运行Triton实现 impl_output = {{ impl_func_name }}(*inputs) @@ -466,7 +477,8 @@ def verify_implementations(): print(f"验证动态shape案例 {case_idx + 1}/{len(inputs_list)}") # 使用timeout装饰器包装整个验证过程 - @with_timeout({{ timeout }}) + timeout = {{ timeout }} + @with_timeout(timeout) def verify_case(): return verify_single_case(inputs) @@ -488,7 +500,8 @@ def verify_implementations(): {% endif %} # 使用timeout装饰器包装整个验证过程 - @with_timeout({{ timeout }}) + timeout = {{ timeout }} + @with_timeout(timeout) def verify_case(): return verify_single_case(inputs) diff --git a/aikg/python/ai_kernel_generator/utils/common_utils.py b/aikg/python/ai_kernel_generator/utils/common_utils.py index 76b9ec5f0..5ca554a1b 100644 --- a/aikg/python/ai_kernel_generator/utils/common_utils.py +++ b/aikg/python/ai_kernel_generator/utils/common_utils.py @@ -106,6 +106,7 @@ class ParserFactory: } _feature_parser = None _api_parser = None + _converter_parser = None _sketch_parser = None _conductor_parser = None @@ -226,6 +227,18 @@ class ParserFactory: ) return cls._feature_parser + @classmethod + def get_converter_parser(cls): + """获取转换的的解析器""" + if cls._converter_parser is None: + cls._converter_parser = cls.create_output_parser( + "ConverterBlock", + { + "torch_code": (str, ...) + } + ) + return cls._converter_parser + @classmethod def get_sketch_parser(cls): """获取Sketch解析器""" diff --git a/aikg/tools/triton2torch/default_triton2torch_config.yaml b/aikg/tools/triton2torch/default_triton2torch_config.yaml new file mode 100644 index 000000000..c8ef36738 --- /dev/null +++ b/aikg/tools/triton2torch/default_triton2torch_config.yaml @@ -0,0 +1,9 @@ +# Model preset configuration +agent_model_config: + triton2torch_converter: vllm_deepseek_v31_default + +# Log configuration +log_dir: "~/inductor_logs" + +# Verification configuration +verify_timeout: 300 # Timeout for verification in seconds (default 5 minutes) \ No newline at end of file diff --git a/aikg/tools/triton2torch/test_triton2torch_converter.py b/aikg/tools/triton2torch/test_triton2torch_converter.py new file mode 100644 index 000000000..f0addbcc5 --- /dev/null +++ b/aikg/tools/triton2torch/test_triton2torch_converter.py @@ -0,0 +1,61 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +from triton2torch_converter import CURR_DIR, Triton2TorchConverter +from ai_kernel_generator.core.verifier.kernel_verifier import KernelVerifier +from ai_kernel_generator.config.config_validator import ConfigValidator + +@pytest.mark.level0 +@pytest.mark.use_model +@pytest.mark.asyncio +async def test_triton2torch_converter(): + framework = "torch" + dsl = "triton" + backend = "ascend" + arch = "ascend910b4" + + validator = ConfigValidator(CURR_DIR / "default_triton2torch_config.yaml") + config = validator.config + validator.validate_llm_models() + validator.validate_log_dir() + + op_name = "op_name" + triton_code_path = f"path/to/{op_name}/{op_name}_triton.py" + with open(triton_code_path, "r", encoding="utf-8") as f: + triton_code = f.read() + + convertor = Triton2TorchConverter( + triton_code=triton_code, + model_config=config.get("agent_model_config", {}) + ) + convertor_content, _, _ = await convertor.run() + parsed_content = convertor.converter_parser.parse(convertor_content) + torch_code = parsed_content.torch_code + + impl_model_name = "Model" + verifier = KernelVerifier( + op_name=op_name, + framework_code=torch_code, + framework=framework, + dsl=dsl, + backend=backend, + arch=arch, + impl_model_name=impl_model_name, + config=config + ) + task_info = {"coder_code": triton_code} + verify_res, verify_log = verifier.run(task_info) + assert verify_res, f"验证失败: {verify_log}" diff --git a/aikg/tools/triton2torch/triton2torch_converter.py b/aikg/tools/triton2torch/triton2torch_converter.py new file mode 100644 index 000000000..cb398344e --- /dev/null +++ b/aikg/tools/triton2torch/triton2torch_converter.py @@ -0,0 +1,80 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Tuple +from pathlib import Path +from langchain.prompts import PromptTemplate +from ai_kernel_generator.utils.common_utils import ParserFactory, get_md5_hash +from ai_kernel_generator.core.agent.agent_base import AgentBase +from ai_kernel_generator import get_project_root + +logger = logging.getLogger(__name__) + +CURR_DIR = Path(get_project_root()).parent.parent / "tools" / "triton2torch" + +class Triton2TorchConverter(AgentBase): + def __init__(self, triton_code: str, model_config: dict): + """ + 将Triton代码转换为Torch代码 + + Args: + triton_code (str): Triton代码字符串 + model_config (dict): 模型配置字典 + """ + self.triton_code = triton_code + self.model_config = model_config + + context = { + "agent_name": "triton2torch_converter" + } + super().__init__(context=context) + + self.converter_parser = ParserFactory.get_converter_parser() + self.format_instructions = self.converter_parser.get_format_instructions() + + # 初始化转换模板 + template_str = self.read_file(str(CURR_DIR / "triton2torch_converter_template.j2")) + self.converter_prompt = PromptTemplate( + template=template_str, + template_format="jinja2" + ) + # 构建输入数据 + self.converter_input = { + "triton_code": self.triton_code, + "format_instructions": self.format_instructions, + } + + async def run(self) -> Tuple[str, str, str]: + """执行Triton代码到Torch代码的转换 + + Args: + task_info: 任务信息字典,包含当前所有代码和状态 + + Returns: + Tuple[str, str, str]: 转换后的Torch代码、提示信息和推理过程 + """ + try: + # 执行LLM生成前更新context + to_update_details = { + "agent_name": "triton2torch_converter", + "hash": get_md5_hash(triton_code=self.triton_code) + } + self.context.update(to_update_details) + + # 执行LLM生成 + return await self.run_llm(self.converter_prompt, self.converter_input, self.model_config.get("triton2torch_converter", "vllm_deepseek_r1_default")) + except Exception as e: + logger.error(f"Exception in Triton2TorchConverter.run: {type(e).__name__}: {e}") + raise diff --git a/aikg/tools/triton2torch/triton2torch_converter_template.j2 b/aikg/tools/triton2torch/triton2torch_converter_template.j2 new file mode 100644 index 000000000..eb62f446d --- /dev/null +++ b/aikg/tools/triton2torch/triton2torch_converter_template.j2 @@ -0,0 +1,87 @@ +# Triton 代码到 PyTorch 代码转换 + +你是一个专业的AI领域Kernel编写专家Agent,熟悉各种计算框架(如Torch)和DSL的Kernel(如Triton)编写。 +当前的任务是需要你将给定的Triton代码转换为计算等效的PyTorch代码。 + +## 当前需要转换的Triton代码 +```python +{{ triton_code }} +``` + +## 代码转换要求 + +1. **转换原则**: + - 精确转换:确保PyTorch代码与Triton代码在计算上完全等效 + - 可读性:生成清晰、易于理解和维护的PyTorch代码 + - 简洁性:避免不必要的复杂性 + - 实用性:优先使用PyTorch的内置函数和高级API,对于标准操作直接使用优化实现,如torch.nn.functional中的函数,避免重新实现底层逻辑 + +2. **转换要点**: + - **核函数转换**:将`@triton.jit`装饰的核函数转换为PyTorch的函数或模块 + - **内存管理**:将Triton的指针操作转换为PyTorch的张量操作 + - **输入输出**:必须保持输入输出数量、维度和数据类型的一致性 + 【注意事项】有时可能没有或无需输入,必须准确判断 + 【注意事项】输出可能有多个,各输出有有各自的计算逻辑,必须正确梳理 + 【注意事项】会出现一个tensor既是输入又是输出的情况,必须正确识别和处理 + - **计算操作**:将Triton代码的计算逻辑用PyTorch的函数接口和功能实现 + - **其他**:忽略Triton的grid/block等并行或切分调度策略,无需考虑并行策略,只需做计算逻辑的转换 + +3. **输出格式规范**: + - 生成完整的PyTorch代码,代码组织: + - 必要的导入语句 + - PyTorch模型类:该类必须命名为Model,继承自nn.Module,实现forward方法 + - get_inputs函数:创建模型的输入tensor,形状size和数据类型dtype均基于Triton代码信息,创建tensor不要指定后端device + - get_init_inputs函数:返回模型的初始化参数,如果没有则返回空列表 + - 代码应具有良好的注释,解释转换思路和关键优化点 + - 代码不应出现原始Triton代码的函数接口和功能 + - 生成的代码应能够直接运行和测试 + +## 示例代码 +```python +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + Simple model that performs sum reduction over a specified dimension. + """ + def __init__(self, dim: int): + """ + Initializes the model with the dimension to reduce over. + + Args: + dim (int): Dimension to reduce over. + """ + super(Model, self).__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies sum reduction over the specified dimension. + + Args: + x (torch.Tensor): Input tensor of shape (..., dim, ...). + + Returns: + torch.Tensor: Output tensor after sum reduction, shape (..., 1, ...). + """ + return torch.sum(x, dim=self.dim, keepdim=True) + +batch_size = 128 +dim1 = 4096 +dim2 = 4095 +reduce_dim = 1 + +def get_inputs(): + x = torch.rand(batch_size, dim1, dim2) + return [x] + +def get_init_inputs(): + return [reduce_dim] +``` +请参考上述PyTorch示例代码的格式、import、调用函数方式等,确保生成可直接接入验证流程的PyTorch代码。 + +**请尽可能使用中文进行思考分析** + +**请按照以下格式输出你的结果,仅返回json格式,不要在json外部有任何解释或补充说明:** +{{ format_instructions }} \ No newline at end of file -- Gitee