diff --git a/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/swft.py b/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/swft.py index 649ecab2a65420d5f1689adf5e1f14b11a98be0e..79748935a18101ace682d8b2be3cda386b85fbe3 100644 --- a/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/swft.py +++ b/aikg/python/ai_kernel_generator/core/verifier/adapters/dsl/swft.py @@ -28,7 +28,25 @@ class DSLAdapterSwft(DSLAdapter): def get_impl_import(self, op_name: str, impl_func_name: str) -> str: """Return implementation function import.""" - return f"from {op_name}_swft import {impl_func_name}\n" + return f"from {op_name}_swft import ModelNew\n" + + def create_impl_module(self, framework: str, + framework_adapter: Any, + init_params_var: str = "init_params", + device_var: str = "device") -> str: + """生成创建 impl_model 的代码(只实例化一次)。 + + Args: + framework: Framework name (torch, mindspore, numpy) + framework_adapter: Framework adapter instance + init_params_var: Variable name for init_params (default: "init_params") + device_var: Variable name for device (default: "device") + + Returns: + str: Code string to create impl_model + """ + code = f"impl_model = ModelNew(*{init_params_var})\n" + return code def call_impl(self, impl_func_name: str, inputs: str, device_id: int, framework_adapter: Any, op_name: str, @@ -50,7 +68,7 @@ class DSLAdapterSwft(DSLAdapter): gen_binary_data({inputs}, {framework_output}, data_dir) # 运行SWFT实现 - {impl_func_name}(device_id=int({device_id})) + impl_model(*{inputs}) # 加载SWFT输出 impl_output = load_binary_data(data_dir, {framework_output}) @@ -87,7 +105,7 @@ class DSLAdapterSwft(DSLAdapter): import time start_time = time.time() for _ in range({warmup + runs}): - {impl_func_name}(device_id=int({device_id})) + impl_model(*{inputs}) end_time = time.time() execution_time_ms = (end_time - start_time) * 1000 / {warmup + runs} # 转换为毫秒 method = "traditional_timing" diff --git a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 index 667462011d33a0d2c021f7a866bc8062e48afefe..ef7a90fa484bab0032df01285a594dd3becaab93 100644 --- a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 +++ b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 @@ -214,14 +214,14 @@ def verify_implementations(): {{ line }} {% endfor %} + # 运行原始实现 + framework_output = framework_model(*inputs_for_framework) + # 运行实现 (generated by DSLAdapter) {% for line in call_impl_code %} {{ line }} {% endfor %} - # 运行原始实现 - framework_output = framework_model(*inputs_for_framework) - if not isinstance(framework_output, (list, tuple)): framework_output = [framework_output] if not isinstance(impl_output, (list, tuple)):