From 311e0ba1032fe795177e488ae073180865eb82f3 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Wed, 3 Dec 2025 20:14:16 +0800 Subject: [PATCH 1/8] enable triton-cuda to triton-ascend --- .../examples/run_cuda_to_ascend_conversion.py | 571 ++++++++++++++++++ .../core/verifier/kernel_verifier.py | 183 ++++++ .../core/worker/interface.py | 22 + .../core/worker/local_worker.py | 90 +++ .../core/worker/remote_worker.py | 63 ++ .../kernel_verify_template_refactored.j2 | 58 +- aikg/python/ai_kernel_generator/server/app.py | 7 + .../ai_kernel_generator/server/job_manager.py | 145 ++++- .../ai_kernel_generator/worker/server.py | 51 ++ aikg/tests/st/test_reference_generation.py | 389 ++++++++++++ 10 files changed, 1544 insertions(+), 35 deletions(-) create mode 100644 aikg/examples/run_cuda_to_ascend_conversion.py create mode 100644 aikg/tests/st/test_reference_generation.py diff --git a/aikg/examples/run_cuda_to_ascend_conversion.py b/aikg/examples/run_cuda_to_ascend_conversion.py new file mode 100644 index 0000000000..4bf00c24aa --- /dev/null +++ b/aikg/examples/run_cuda_to_ascend_conversion.py @@ -0,0 +1,571 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CUDA-to-Ascend 转换示例 + +本示例演示如何使用服务化架构实现 Triton-CUDA 到 Triton-Ascend 的自动转换: +1. 注册两个 Remote Worker:CUDA (a100) 和 Ascend (ascend910b4) +2. 提交任务时指定 source_backend=cuda, backend=ascend +3. Server 自动在 CUDA Worker 上生成参考数据 +4. NPU Worker 使用参考数据验证转换后的代码 + +使用方式: + +方式1: 通过 Server API(推荐) + # GPU 机器上启动 CUDA Worker + ./scripts/server_related/start_worker_service.sh cuda a100 0 9001 + + # NPU 机器上启动 Ascend Worker + ./scripts/server_related/start_worker_service.sh ascend ascend910b4 0 9001 + + # Server 机器上启动 AIKG Server + ./scripts/server_related/start_server.sh 8000 + + # 注册 Workers + ./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://gpu-server:9001 cuda a100 + ./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://npu-server:9001 ascend ascend910b4 + + # 运行此脚本 + python examples/run_cuda_to_ascend_conversion.py --server http://localhost:8000 + +方式2: 直接使用 Remote Workers + export CUDA_WORKER_URL=http://cuda-server:9001 + export ASCEND_WORKER_URL=http://ascend-server:9002 + python examples/run_cuda_to_ascend_conversion.py --direct + +方式3: 快速验证模式(仅测试参考数据生成和传输,不调用 LLM) + export CUDA_WORKER_URL=http://cuda-server:9001 + export ASCEND_WORKER_URL=http://ascend-server:9002 + python examples/run_cuda_to_ascend_conversion.py --verify +""" + +import asyncio +import os +import sys +import time + +os.environ['AIKG_STREAM_OUTPUT'] = 'on' + + +def get_op_name(): + return 'relu' + + +def get_task_desc(): + """ + task_desc(纯 PyTorch 代码,用于生成参考数据和转换目标) + """ + return ''' +import torch +import torch.nn as nn + + +class Model(nn.Module): + """ + ReLU激活函数模型 + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + 计算ReLU激活函数 + Args: + x: 输入张量 + Returns: + ReLU激活后的张量 + """ + return torch.relu(x) + + +batch_size = 16 +dim = 16384 + + +def get_inputs(): + x = torch.randn(batch_size, dim, dtype=torch.float16) + return [x] + + +def get_init_inputs(): + return [] # No special initialization inputs needed +''' + + +async def run_via_server_api(server_url: str): + """ + 方式1: 通过 Server API 提交任务(推荐) + + Server 会自动处理: + 1. 在 CUDA Worker 上生成参考数据 + 2. 将参考数据传递给 Ascend Worker + 3. 执行 LLM 代码生成和验证 + """ + import httpx + + op_name = get_op_name() + task_desc = get_task_desc() + + print("=" * 60) + print("CUDA-to-Ascend 转换示例 (Server API 模式)") + print("=" * 60) + print(f"Server URL: {server_url}") + print() + + # 构建请求 + request_data = { + "op_name": op_name, + "task_desc": task_desc, + "job_type": "single", + "backend": "ascend", + "arch": "ascend910b4", + "dsl": "triton_ascend", + "framework": "torch", + "workflow": "coder_only_workflow", + # 关键:指定源后端,触发参考数据生成 + "source_backend": "cuda", + "source_arch": "a100", + } + + print("[Step 1] 提交任务到 Server...") + print(f" 算子: {op_name}") + print(f" 源后端: cuda (a100)") + print(f" 目标后端: ascend (ascend910b4)") + print() + + try: + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post( + f"{server_url}/api/v1/jobs/submit", + json=request_data + ) + response.raise_for_status() + result = response.json() + job_id = result.get("job_id") + print(f" ✓ 任务提交成功,Job ID: {job_id}") + except Exception as e: + print(f" ✗ 任务提交失败: {e}") + return + + # 轮询任务状态 + print() + print("[Step 2] 等待任务完成...") + + max_wait = 600 # 最多等待 10 分钟 + start_time = time.time() + + try: + async with httpx.AsyncClient(timeout=10) as client: + while time.time() - start_time < max_wait: + response = await client.get(f"{server_url}/api/v1/jobs/{job_id}/status") + response.raise_for_status() + status = response.json() + + job_status = status.get("status") + print(f" 状态: {job_status}", end="\r") + + if job_status in ["completed", "failed", "error"]: + print() + break + + await asyncio.sleep(2) + else: + print(f"\n ✗ 任务超时 ({max_wait}秒)") + return + except Exception as e: + print(f"\n ✗ 查询状态失败: {e}") + return + + # 显示结果 + print() + print("[结果]") + if job_status == "completed": + result_data = status.get("result", {}) + if result_data.get("success"): + print(f" ✓ {op_name} 转换成功!") + code = result_data.get("code", "") + if code: + print(f"\n[生成的 Triton-Ascend 代码]") + print("-" * 40) + print(code[:800] + ("..." if len(code) > 800 else "")) + else: + print(f" ✗ {op_name} 转换失败") + else: + print(f" ✗ 任务状态: {job_status}") + if status.get("error"): + print(f" 错误: {status.get('error')[:300]}...") + + +async def run_direct_with_workers(): + """ + 方式2: 直接使用 Remote Workers(不通过 Server) + + 手动处理参考数据生成和传递 + """ + from ai_kernel_generator.config.config_validator import load_config + from ai_kernel_generator.core.worker.manager import register_remote_worker, get_worker_manager + from ai_kernel_generator.core.verifier.kernel_verifier import KernelVerifier + from ai_kernel_generator.core.task import Task + from ai_kernel_generator.core.async_pool.task_pool import TaskPool + + op_name = get_op_name() + task_desc = get_task_desc() + + # 从环境变量获取 Worker URL + cuda_worker_url = os.environ.get("CUDA_WORKER_URL", "http://localhost:9001") + ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9002") + + print("=" * 60) + print("CUDA-to-Ascend 转换示例 (Direct Worker 模式)") + print("=" * 60) + print(f"CUDA Worker URL: {cuda_worker_url}") + print(f"Ascend Worker URL: {ascend_worker_url}") + print() + + # ========== 1. 注册 Remote Workers ========== + print("[Step 1] 注册 Workers...") + + try: + await register_remote_worker( + backend="cuda", + arch="a100", + worker_url=cuda_worker_url + ) + print(f" ✓ CUDA Worker 注册成功") + except Exception as e: + print(f" ✗ CUDA Worker 注册失败: {e}") + return + + try: + await register_remote_worker( + backend="ascend", + arch="ascend910b4", + worker_url=ascend_worker_url + ) + print(f" ✓ Ascend Worker 注册成功") + except Exception as e: + print(f" ✗ Ascend Worker 注册失败: {e}") + return + + worker_manager = get_worker_manager() + print() + + # ========== 2. 在 CUDA Worker 上生成参考数据 ========== + print("[Step 2] 在 CUDA Worker 上生成参考数据...") + + config = load_config("triton_cuda", backend="cuda") + + cuda_worker = await worker_manager.select(backend="cuda", arch="a100") + if not cuda_worker: + print(" ✗ 无法获取 CUDA Worker") + return + + try: + verifier = KernelVerifier( + op_name=op_name, + framework_code=task_desc, + task_id="gen_ref_001", + framework="torch", + dsl="triton_cuda", + backend="cuda", + arch="a100", + config=config, + worker=cuda_worker + ) + + success, log, ref_bytes = await verifier.generate_reference_data(task_desc, timeout=120) + + if not success: + print(f" ✗ 参考数据生成失败: {log[:200]}...") + return + + print(f" ✓ 参考数据生成成功 ({len(ref_bytes)} bytes)") + finally: + await worker_manager.release(cuda_worker) + + print() + + # ========== 3. 在 Ascend Worker 上执行转换 ========== + print("[Step 3] 在 Ascend Worker 上执行转换...") + + ascend_config = load_config("triton_ascend", backend="ascend") + + # 注入参考数据 + ascend_config['use_reference_data'] = True + ascend_config['reference_data'] = ref_bytes + + task_pool = TaskPool() + + task = Task( + op_name=op_name, + task_desc=task_desc, + task_id="convert_001", + dsl="triton_ascend", + backend="ascend", + arch="ascend910b4", + config=ascend_config, + framework="torch", + workflow="coder_only_workflow" + ) + + task_pool.create_task(task.run) + results = await task_pool.wait_all() + + print() + print("[结果]") + for result_op_name, success, task_info in results: + if success: + print(f" ✓ {result_op_name} 转换成功!") + if task_info.get("coder_code"): + print(f"\n[生成的 Triton-Ascend 代码]") + print("-" * 40) + code = task_info.get("coder_code", "") + print(code[:800] + ("..." if len(code) > 800 else "")) + else: + print(f" ✗ {result_op_name} 转换失败") + if task_info.get("verifier_error"): + print(f" 错误: {task_info.get('verifier_error', '')[:200]}...") + + +async def run_quick_verify(): + """ + 快速验证模式:仅测试参考数据生成和传输,不调用 LLM + + 流程: + 1. CUDA Worker 生成参考数据 (.pt) + 2. 传输 .pt 到 Ascend Worker + 3. Ascend Worker 执行验证(使用现有的 PyTorch 代码,跳过 base 执行) + """ + from ai_kernel_generator.config.config_validator import load_config + from ai_kernel_generator.core.worker.manager import register_remote_worker, get_worker_manager + from ai_kernel_generator.core.verifier.kernel_verifier import KernelVerifier + + op_name = get_op_name() + task_desc = get_task_desc() + + # 从环境变量获取 Worker URL + cuda_worker_url = os.environ.get("CUDA_WORKER_URL", "http://localhost:9001") + ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9002") + + print("=" * 60) + print("快速验证模式:测试参考数据生成和传输") + print("=" * 60) + print(f"CUDA Worker URL: {cuda_worker_url}") + print(f"Ascend Worker URL: {ascend_worker_url}") + print() + + # ========== 1. 注册 Workers ========== + print("[Step 1] 注册 Workers...") + + try: + await register_remote_worker( + backend="cuda", + arch="a100", + worker_url=cuda_worker_url + ) + print(f" ✓ CUDA Worker 注册成功") + except Exception as e: + print(f" ✗ CUDA Worker 注册失败: {e}") + return False + + try: + await register_remote_worker( + backend="ascend", + arch="ascend910b4", + worker_url=ascend_worker_url + ) + print(f" ✓ Ascend Worker 注册成功") + except Exception as e: + print(f" ✗ Ascend Worker 注册失败: {e}") + return False + + worker_manager = get_worker_manager() + print() + + # ========== 2. CUDA Worker 生成参考数据 ========== + print("[Step 2] CUDA Worker 生成参考数据...") + + cuda_config = load_config("triton_cuda", backend="cuda") + + cuda_worker = await worker_manager.select(backend="cuda", arch="a100") + if not cuda_worker: + print(" ✗ 无法获取 CUDA Worker") + return False + + try: + verifier = KernelVerifier( + op_name=op_name, + framework_code=task_desc, + task_id="quick_verify_001", + framework="torch", + dsl="triton_cuda", + backend="cuda", + arch="a100", + config=cuda_config, + worker=cuda_worker + ) + + success, log, ref_bytes = await verifier.generate_reference_data(task_desc, timeout=120) + + if not success: + print(f" ✗ 参考数据生成失败:") + print(f" {log[:300]}...") + return False + + print(f" ✓ 参考数据生成成功 ({len(ref_bytes)} bytes)") + + # 解析并显示参考数据信息 + import tempfile + import torch + with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f: + f.write(ref_bytes) + temp_path = f.name + + try: + ref_data = torch.load(temp_path) + print(f" 种子: {ref_data.get('seed', 'unknown')}") + print(f" 输出数量: {len(ref_data.get('outputs', []))}") + for i, out in enumerate(ref_data.get('outputs', [])): + if hasattr(out, 'shape'): + print(f" 输出[{i}]: shape={out.shape}, dtype={out.dtype}") + finally: + os.unlink(temp_path) + + finally: + await worker_manager.release(cuda_worker) + + print() + + # ========== 3. Ascend Worker 验证 ========== + print("[Step 3] Ascend Worker 使用参考数据验证...") + + ascend_config = load_config("triton_ascend", backend="ascend") + + # 注入参考数据 + ascend_config['use_reference_data'] = True + ascend_config['reference_data'] = ref_bytes + + ascend_worker = await worker_manager.select(backend="ascend", arch="ascend910b4") + if not ascend_worker: + print(" ✗ 无法获取 Ascend Worker") + return False + + try: + # 创建一个简单的验证:使用 PyTorch 原始代码作为 impl + # 这里只是验证参考数据传输和加载是否正常 + verifier = KernelVerifier( + op_name=op_name, + framework_code=task_desc, + task_id="quick_verify_002", + framework="torch", + dsl="triton_ascend", + backend="ascend", + arch="ascend910b4", + config=ascend_config, + worker=ascend_worker + ) + + # 构造 task_info,使用原始 PyTorch 代码作为 impl(只是为了验证流程) + # 这里的 coder_code 是一个简单的透传实现 + simple_impl_code = ''' +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + """透传实现,用于验证参考数据流程""" + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.relu(x) +''' + task_info = {'coder_code': simple_impl_code} + + # 运行验证 + verify_result, verify_log = await verifier.run(task_info, current_step=0) + + if verify_result: + print(" ✓ Ascend Worker 验证成功!") + print(" 参考数据传输和加载正常") + else: + print(" ✗ Ascend Worker 验证失败:") + print(f" {verify_log[:300]}...") + return False + + finally: + await worker_manager.release(ascend_worker) + + print() + print("=" * 60) + print("快速验证完成!参考数据生成和传输流程正常") + print("=" * 60) + return True + + +def print_usage(): + print(""" +用法: + python run_cuda_to_ascend_conversion.py --server + 通过 Server API 提交任务(推荐) + + python run_cuda_to_ascend_conversion.py --direct + 直接使用 Remote Workers(需设置 CUDA_WORKER_URL 和 ASCEND_WORKER_URL 环境变量) + + python run_cuda_to_ascend_conversion.py --verify + 快速验证模式:仅测试参考数据生成和传输,不调用 LLM + +示例: + # Server API 模式(参考 scripts/server_related/ 中的脚本启动服务) + python run_cuda_to_ascend_conversion.py --server http://localhost:8000 + + # Direct Worker 模式(需先启动 Worker Service) + # GPU: ./scripts/server_related/start_worker_service.sh cuda a100 0 9001 + # NPU: ./scripts/server_related/start_worker_service.sh ascend ascend910b4 0 9001 + export CUDA_WORKER_URL=http://gpu-server:9001 + export ASCEND_WORKER_URL=http://npu-server:9001 + python run_cuda_to_ascend_conversion.py --direct + + # 快速验证模式(仅测试参考数据生成和传输) + export CUDA_WORKER_URL=http://gpu-server:9001 + export ASCEND_WORKER_URL=http://npu-server:9001 + python run_cuda_to_ascend_conversion.py --verify +""") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print_usage() + sys.exit(1) + + if sys.argv[1] == "--server": + if len(sys.argv) < 3: + print("错误: 请指定 Server URL") + print_usage() + sys.exit(1) + server_url = sys.argv[2] + asyncio.run(run_via_server_api(server_url)) + + elif sys.argv[1] == "--direct": + asyncio.run(run_direct_with_workers()) + + elif sys.argv[1] == "--verify": + success = asyncio.run(run_quick_verify()) + sys.exit(0 if success else 1) + + else: + print(f"未知参数: {sys.argv[1]}") + print_usage() + sys.exit(1) + diff --git a/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py b/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py index 517ed094b3..ea9c5941ee 100644 --- a/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py +++ b/aikg/python/ai_kernel_generator/core/verifier/kernel_verifier.py @@ -320,6 +320,165 @@ if __name__ == "__main__": # 清理临时目录 shutil.rmtree(check_dir, ignore_errors=True) + async def generate_reference_data(self, task_desc: str, timeout: int = 120) -> Tuple[bool, str, bytes]: + """ + 在 GPU 上执行 task_desc 并生成参考数据 + + 用于 CUDA-to-Ascend 转换场景:在 GPU Worker 上执行 Triton-CUDA 代码, + 保存输出作为参考数据,供 NPU Worker 验证转换后的代码正确性。 + + Args: + task_desc: task_desc 代码字符串(Triton-CUDA 代码) + timeout: 超时时间 + + Returns: + Tuple[bool, str, bytes]: (是否成功, 日志, 参考数据bytes) + - 成功时 bytes 为 .pt 文件内容 + - 失败时 bytes 为空 b'' + """ + # 1. 创建临时目录 + ref_dir = os.path.join(os.path.expanduser(self.log_dir), f"{self.op_name}_gen_ref_{self.task_id}") + os.makedirs(ref_dir, exist_ok=True) + + try: + # 2. 写入 task_desc 到 reference.py + ref_file = os.path.join(ref_dir, "reference.py") + with open(ref_file, "w", encoding="utf-8") as f: + f.write(task_desc) + + # 3. 生成参考数据脚本 + # 使用固定 seed=0 确保可复现性 + gen_ref_script = f''' +import torch +import sys +import os + +# Add current directory to sys.path +sys.path.append(os.getcwd()) + +def generate_reference(): + print("Starting reference data generation...") + try: + # Import from reference + try: + from reference import Model, get_inputs, get_init_inputs + except ImportError as e: + print(f"Import failed: {{e}}") + return False + + print("Successfully imported Model and helper functions.") + + # Determine device + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif hasattr(torch, 'npu') and torch.npu.is_available(): + device = "npu" + + print(f"Using device: {{device}}") + + # Fixed seed for reproducibility + torch.manual_seed(0) + print("[INFO] Random seed: 0") + + # Instantiate model + try: + init_inputs = get_init_inputs() + model = Model(*init_inputs) + if device != "cpu": + model = model.to(device) + model.eval() + except Exception as e: + print(f"Model instantiation failed: {{e}}") + return False + + # Get inputs with fixed seed + torch.manual_seed(0) + try: + inputs = get_inputs() + if device != "cpu": + inputs = [inp.to(device) if isinstance(inp, torch.Tensor) else inp for inp in inputs] + except Exception as e: + print(f"get_inputs failed: {{e}}") + return False + + # Run forward pass + try: + with torch.no_grad(): + outputs = model(*inputs) + print("Forward pass successful.") + except Exception as e: + print(f"Forward pass failed: {{e}}") + return False + + # Ensure outputs is a list + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + + # Move to CPU for saving + outputs_cpu = [x.cpu() if isinstance(x, torch.Tensor) else x for x in outputs] + + # Save reference data + ref_data = {{ + 'op_name': '{self.op_name}', + 'seed': 0, + 'outputs': outputs_cpu, + 'output_shapes': [x.shape if isinstance(x, torch.Tensor) else None for x in outputs_cpu], + }} + + ref_file = os.path.join(os.getcwd(), "{self.op_name}_reference.pt") + torch.save(ref_data, ref_file) + print(f"[INFO] Reference data saved to: {{ref_file}}") + print(f"[INFO] Output count: {{len(outputs_cpu)}}") + for i, out in enumerate(outputs_cpu): + if isinstance(out, torch.Tensor): + print(f" Output[{{i}}]: shape={{out.shape}}, dtype={{out.dtype}}") + + return True + + except Exception as e: + print(f"Unexpected error: {{e}}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = generate_reference() + if success: + print("REFERENCE_GENERATION_SUCCESS") + sys.exit(0) + else: + print("REFERENCE_GENERATION_FAILED") + sys.exit(1) +''' + script_file = os.path.join(ref_dir, f"verify_{self.op_name}.py") + with open(script_file, "w", encoding="utf-8") as f: + f.write(gen_ref_script) + + # 4. 打包目录 + package_data = self._pack_directory(ref_dir) + + # 5. 使用 Worker.generate_reference 执行 + if not self.worker: + raise RuntimeError("Worker not set for reference generation") + + # 直接调用 Worker 的 generate_reference 方法 + # 该方法会执行脚本并返回 .pt 文件的 bytes + success, log, ref_bytes = await self.worker.generate_reference( + package_data, f"{self.task_id}_gen_ref", self.op_name, timeout + ) + + if not success: + return False, f"Reference generation failed:\n{log}", b'' + + return True, log, ref_bytes + + except Exception as e: + return False, f"Reference generation exception: {str(e)}", b'' + finally: + # 清理临时目录 + shutil.rmtree(ref_dir, ignore_errors=True) + def _create_verify_dir(self, step_counter) -> str: """创建验证目录并返回目录路径""" expanded_log_dir = os.path.expanduser(self.log_dir) @@ -468,6 +627,27 @@ if __name__ == "__main__": """生成验证项目文件到指定目录""" logger.info(f"[{self.op_name}] 开始生成验证项目,目录: {verify_dir}, device_id={device_id}") + # ========== 处理参考数据模式 ========== + use_reference_data = self.config.get('use_reference_data', False) + reference_file = None + + if use_reference_data: + reference_data_bytes = self.config.get('reference_data') + if reference_data_bytes: + # 将参考数据写入验证目录 + reference_file = os.path.join(verify_dir, f"{self.op_name}_reference.pt") + try: + with open(reference_file, 'wb') as f: + f.write(reference_data_bytes) + logger.info(f"[{self.op_name}] 参考数据已写入: {reference_file} ({len(reference_data_bytes)} bytes)") + except Exception as e: + logger.error(f"[{self.op_name}] 参考数据写入失败: {e}") + use_reference_data = False + reference_file = None + else: + logger.warning(f"[{self.op_name}] use_reference_data=True 但未找到 reference_data") + use_reference_data = False + # 创建框架实现文件 framework_file = os.path.join(verify_dir, f"{self.op_name}_{self.framework}.py") try: @@ -612,6 +792,9 @@ if __name__ == "__main__": arch=self.arch, is_dynamic_shape=is_dynamic_shape, timeout=self.config.get('verify_timeout', 300), + # 参考数据模式(用于跨后端转换场景) + use_reference_data=use_reference_data, + reference_file=reference_file, # Adapter生成的代码 framework_imports=self._prepare_code_lines(framework_imports), framework_model_import=self._prepare_code_lines(framework_model_import), diff --git a/aikg/python/ai_kernel_generator/core/worker/interface.py b/aikg/python/ai_kernel_generator/core/worker/interface.py index 39f97290cd..49a8957352 100644 --- a/aikg/python/ai_kernel_generator/core/worker/interface.py +++ b/aikg/python/ai_kernel_generator/core/worker/interface.py @@ -49,3 +49,25 @@ class WorkerInterface(ABC): - artifacts: 执行过程中生成的文件内容,格式为 {relative_path: json_content} """ pass + + @abstractmethod + async def generate_reference(self, package_data: bytes, task_id: str, op_name: str, timeout: int = 120) -> Tuple[bool, str, bytes]: + """ + Execute task_desc and generate reference data. + + 用于 CUDA-to-Ascend 转换场景:在 GPU Worker 上执行 Triton-CUDA 代码, + 保存输出作为参考数据(.pt 文件),供 NPU Worker 验证转换后的代码正确性。 + + Args: + package_data: The compressed project (TAR bytes) containing reference.py and verify script. + task_id: Unique task identifier. + op_name: Operator name. + timeout: Execution timeout in seconds. + + Returns: + Tuple[bool, str, bytes]: (success, log_output, reference_data_bytes) + - success: 是否成功生成参考数据 + - log_output: 执行日志 + - reference_data_bytes: .pt 文件的二进制内容(成功时),失败时为空 b'' + """ + pass diff --git a/aikg/python/ai_kernel_generator/core/worker/local_worker.py b/aikg/python/ai_kernel_generator/core/worker/local_worker.py index 469160cf77..bdddcaf061 100644 --- a/aikg/python/ai_kernel_generator/core/worker/local_worker.py +++ b/aikg/python/ai_kernel_generator/core/worker/local_worker.py @@ -302,3 +302,93 @@ class LocalWorker(WorkerInterface): except Exception as e: logger.error(f"[{task_id}] nsys profiling failed: {e}", exc_info=True) return float('inf'), float('inf') + + async def generate_reference(self, package_data: bytes, task_id: str, op_name: str, timeout: int = 120) -> Tuple[bool, str, bytes]: + """ + Execute task_desc and generate reference data locally. + + 用于 CUDA-to-Ascend 转换场景:执行 Triton-CUDA 代码,保存输出作为参考数据。 + + Args: + package_data: 验证包数据(bytes) + task_id: 任务ID + op_name: 算子名称 + timeout: 超时时间 + + Returns: + Tuple[bool, str, bytes]: (success, log, reference_data_bytes) + """ + try: + with tempfile.TemporaryDirectory() as temp_dir: + # Extract package + tar_path = os.path.join(temp_dir, "package.tar") + with open(tar_path, "wb") as f: + f.write(package_data) + + extract_dir = os.path.join(temp_dir, "extract") + os.makedirs(extract_dir, exist_ok=True) + + try: + with tarfile.open(tar_path, 'r') as tar_ref: + tar_ref.extractall(extract_dir) + except Exception as e: + return False, f"Failed to extract package: {e}", b'' + + # Find and run the verify script + script_name = f"verify_{op_name}.py" + script_path = os.path.join(extract_dir, script_name) + if not os.path.exists(script_path): + return False, f"Verification script {script_name} not found.", b'' + + env = os.environ.copy() + env['PYTHONUNBUFFERED'] = '1' + + python_exe = sys.executable + cmd = [python_exe, script_name] + logger.info(f"[{task_id}] Running reference generation for {op_name}") + + process = await asyncio.create_subprocess_exec( + *cmd, + cwd=extract_dir, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + try: + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) + returncode = process.returncode + + output_log = stdout.decode(errors='replace') + "\n" + stderr.decode(errors='replace') + success = (returncode == 0) + + if not success: + logger.error(f"[{task_id}] Reference generation failed with log:\n{output_log}") + return False, output_log, b'' + + # Check for success marker + if "REFERENCE_GENERATION_SUCCESS" not in output_log: + return False, f"Reference generation did not complete successfully:\n{output_log}", b'' + + # Read the generated .pt file + ref_file = os.path.join(extract_dir, f"{op_name}_reference.pt") + if not os.path.exists(ref_file): + return False, f"Reference file {ref_file} not found after generation.", b'' + + with open(ref_file, 'rb') as f: + ref_bytes = f.read() + + logger.info(f"[{task_id}] Reference generation succeeded, .pt file size: {len(ref_bytes)} bytes") + return True, output_log, ref_bytes + + except asyncio.TimeoutError: + try: + process.kill() + except ProcessLookupError: + pass + logger.error(f"[{task_id}] Reference generation timed out.") + return False, f"Reference generation timed out after {timeout} seconds.", b'' + + except Exception as e: + logger.error(f"[{task_id}] LocalWorker generate_reference failed: {e}", exc_info=True) + return False, str(e), b'' diff --git a/aikg/python/ai_kernel_generator/core/worker/remote_worker.py b/aikg/python/ai_kernel_generator/core/worker/remote_worker.py index 5db7d149df..e595defc46 100644 --- a/aikg/python/ai_kernel_generator/core/worker/remote_worker.py +++ b/aikg/python/ai_kernel_generator/core/worker/remote_worker.py @@ -146,3 +146,66 @@ class RemoteWorker(WorkerInterface): except Exception as e: logger.error(f"[{task_id}] Remote profiling failed: {e}") return {'artifacts': {}} + + async def generate_reference(self, package_data: bytes, task_id: str, op_name: str, timeout: int = 120) -> Tuple[bool, str, bytes]: + """ + Send reference generation task to remote worker. + + 用于 CUDA-to-Ascend 转换场景:在远程 GPU Worker 上执行 Triton-CUDA 代码, + 生成参考数据(.pt 文件)并返回其二进制内容。 + + Args: + package_data: 验证包数据(TAR bytes) + task_id: 任务ID + op_name: 算子名称 + timeout: 超时时间 + + Returns: + Tuple[bool, str, bytes]: (success, log, reference_data_bytes) + """ + import base64 + + generate_ref_url = f"{self.worker_url}/api/v1/generate_reference" + + try: + async with httpx.AsyncClient(timeout=timeout + 10) as client: + files = {'package': ('package.tar', package_data, 'application/x-tar')} + data = { + 'task_id': task_id, + 'op_name': op_name, + 'timeout': str(timeout) + } + + logger.info(f"[{task_id}] Sending generate_reference request to {generate_ref_url}") + + response = await client.post(generate_ref_url, files=files, data=data) + response.raise_for_status() + + result = response.json() + success = result.get('success', False) + log = result.get('log', '') + + if success: + # reference_data 以 base64 编码传输 + ref_data_b64 = result.get('reference_data', '') + if ref_data_b64: + ref_bytes = base64.b64decode(ref_data_b64) + logger.info(f"[{task_id}] Received reference data: {len(ref_bytes)} bytes") + return True, log, ref_bytes + else: + return False, f"No reference data in response:\n{log}", b'' + else: + return False, log, b'' + + except httpx.RequestError as e: + error_msg = f"Network error communicating with worker at {self.worker_url}: {e}" + logger.error(f"[{task_id}] {error_msg}") + return False, error_msg, b'' + except httpx.HTTPStatusError as e: + error_msg = f"Worker returned error status: {e.response.status_code} - {e.response.text}" + logger.error(f"[{task_id}] {error_msg}") + return False, error_msg, b'' + except Exception as e: + error_msg = f"Remote generate_reference failed: {e}" + logger.error(f"[{task_id}] {error_msg}") + return False, error_msg, b'' diff --git a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 index f1ff0a4a15..7f0c16973d 100644 --- a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 +++ b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 @@ -85,6 +85,29 @@ def verify_implementations(): arch = "{{ arch }}" # 硬件架构 dsl = "{{ dsl }}" # 实现方式 + # 参考数据模式配置(用于跨后端转换场景) + use_reference_data = {{ 'True' if use_reference_data else 'False' }} + reference_file = {% if reference_file %}"{{ reference_file }}"{% else %}None{% endif %} + + # 加载参考数据(如果启用) + reference_outputs = None + if use_reference_data and reference_file and os.path.exists(reference_file): + print(f"[INFO] 使用参考数据模式: {reference_file}") + {% if framework == "torch" %} + import torch + reference_data = torch.load(reference_file, map_location='cpu') + reference_outputs = reference_data.get('outputs', []) + print(f"[INFO] 参考数据种子: {reference_data.get('seed', 'unknown')}") + print(f"[INFO] 参考输出数量: {len(reference_outputs)}") + {% elif framework == "numpy" %} + import numpy as np + reference_data = np.load(reference_file, allow_pickle=True).item() + reference_outputs = reference_data.get('outputs', []) + print(f"[INFO] 参考输出数量: {len(reference_outputs)}") + {% endif %} + elif use_reference_data: + print(f"[WARN] 参考数据模式已启用,但参考文件不存在: {reference_file}") + # 设备设置 (generated by FrameworkAdapter) {% for line in device_setup_code %} {{ line }} @@ -95,8 +118,14 @@ def verify_implementations(): {{ line }} {% endfor %} - def verify_single_case(inputs_for_framework, inputs_for_impl): - """验证单个案例的公共逻辑""" + def verify_single_case(inputs_for_framework, inputs_for_impl, reference_output=None): + """验证单个案例的公共逻辑 + + Args: + inputs_for_framework: 框架模型的输入 + inputs_for_impl: 实现的输入 + reference_output: 参考输出(可选,用于跨后端转换场景) + """ # 设置随机种子 (generated by FrameworkAdapter) @@ -109,8 +138,14 @@ def verify_implementations(): {{ line }} {% endfor %} - # 运行原始实现 - framework_output = framework_model(*inputs_for_framework) + # 获取 framework 输出:使用参考数据或运行原始实现 + if reference_output is not None: + # 使用参考数据作为 ground truth(跨后端转换场景) + framework_output = reference_output + print("[INFO] 使用参考数据作为 ground truth(跳过 base model 执行)") + else: + # 运行原始实现 + framework_output = framework_model(*inputs_for_framework) if not isinstance(framework_output, (list, tuple)): framework_output = [framework_output] @@ -133,11 +168,11 @@ def verify_implementations(): return True, framework_output - def verify_with_timeout(inputs_for_framework, inputs_for_impl, timeout_msg=None, success_msg=None): + def verify_with_timeout(inputs_for_framework, inputs_for_impl, reference_output=None, timeout_msg=None, success_msg=None): """带超时控制的验证函数""" @with_timeout({{ timeout }}) def verify_case(): - return verify_single_case(inputs_for_framework, inputs_for_impl) + return verify_single_case(inputs_for_framework, inputs_for_impl, reference_output) try: verify_result, framework_output = verify_case() @@ -218,10 +253,21 @@ def verify_implementations(): inputs_for_framework = [process_input(x) for x in inputs_for_framework] {% endif %} + # 获取参考输出(如果使用参考数据模式) + ref_output_for_case = None + if use_reference_data and reference_outputs: + # 参考数据模式:使用预先生成的参考输出 + ref_output_for_case = reference_outputs + # 将参考输出移动到当前设备 + {% if framework == "torch" %} + ref_output_for_case = [x.to(device) if hasattr(x, 'to') else x for x in ref_output_for_case] + {% endif %} + # 使用带超时控制的验证函数 verify_result, framework_output = verify_with_timeout( inputs_for_framework, inputs_for_impl, + reference_output=ref_output_for_case, timeout_msg=f"静态shape验证超时({{timeout}}秒)" ) {% endif %} diff --git a/aikg/python/ai_kernel_generator/server/app.py b/aikg/python/ai_kernel_generator/server/app.py index 927e954c90..0e0f47bcef 100644 --- a/aikg/python/ai_kernel_generator/server/app.py +++ b/aikg/python/ai_kernel_generator/server/app.py @@ -23,6 +23,13 @@ class JobSubmitRequest(BaseModel): framework: str = "torch" workflow: Optional[str] = "coder_only_workflow" + # Cross-backend conversion params + # 当 source_backend 与 backend 不同时,表示跨平台转换场景 + # 例如: source_backend=cuda, backend=ascend 表示 Triton-CUDA 到 Triton-Ascend 转换 + source_backend: Optional[str] = None # 源后端(如 cuda) + source_arch: Optional[str] = None # 源架构(如 a100) + source_dsl: Optional[str] = None # 源 DSL(如 triton_cuda),不指定则根据 source_backend 推断 + # Evolve params max_rounds: int = 1 parallel_num: int = 1 diff --git a/aikg/python/ai_kernel_generator/server/job_manager.py b/aikg/python/ai_kernel_generator/server/job_manager.py index cc2aaa30f9..8aa0fa79cd 100644 --- a/aikg/python/ai_kernel_generator/server/job_manager.py +++ b/aikg/python/ai_kernel_generator/server/job_manager.py @@ -96,6 +96,8 @@ class ServerJobManager: raise ValueError(f"task_desc is required when submitting a job.\n\n{hint}") worker_manager = get_worker_manager() + + # 检查目标 backend 的 Worker 是否可用 worker_available = await worker_manager.has_worker( backend=backend, arch=arch @@ -117,6 +119,28 @@ class ServerJobManager: error_msg += "\nPlease register a compatible worker before submitting the job." raise RuntimeError(error_msg) + + # 检查是否是 CUDA-to-Ascend 转换场景 + source_backend = request_data.get("source_backend") + source_arch = request_data.get("source_arch") + + if source_backend and source_backend != backend: + # CUDA-to-Ascend 转换场景,需要检查 source_backend 的 Worker 是否可用 + source_worker_available = await worker_manager.has_worker( + backend=source_backend, + arch=source_arch + ) + if not source_worker_available: + all_workers = await worker_manager.get_status() + error_msg = f"No available worker found for source_backend='{source_backend}', source_arch='{source_arch}'.\n" + if all_workers: + worker_list_str = "\n".join([ + f"- backend='{w['backend']}', arch='{w['arch']}', capacity={w['capacity']}, tags={w['tags']}" + for w in all_workers + ]) + error_msg += f"Currently registered workers:\n{worker_list_str}" + error_msg += "\nPlease register a source backend worker for CUDA-to-Ascend conversion." + raise RuntimeError(error_msg) job_id = str(uuid.uuid4()) job_type = request_data.get("job_type", "single") @@ -143,6 +167,10 @@ class ServerJobManager: async def _check_task_desc_runtime_wrapper(self, job_id: str, data: dict, config: dict) -> bool: """ 在任务开始前执行运行时检查 + + 支持跨后端转换场景: + - 当 source_backend 与 backend 不同时,在源 Worker 上生成参考数据 + - 将参考数据存入 config['reference_data'],供目标 Worker 验证时使用 """ task_desc = data.get("task_desc", "") if not task_desc: @@ -151,42 +179,101 @@ class ServerJobManager: backend = data.get("backend") arch = data.get("arch") - - logger.info(f"[{job_id}] Starting runtime check for task description...") + source_backend = data.get("source_backend") + source_arch = data.get("source_arch") worker_manager = get_worker_manager() - # 获取 worker - worker = await worker_manager.select(backend=backend, arch=arch) - if not worker: - raise RuntimeError(f"No available worker found for runtime check (backend={backend}, arch={arch})") - - try: - # 创建 verifier - verifier = KernelVerifier( - op_name=data.get("op_name"), - framework_code=task_desc, # 这里的 framework_code 不重要,只要 task_desc 传对了 - task_id=job_id, - framework=data.get("framework"), - dsl=data.get("dsl"), - backend=backend, - arch=arch, - config=config, - worker=worker - ) + + # 检查是否需要生成参考数据(只要有 source_backend 就需要) + need_reference_data = (source_backend is not None and source_backend != backend) + + if need_reference_data: + # ========== 阶段1: 在源 Worker 上生成参考数据 ========== + logger.info(f"[{job_id}] Cross-backend conversion detected (source={source_backend} -> target={backend}). Generating reference data...") - # 执行检查 - valid, error = await verifier.check_task_desc_runtime(task_desc, timeout=60) + source_worker = await worker_manager.select(backend=source_backend, arch=source_arch) + if not source_worker: + raise RuntimeError(f"No available source worker found for reference generation (backend={source_backend}, arch={source_arch})") - if not valid: - hint = _get_task_desc_format_hint() - raise RuntimeError(f"Task description runtime check failed: {error}\n\n{hint}") + try: + # 根据 source_backend 决定 dsl + source_dsl = data.get("source_dsl") + if not source_dsl: + # 默认根据 source_backend 推断 + if source_backend == "cuda": + source_dsl = "triton_cuda" + elif source_backend == "ascend": + source_dsl = "triton_ascend" + else: + source_dsl = "triton" + + # 创建 verifier 用于生成参考数据 + verifier = KernelVerifier( + op_name=data.get("op_name"), + framework_code=task_desc, + task_id=job_id, + framework=data.get("framework"), + dsl=source_dsl, + backend=source_backend, + arch=source_arch, + config=config, + worker=source_worker + ) + + # 生成参考数据 + success, log, ref_bytes = await verifier.generate_reference_data(task_desc, timeout=120) + + if not success: + raise RuntimeError(f"Reference data generation failed on source worker:\n{log}") + + # 将参考数据存入 config + config['use_reference_data'] = True + config['reference_data'] = ref_bytes + logger.info(f"[{job_id}] Reference data generated successfully ({len(ref_bytes)} bytes)") - logger.info(f"[{job_id}] Task description runtime check passed.") + finally: + await worker_manager.release(source_worker) + + # ========== 阶段2: 不再需要在目标 Worker 上执行运行时检查 ========== + # 因为我们已经在源 Worker 上验证过 task_desc 可以正常运行 + logger.info(f"[{job_id}] Skipping target runtime check (reference data already generated)") return True + + else: + # ========== 普通场景: 在目标 Worker 上执行运行时检查 ========== + logger.info(f"[{job_id}] Starting runtime check for task description...") - finally: - # 释放 worker - await worker_manager.release(worker) + worker = await worker_manager.select(backend=backend, arch=arch) + if not worker: + raise RuntimeError(f"No available worker found for runtime check (backend={backend}, arch={arch})") + + try: + # 创建 verifier + verifier = KernelVerifier( + op_name=data.get("op_name"), + framework_code=task_desc, + task_id=job_id, + framework=data.get("framework"), + dsl=data.get("dsl"), + backend=backend, + arch=arch, + config=config, + worker=worker + ) + + # 执行检查 + valid, error = await verifier.check_task_desc_runtime(task_desc, timeout=60) + + if not valid: + hint = _get_task_desc_format_hint() + raise RuntimeError(f"Task description runtime check failed: {error}\n\n{hint}") + + logger.info(f"[{job_id}] Task description runtime check passed.") + return True + + finally: + # 释放 worker + await worker_manager.release(worker) async def _run_single_job(self, job_id: str, data: dict): self.jobs[job_id]["status"] = "running" diff --git a/aikg/python/ai_kernel_generator/worker/server.py b/aikg/python/ai_kernel_generator/worker/server.py index 0f7156f565..0c7d2c05ea 100644 --- a/aikg/python/ai_kernel_generator/worker/server.py +++ b/aikg/python/ai_kernel_generator/worker/server.py @@ -115,6 +115,57 @@ async def profile( logger.error(f"[{task_id}] Profiling request failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) +@app.post("/api/v1/generate_reference") +async def generate_reference( + package: Annotated[UploadFile, File(...)], + task_id: Annotated[str, Form(...)], + op_name: Annotated[str, Form(...)], + timeout: Annotated[int, Form(...)] = 120 +): + """ + Execute task_desc and generate reference data. + + 用于 CUDA-to-Ascend 转换场景:执行 Triton-CUDA 代码, + 保存输出作为参考数据(.pt 文件),并以 base64 编码返回。 + + Returns: + - success: 是否成功生成参考数据 + - log: 执行日志 + - reference_data: base64 编码的 .pt 文件内容 + """ + import base64 + + if worker is None: + raise HTTPException(status_code=503, detail="Worker not initialized") + + try: + logger.info(f"[{task_id}] Received generate_reference request for {op_name}") + + package_data = await package.read() + + success, log, ref_bytes = await worker.generate_reference( + package_data, task_id, op_name, timeout + ) + + if success: + # 以 base64 编码返回二进制数据 + ref_data_b64 = base64.b64encode(ref_bytes).decode('utf-8') + return { + "success": True, + "log": log, + "reference_data": ref_data_b64 + } + else: + return { + "success": False, + "log": log, + "reference_data": "" + } + + except Exception as e: + logger.error(f"[{task_id}] Generate reference request failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + @app.post("/api/v1/acquire_device") async def acquire_device( task_id: Annotated[str, Form(...)] diff --git a/aikg/tests/st/test_reference_generation.py b/aikg/tests/st/test_reference_generation.py new file mode 100644 index 0000000000..488738043f --- /dev/null +++ b/aikg/tests/st/test_reference_generation.py @@ -0,0 +1,389 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +测试参考数据生成功能 + +用于验证 CUDA-to-Ascend 转换场景中的参考数据生成和传输功能。 +在 CPU 后端上运行,验证基础功能的正确性。 +""" + +import pytest +import asyncio +import tarfile +import io +import os +import tempfile +import torch + +from ai_kernel_generator.core.worker.local_worker import LocalWorker +from ai_kernel_generator.core.async_pool.device_pool import DevicePool +from ai_kernel_generator.core.verifier.kernel_verifier import KernelVerifier + + +# 简单的 ReLU task_desc 用于测试 +RELU_TASK_DESC = ''' +import torch +import torch.nn as nn + +class Model(nn.Module): + """Simple ReLU model for testing.""" + def __init__(self): + super(Model, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.relu(x) + +batch_size = 4 +dim = 16 + +def get_inputs(): + torch.manual_seed(0) # 固定种子确保可复现 + x = torch.randn(batch_size, dim) + return [x] + +def get_init_inputs(): + return [] +''' + + +def create_reference_generation_package(op_name: str, task_desc: str) -> bytes: + """ + 创建参考数据生成的 TAR 包 + + Args: + op_name: 算子名称 + task_desc: task_desc 代码 + + Returns: + bytes: TAR 包数据 + """ + # 生成参考数据的脚本 + gen_ref_script = f''' +import torch +import sys +import os + +sys.path.append(os.getcwd()) + +def generate_reference(): + print("Starting reference data generation...") + try: + from reference import Model, get_inputs, get_init_inputs + print("Successfully imported Model and helper functions.") + + device = "cpu" + print(f"Using device: {{device}}") + + # Fixed seed + torch.manual_seed(0) + print("[INFO] Random seed: 0") + + # Instantiate model + init_inputs = get_init_inputs() + model = Model(*init_inputs) + model.eval() + + # Get inputs + torch.manual_seed(0) + inputs = get_inputs() + + # Run forward + with torch.no_grad(): + outputs = model(*inputs) + + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + + # Save reference data + ref_data = {{ + 'op_name': '{op_name}', + 'seed': 0, + 'outputs': outputs, + 'output_shapes': [x.shape if isinstance(x, torch.Tensor) else None for x in outputs], + }} + + ref_file = os.path.join(os.getcwd(), "{op_name}_reference.pt") + torch.save(ref_data, ref_file) + print(f"[INFO] Reference data saved to: {{ref_file}}") + print(f"[INFO] Output count: {{len(outputs)}}") + + return True + except Exception as e: + print(f"Error: {{e}}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = generate_reference() + if success: + print("REFERENCE_GENERATION_SUCCESS") + sys.exit(0) + else: + print("REFERENCE_GENERATION_FAILED") + sys.exit(1) +''' + + # 创建 TAR 包 + tar_buffer = io.BytesIO() + with tarfile.open(fileobj=tar_buffer, mode='w') as tar_file: + # 添加 reference.py + ref_info = tarfile.TarInfo(name="reference.py") + ref_bytes = task_desc.encode('utf-8') + ref_info.size = len(ref_bytes) + tar_file.addfile(tarinfo=ref_info, fileobj=io.BytesIO(ref_bytes)) + + # 添加 verify_{op_name}.py + script_info = tarfile.TarInfo(name=f"verify_{op_name}.py") + script_bytes = gen_ref_script.encode('utf-8') + script_info.size = len(script_bytes) + tar_file.addfile(tarinfo=script_info, fileobj=io.BytesIO(script_bytes)) + + return tar_buffer.getvalue() + + +class TestReferenceGeneration: + """测试参考数据生成功能""" + + @pytest.mark.asyncio + async def test_local_worker_generate_reference_success(self): + """测试 LocalWorker 成功生成参考数据""" + op_name = "test_relu" + package_data = create_reference_generation_package(op_name, RELU_TASK_DESC) + + # 创建 Worker + device_pool = DevicePool([0]) + worker = LocalWorker(device_pool, backend="cpu") + + # 生成参考数据 + task_id = "test_gen_ref_001" + success, log, ref_bytes = await worker.generate_reference( + package_data, task_id, op_name, timeout=30 + ) + + # 验证结果 + assert success is True, f"Reference generation failed: {log}" + assert "REFERENCE_GENERATION_SUCCESS" in log + assert len(ref_bytes) > 0, "Reference data bytes should not be empty" + + # 验证 .pt 文件内容 + with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f: + f.write(ref_bytes) + temp_path = f.name + + try: + ref_data = torch.load(temp_path) + assert 'op_name' in ref_data + assert 'seed' in ref_data + assert 'outputs' in ref_data + assert ref_data['seed'] == 0 + assert len(ref_data['outputs']) > 0 + + # 验证输出形状 + output = ref_data['outputs'][0] + assert output.shape == (4, 16) # batch_size=4, dim=16 + finally: + os.unlink(temp_path) + + @pytest.mark.asyncio + async def test_local_worker_generate_reference_invalid_task_desc(self): + """测试 LocalWorker 处理无效的 task_desc""" + op_name = "test_invalid" + invalid_task_desc = "this is not valid python code !!!" + package_data = create_reference_generation_package(op_name, invalid_task_desc) + + device_pool = DevicePool([0]) + worker = LocalWorker(device_pool, backend="cpu") + + success, log, ref_bytes = await worker.generate_reference( + package_data, "test_gen_ref_002", op_name, timeout=30 + ) + + assert success is False + assert ref_bytes == b'' + + @pytest.mark.asyncio + async def test_reference_data_reproducibility(self): + """测试参考数据的可复现性(使用相同 seed 应产生相同结果)""" + op_name = "test_repro" + package_data = create_reference_generation_package(op_name, RELU_TASK_DESC) + + device_pool = DevicePool([0]) + worker = LocalWorker(device_pool, backend="cpu") + + # 生成两次参考数据 + success1, log1, ref_bytes1 = await worker.generate_reference( + package_data, "test_repro_001", op_name, timeout=30 + ) + success2, log2, ref_bytes2 = await worker.generate_reference( + package_data, "test_repro_002", op_name, timeout=30 + ) + + assert success1 is True + assert success2 is True + + # 加载并比较输出 + with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f1: + f1.write(ref_bytes1) + path1 = f1.name + with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f2: + f2.write(ref_bytes2) + path2 = f2.name + + try: + data1 = torch.load(path1) + data2 = torch.load(path2) + + # 输出应该完全相同 + output1 = data1['outputs'][0] + output2 = data2['outputs'][0] + assert torch.allclose(output1, output2), "Outputs should be identical with same seed" + finally: + os.unlink(path1) + os.unlink(path2) + + +class TestReferenceDataTransfer: + """测试参考数据传输功能""" + + @pytest.mark.asyncio + async def test_reference_bytes_serialization(self): + """测试参考数据的序列化和反序列化""" + import base64 + + # 创建测试数据 + test_data = { + 'op_name': 'test_op', + 'seed': 0, + 'outputs': [torch.randn(4, 16)], + 'output_shapes': [(4, 16)], + } + + # 保存到 bytes + buffer = io.BytesIO() + torch.save(test_data, buffer) + original_bytes = buffer.getvalue() + + # Base64 编码(模拟 HTTP 传输) + encoded = base64.b64encode(original_bytes).decode('utf-8') + + # Base64 解码 + decoded_bytes = base64.b64decode(encoded) + + # 验证数据完整性 + assert original_bytes == decoded_bytes + + # 加载并验证内容 + buffer2 = io.BytesIO(decoded_bytes) + loaded_data = torch.load(buffer2) + + assert loaded_data['op_name'] == test_data['op_name'] + assert loaded_data['seed'] == test_data['seed'] + assert torch.allclose(loaded_data['outputs'][0], test_data['outputs'][0]) + + @pytest.mark.asyncio + async def test_config_reference_data_injection(self): + """测试将参考数据注入到 config 中""" + # 模拟生成参考数据 + op_name = "test_inject" + package_data = create_reference_generation_package(op_name, RELU_TASK_DESC) + + device_pool = DevicePool([0]) + worker = LocalWorker(device_pool, backend="cpu") + + success, log, ref_bytes = await worker.generate_reference( + package_data, "test_inject_001", op_name, timeout=30 + ) + + assert success is True + + # 模拟 JobManager 将参考数据注入 config + config = { + 'log_dir': '/tmp/aikg_test', + } + config['use_reference_data'] = True + config['reference_data'] = ref_bytes + + # 验证 config 中的数据 + assert config['use_reference_data'] is True + assert len(config['reference_data']) > 0 + + # 模拟 KernelVerifier 从 config 读取并写入文件 + with tempfile.TemporaryDirectory() as verify_dir: + ref_file = os.path.join(verify_dir, f"{op_name}_reference.pt") + with open(ref_file, 'wb') as f: + f.write(config['reference_data']) + + # 验证文件可以被正确加载 + assert os.path.exists(ref_file) + loaded = torch.load(ref_file) + assert 'outputs' in loaded + assert loaded['seed'] == 0 + + +class TestKernelVerifierGenerateReference: + """测试 KernelVerifier.generate_reference_data 方法""" + + @pytest.mark.asyncio + async def test_verifier_generate_reference_data(self): + """测试 KernelVerifier 的 generate_reference_data 方法""" + with tempfile.TemporaryDirectory() as log_dir: + config = {'log_dir': log_dir} + + # 创建 Worker + device_pool = DevicePool([0]) + worker = LocalWorker(device_pool, backend="cpu") + + # 创建 Verifier + verifier = KernelVerifier( + op_name="test_relu_verifier", + framework_code=RELU_TASK_DESC, + task_id="test_verifier_001", + framework="torch", + dsl="triton_cuda", # dsl 在这个测试中不重要 + backend="cpu", + arch="x86_64", + config=config, + worker=worker + ) + + # 生成参考数据 + success, log, ref_bytes = await verifier.generate_reference_data( + RELU_TASK_DESC, timeout=60 + ) + + # 验证成功 + assert success is True, f"generate_reference_data failed: {log}" + assert len(ref_bytes) > 0, "Reference bytes should not be empty" + + # 验证 .pt 文件内容 + with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f: + f.write(ref_bytes) + temp_path = f.name + + try: + ref_data = torch.load(temp_path) + assert 'op_name' in ref_data + assert ref_data['op_name'] == "test_relu_verifier" + assert 'outputs' in ref_data + assert len(ref_data['outputs']) > 0 + finally: + os.unlink(temp_path) + + +if __name__ == "__main__": + # 直接运行测试 + pytest.main([__file__, "-v", "-s"]) + -- Gitee From c8028a26da5376c71083ae88d702395ad442a867 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Wed, 3 Dec 2025 20:34:37 +0800 Subject: [PATCH 2/8] update cases --- .../examples/run_cuda_to_ascend_conversion.py | 68 +++++++++++++------ 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/aikg/examples/run_cuda_to_ascend_conversion.py b/aikg/examples/run_cuda_to_ascend_conversion.py index 4bf00c24aa..cca34364bb 100644 --- a/aikg/examples/run_cuda_to_ascend_conversion.py +++ b/aikg/examples/run_cuda_to_ascend_conversion.py @@ -13,7 +13,7 @@ # limitations under the License. """ -CUDA-to-Ascend 转换示例 +CUDA-to-Ascend 转换示例(多输出) 本示例演示如何使用服务化架构实现 Triton-CUDA 到 Triton-Ascend 的自动转换: 1. 注册两个 Remote Worker:CUDA (a100) 和 Ascend (ascend910b4) @@ -21,14 +21,24 @@ CUDA-to-Ascend 转换示例 3. Server 自动在 CUDA Worker 上生成参考数据 4. NPU Worker 使用参考数据验证转换后的代码 +示例算子:ReLU + Add 双输出 +- 输入: x, y 两个张量 +- 输出1: relu(x) +- 输出2: x + y + 使用方式: -方式1: 通过 Server API(推荐) - # GPU 机器上启动 CUDA Worker - ./scripts/server_related/start_worker_service.sh cuda a100 0 9001 - - # NPU 机器上启动 Ascend Worker - ./scripts/server_related/start_worker_service.sh ascend ascend910b4 0 9001 + +初始设置:启动多后端Workers + +# GPU 机器上启动 CUDA Worker +./scripts/server_related/start_worker_service.sh cuda a100 0 9001 + +# NPU 机器上启动 Ascend Worker +./scripts/server_related/start_worker_service.sh ascend ascend910b4 0 9001 + + +方式1: 通过 Server API # Server 机器上启动 AIKG Server ./scripts/server_related/start_server.sh 8000 @@ -60,34 +70,47 @@ os.environ['AIKG_STREAM_OUTPUT'] = 'on' def get_op_name(): - return 'relu' + return 'relu_add_dual_output' def get_task_desc(): """ task_desc(纯 PyTorch 代码,用于生成参考数据和转换目标) + + 本示例演示多输出场景: + - 输入: x, y 两个张量 + - 输出1: relu(x) + - 输出2: x + y """ return ''' import torch import torch.nn as nn +from typing import Tuple class Model(nn.Module): """ - ReLU激活函数模型 + ReLU + Add 双输出模型 + + 演示多输出场景的参考数据生成和验证 """ def __init__(self): super(Model, self).__init__() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - 计算ReLU激活函数 + 计算 ReLU 和 Add 操作,返回两个输出 + Args: - x: 输入张量 + x: 输入张量1 + y: 输入张量2 Returns: - ReLU激活后的张量 + out1: ReLU激活后的张量 relu(x) + out2: 相加后的张量 x + y """ - return torch.relu(x) + out1 = torch.relu(x) + out2 = x + y + return out1, out2 batch_size = 16 @@ -96,7 +119,8 @@ dim = 16384 def get_inputs(): x = torch.randn(batch_size, dim, dtype=torch.float16) - return [x] + y = torch.randn(batch_size, dim, dtype=torch.float16) + return [x, y] def get_init_inputs(): @@ -435,9 +459,10 @@ async def run_quick_verify(): try: ref_data = torch.load(temp_path) + outputs = ref_data.get('outputs', []) print(f" 种子: {ref_data.get('seed', 'unknown')}") - print(f" 输出数量: {len(ref_data.get('outputs', []))}") - for i, out in enumerate(ref_data.get('outputs', [])): + print(f" 输出数量: {len(outputs)} {'(多输出)' if len(outputs) > 1 else ''}") + for i, out in enumerate(outputs): if hasattr(out, 'shape'): print(f" 输出[{i}]: shape={out.shape}, dtype={out.dtype}") finally: @@ -482,14 +507,17 @@ async def run_quick_verify(): simple_impl_code = ''' import torch import torch.nn as nn +from typing import Tuple class ModelNew(nn.Module): - """透传实现,用于验证参考数据流程""" + """透传实现,用于验证参考数据流程(双输出)""" def __init__(self): super(ModelNew, self).__init__() - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.relu(x) + def forward(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + out1 = torch.relu(x) + out2 = x + y + return out1, out2 ''' task_info = {'coder_code': simple_impl_code} -- Gitee From 9932fcbb9c5281e19ccb877744d11b4ac23c5224 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Wed, 3 Dec 2025 20:46:10 +0800 Subject: [PATCH 3/8] add triton-cuda case --- .../examples/run_cuda_to_ascend_conversion.py | 104 +++++++++++++----- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/aikg/examples/run_cuda_to_ascend_conversion.py b/aikg/examples/run_cuda_to_ascend_conversion.py index cca34364bb..6459288747 100644 --- a/aikg/examples/run_cuda_to_ascend_conversion.py +++ b/aikg/examples/run_cuda_to_ascend_conversion.py @@ -21,10 +21,11 @@ CUDA-to-Ascend 转换示例(多输出) 3. Server 自动在 CUDA Worker 上生成参考数据 4. NPU Worker 使用参考数据验证转换后的代码 -示例算子:ReLU + Add 双输出 -- 输入: x, y 两个张量 +示例算子:ReLU + Add 双输出(Triton-CUDA 实现) +- 输入: x, y 两个张量,shape=(537, 32) - 输出1: relu(x) - 输出2: x + y +- grid=537, TILE=32 使用方式: @@ -32,13 +33,26 @@ CUDA-to-Ascend 转换示例(多输出) 初始设置:启动多后端Workers # GPU 机器上启动 CUDA Worker -./scripts/server_related/start_worker_service.sh cuda a100 0 9001 +./scripts/server_related/start_worker_service.sh cuda a100 0,1,2,3,4,5,6,7 9001 # NPU 机器上启动 Ascend Worker -./scripts/server_related/start_worker_service.sh ascend ascend910b4 0 9001 +./scripts/server_related/start_worker_service.sh ascend ascend910b4 0,1,2,3,4,5,6,7 9001 -方式1: 通过 Server API + +方式1: 快速验证模式(仅测试参考数据生成和传输,不调用 LLM) + export CUDA_WORKER_URL=http://cuda-server:9001 + export ASCEND_WORKER_URL=http://ascend-server:9002 + python examples/run_cuda_to_ascend_conversion.py --verify + + +方式2: 直接使用 Remote Workers + export CUDA_WORKER_URL=http://cuda-server:9001 + export ASCEND_WORKER_URL=http://ascend-server:9002 + python examples/run_cuda_to_ascend_conversion.py --direct + + +方式3: 通过 Server API # Server 机器上启动 AIKG Server ./scripts/server_related/start_server.sh 8000 @@ -50,15 +64,6 @@ CUDA-to-Ascend 转换示例(多输出) # 运行此脚本 python examples/run_cuda_to_ascend_conversion.py --server http://localhost:8000 -方式2: 直接使用 Remote Workers - export CUDA_WORKER_URL=http://cuda-server:9001 - export ASCEND_WORKER_URL=http://ascend-server:9002 - python examples/run_cuda_to_ascend_conversion.py --direct - -方式3: 快速验证模式(仅测试参考数据生成和传输,不调用 LLM) - export CUDA_WORKER_URL=http://cuda-server:9001 - export ASCEND_WORKER_URL=http://ascend-server:9002 - python examples/run_cuda_to_ascend_conversion.py --verify """ import asyncio @@ -75,51 +80,93 @@ def get_op_name(): def get_task_desc(): """ - task_desc(纯 PyTorch 代码,用于生成参考数据和转换目标) + task_desc(Triton-CUDA 代码,用于生成参考数据和转换目标) 本示例演示多输出场景: - - 输入: x, y 两个张量 + - 输入: x, y 两个张量,shape=(537, 32) - 输出1: relu(x) - 输出2: x + y + - grid=537, TILE=32 """ return ''' import torch import torch.nn as nn +import triton +import triton.language as tl from typing import Tuple +@triton.jit +def relu_add_kernel( + x_ptr, + y_ptr, + out1_ptr, + out2_ptr, + TILE: tl.constexpr, +): + """ + Triton-CUDA kernel: ReLU + Add 双输出 + 每个 program 处理一行(TILE 个元素) + """ + row_idx = tl.program_id(0) + + # 计算偏移 + offsets = row_idx * TILE + tl.arange(0, TILE) + + # 加载数据 + x = tl.load(x_ptr + offsets) + y = tl.load(y_ptr + offsets) + + # 计算 ReLU 和 Add + out1 = tl.maximum(x, 0.0) + out2 = x + y + + # 存储结果 + tl.store(out1_ptr + offsets, out1) + tl.store(out2_ptr + offsets, out2) + + class Model(nn.Module): """ - ReLU + Add 双输出模型 + ReLU + Add 双输出模型(Triton-CUDA 实现) 演示多输出场景的参考数据生成和验证 + grid=537, TILE=32 """ def __init__(self): super(Model, self).__init__() def forward(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - 计算 ReLU 和 Add 操作,返回两个输出 + 使用 Triton-CUDA kernel 计算 ReLU 和 Add 操作 Args: - x: 输入张量1 - y: 输入张量2 + x: 输入张量1, shape=(537, 32) + y: 输入张量2, shape=(537, 32) Returns: out1: ReLU激活后的张量 relu(x) out2: 相加后的张量 x + y """ - out1 = torch.relu(x) - out2 = x + y + N, TILE = x.shape + out1 = torch.empty_like(x) + out2 = torch.empty_like(x) + + grid = (N,) + relu_add_kernel[grid]( + x, y, out1, out2, + TILE=TILE, + ) + return out1, out2 -batch_size = 16 -dim = 16384 +N = 537 +TILE = 32 def get_inputs(): - x = torch.randn(batch_size, dim, dtype=torch.float16) - y = torch.randn(batch_size, dim, dtype=torch.float16) + x = torch.randn(N, TILE, dtype=torch.float16) + y = torch.randn(N, TILE, dtype=torch.float16) return [x, y] @@ -502,11 +549,12 @@ async def run_quick_verify(): worker=ascend_worker ) - # 构造 task_info,使用原始 PyTorch 代码作为 impl(只是为了验证流程) - # 这里的 coder_code 是一个简单的透传实现 + # 构造 task_info,使用 PyTorch 代码作为 impl(只是为了验证流程) + # 这里的 coder_code 是一个简单的透传实现,用于验证参考数据传输 simple_impl_code = ''' import torch import torch.nn as nn +import torch_npu from typing import Tuple class ModelNew(nn.Module): -- Gitee From 311b8d55f4d6445eb66110c32886d6d5f44d1b74 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Wed, 3 Dec 2025 21:24:43 +0800 Subject: [PATCH 4/8] restore aikg/tests/resources/relu_op/relu_triton_cuda.py --- .../resources/relu_op/relu_triton_cuda.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 aikg/tests/resources/relu_op/relu_triton_cuda.py diff --git a/aikg/tests/resources/relu_op/relu_triton_cuda.py b/aikg/tests/resources/relu_op/relu_triton_cuda.py new file mode 100644 index 0000000000..eda8a78898 --- /dev/null +++ b/aikg/tests/resources/relu_op/relu_triton_cuda.py @@ -0,0 +1,67 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def relu_kernel( + x_ptr, # 输入指针 + output_ptr, # 输出指针 + n_elements, # 总元素数 + BLOCK_SIZE: tl.constexpr, # 每个block处理的元素数 +): + # 获取程序ID + pid = tl.program_id(axis=0) + # 计算这个block的起始位置 + block_start = pid * BLOCK_SIZE + # 创建偏移量 + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # 创建掩码,确保不越界 + mask = offsets < n_elements + + # 加载输入数据 + x = tl.load(x_ptr + offsets, mask=mask) + + # 执行ReLU: max(0, x) + output = tl.maximum(x, 0.0) + + # 存储结果 + tl.store(output_ptr + offsets, output, mask=mask) + + +class ModelNew(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Triton ReLU + """ + x = x.contiguous() + n_elements = x.numel() + output = torch.empty_like(x, device=x.device) + + BLOCK_SIZE = 1024 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # 启动kernel + relu_kernel[grid]( + x, output, n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return output -- Gitee From f746cb1550ce66971ff0a6ade9523d7218501313 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Wed, 3 Dec 2025 21:43:33 +0800 Subject: [PATCH 5/8] bugfix --- aikg/examples/run_cuda_to_ascend_conversion.py | 8 ++++---- .../templates/kernel_verify_template_refactored.j2 | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/aikg/examples/run_cuda_to_ascend_conversion.py b/aikg/examples/run_cuda_to_ascend_conversion.py index 6459288747..30ff7c8f64 100644 --- a/aikg/examples/run_cuda_to_ascend_conversion.py +++ b/aikg/examples/run_cuda_to_ascend_conversion.py @@ -42,13 +42,13 @@ CUDA-to-Ascend 转换示例(多输出) 方式1: 快速验证模式(仅测试参考数据生成和传输,不调用 LLM) export CUDA_WORKER_URL=http://cuda-server:9001 - export ASCEND_WORKER_URL=http://ascend-server:9002 + export ASCEND_WORKER_URL=http://ascend-server:9001 python examples/run_cuda_to_ascend_conversion.py --verify 方式2: 直接使用 Remote Workers export CUDA_WORKER_URL=http://cuda-server:9001 - export ASCEND_WORKER_URL=http://ascend-server:9002 + export ASCEND_WORKER_URL=http://ascend-server:9001 python examples/run_cuda_to_ascend_conversion.py --direct @@ -296,7 +296,7 @@ async def run_direct_with_workers(): # 从环境变量获取 Worker URL cuda_worker_url = os.environ.get("CUDA_WORKER_URL", "http://localhost:9001") - ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9002") + ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9001") print("=" * 60) print("CUDA-to-Ascend 转换示例 (Direct Worker 模式)") @@ -428,7 +428,7 @@ async def run_quick_verify(): # 从环境变量获取 Worker URL cuda_worker_url = os.environ.get("CUDA_WORKER_URL", "http://localhost:9001") - ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9002") + ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9001") print("=" * 60) print("快速验证模式:测试参考数据生成和传输") diff --git a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 index 7f0c16973d..ff94cca04f 100644 --- a/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 +++ b/aikg/python/ai_kernel_generator/resources/templates/kernel_verify_template_refactored.j2 @@ -94,7 +94,6 @@ def verify_implementations(): if use_reference_data and reference_file and os.path.exists(reference_file): print(f"[INFO] 使用参考数据模式: {reference_file}") {% if framework == "torch" %} - import torch reference_data = torch.load(reference_file, map_location='cpu') reference_outputs = reference_data.get('outputs', []) print(f"[INFO] 参考数据种子: {reference_data.get('seed', 'unknown')}") -- Gitee From 552bc910d8b7a8f164f45134abdcbc32d151e400 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Thu, 4 Dec 2025 00:50:23 +0800 Subject: [PATCH 6/8] support both ipv4 and ipv6 --- aikg/python/ai_kernel_generator/server/app.py | 37 ++++++++++-- .../ai_kernel_generator/worker/server.py | 25 ++++++-- .../scripts/server_related/check_e2e_setup.sh | 21 ++++++- .../server_related/check_worker_health.sh | 18 +++++- .../register_worker_to_server.sh | 24 +++++++- .../server_related/setup_ssh_tunnel.sh | 51 +++++++++++++++-- aikg/scripts/server_related/start_server.sh | 27 ++++++++- .../start_server_with_local_worker.sh | 57 ++++++++++++++++--- .../server_related/start_worker_service.sh | 25 ++++++-- 9 files changed, 251 insertions(+), 34 deletions(-) diff --git a/aikg/python/ai_kernel_generator/server/app.py b/aikg/python/ai_kernel_generator/server/app.py index 0e0f47bcef..3cafdc2767 100644 --- a/aikg/python/ai_kernel_generator/server/app.py +++ b/aikg/python/ai_kernel_generator/server/app.py @@ -1,3 +1,4 @@ +import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional, List, Dict, Any @@ -65,14 +66,22 @@ async def get_job_status(job_id: str): raise HTTPException(status_code=404, detail="Job not found") return status +def _is_loopback_url(url: str) -> bool: + """ + 检查 URL 是否为本地回环地址。 + 支持 IPv4 (localhost, 127.0.0.1) 和 IPv6 ([::1])。 + """ + loopback_patterns = ["localhost", "127.0.0.1", "[::1]"] + return any(pattern in url for pattern in loopback_patterns) + @app.post("/api/v1/workers/register") async def register_worker(req: WorkerRegisterRequest): """Worker 注册接口""" logger.info(f"Registering worker: {req.url} ({req.backend}/{req.arch})") - # 简单的 URL 检查提示 - if "localhost" in req.url or "127.0.0.1" in req.url: - logger.warning(f"Worker registered with localhost URL: {req.url}. " + # 简单的 URL 检查提示 (支持 IPv4 和 IPv6 loopback) + if _is_loopback_url(req.url): + logger.warning(f"Worker registered with loopback URL: {req.url}. " "Ensure the Server can access this URL (e.g. they are on the same host).") worker = RemoteWorker(req.url) @@ -90,8 +99,28 @@ async def get_workers_status(): """查询所有 Worker 状态""" return await get_worker_manager().get_status() -def start_server(host="0.0.0.0", port=8000): + +def start_server(host: Optional[str] = None, port: Optional[int] = None): + """ + 启动 AIKG Server。 + + Args: + host: 监听地址。可从环境变量 AIKG_SERVER_HOST 设置。 + - IPv4: "0.0.0.0" (所有接口), "127.0.0.1" (本地) + - IPv6: "::" (所有接口,双栈), "::1" (本地) + 默认: "0.0.0.0" + port: 监听端口。可从环境变量 AIKG_SERVER_PORT 设置。 + 默认: 8000 + """ import uvicorn + + # 从环境变量读取配置,参数优先 + if host is None: + host = os.environ.get("AIKG_SERVER_HOST", "0.0.0.0") + if port is None: + port = int(os.environ.get("AIKG_SERVER_PORT", "8000")) + + logger.info(f"Starting AIKG Server on {host}:{port}") uvicorn.run(app, host=host, port=port) if __name__ == "__main__": diff --git a/aikg/python/ai_kernel_generator/worker/server.py b/aikg/python/ai_kernel_generator/worker/server.py index 0c7d2c05ea..dc8f906704 100644 --- a/aikg/python/ai_kernel_generator/worker/server.py +++ b/aikg/python/ai_kernel_generator/worker/server.py @@ -220,10 +220,27 @@ async def status(): # "available_devices": worker.device_pool.qsize() # DevicePool uses Queue } -def start_server(host="0.0.0.0", port=9001): + +def start_server(host: Optional[str] = None, port: Optional[int] = None): + """ + 启动 AIKG Worker Service。 + + Args: + host: 监听地址。可从环境变量 WORKER_HOST 设置。 + - IPv4: "0.0.0.0" (所有接口), "127.0.0.1" (本地) + - IPv6: "::" (所有接口,双栈), "::1" (本地) + 默认: "0.0.0.0" + port: 监听端口。可从环境变量 WORKER_PORT 设置。 + 默认: 9001 + """ + # 从环境变量读取配置,参数优先 + if host is None: + host = os.environ.get("WORKER_HOST", "0.0.0.0") + if port is None: + port = int(os.environ.get("WORKER_PORT", "9001")) + + logger.info(f"Starting Worker Service on {host}:{port}") uvicorn.run(app, host=host, port=port) if __name__ == "__main__": - port = int(os.environ.get("WORKER_PORT", 9001)) - start_server(port=port) - + start_server() diff --git a/aikg/scripts/server_related/check_e2e_setup.sh b/aikg/scripts/server_related/check_e2e_setup.sh index 50c851eae3..408816917c 100755 --- a/aikg/scripts/server_related/check_e2e_setup.sh +++ b/aikg/scripts/server_related/check_e2e_setup.sh @@ -1,7 +1,25 @@ #!/bin/bash # 检查 Client-Server-Worker 端到端环境是否就绪 +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# 通过环境变量或参数来指定 URL: +# - 参数方式: ./check_e2e_setup.sh +# - 环境变量: AIKG_SERVER_URL, AIKG_WORKER_URL +# +# IPv4 示例: +# ./check_e2e_setup.sh http://192.168.1.100:8000 http://192.168.1.100:9001 +# +# IPv6 示例 (注意 IPv6 地址需要用方括号包围): +# ./check_e2e_setup.sh http://[2001:db8::1]:8000 http://[2001:db8::1]:9001 +# 或者: +# export AIKG_SERVER_URL=http://[::1]:8000 +# export AIKG_WORKER_URL=http://[::1]:9001 +# ./check_e2e_setup.sh +# ======================================== -SERVER_URL=${1:-http://localhost:8000} +SERVER_URL=${1:-${AIKG_SERVER_URL:-http://localhost:8000}} WORKER_URL=${2:-${AIKG_WORKER_URL:-http://localhost:9001}} echo "==========================================" @@ -53,4 +71,3 @@ echo "" echo "==========================================" echo "✅ 环境检查通过!可以运行 Client 测试了" echo "==========================================" - diff --git a/aikg/scripts/server_related/check_worker_health.sh b/aikg/scripts/server_related/check_worker_health.sh index 9fe571846e..aeed1a43fd 100755 --- a/aikg/scripts/server_related/check_worker_health.sh +++ b/aikg/scripts/server_related/check_worker_health.sh @@ -1,5 +1,22 @@ #!/bin/bash # 检查 Worker Service 健康状态 +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# 通过环境变量或参数来指定 URL: +# - 参数方式: ./check_worker_health.sh +# - 环境变量: AIKG_WORKER_URL +# +# IPv4 示例: +# ./check_worker_health.sh http://192.168.1.100:9001 +# +# IPv6 示例 (注意 IPv6 地址需要用方括号包围): +# ./check_worker_health.sh http://[2001:db8::1]:9001 +# 或者: +# export AIKG_WORKER_URL=http://[::1]:9001 +# ./check_worker_health.sh +# ======================================== WORKER_URL=${1:-${AIKG_WORKER_URL:-http://localhost:9001}} @@ -20,4 +37,3 @@ else echo "Response: $body" exit 1 fi - diff --git a/aikg/scripts/server_related/register_worker_to_server.sh b/aikg/scripts/server_related/register_worker_to_server.sh index 6686f8f059..2465a75171 100755 --- a/aikg/scripts/server_related/register_worker_to_server.sh +++ b/aikg/scripts/server_related/register_worker_to_server.sh @@ -2,10 +2,29 @@ # 向 AIKG Server 注册 Worker Service # 用法: ./scripts/server_related/register_worker_to_server.sh [server_url] [worker_url] [backend] [arch] [capacity] # ./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://localhost:9001 cuda a100 1 -#./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://localhost:9001 ascend ascend910b4 1 +# ./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://localhost:9001 ascend ascend910b4 1 +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# 通过环境变量或参数来指定 URL: +# - 参数方式: ./register_worker_to_server.sh ... +# - 环境变量: AIKG_SERVER_URL, AIKG_WORKER_URL +# +# IPv4 示例: +# ./register_worker_to_server.sh http://192.168.1.100:8000 http://192.168.1.101:9001 cuda a100 1 +# +# IPv6 示例 (注意 IPv6 地址需要用方括号包围): +# ./register_worker_to_server.sh http://[2001:db8::1]:8000 http://[2001:db8::2]:9001 cuda a100 1 +# 或者: +# export AIKG_SERVER_URL=http://[::1]:8000 +# export AIKG_WORKER_URL=http://[::1]:9001 +# ./register_worker_to_server.sh "" "" cuda a100 1 +# ======================================== + set -e -SERVER_URL=${1:-http://localhost:8000} +SERVER_URL=${1:-${AIKG_SERVER_URL:-http://localhost:8000}} WORKER_URL=${2:-${AIKG_WORKER_URL:-http://localhost:9001}} BACKEND=${3:-cuda} ARCH=${4:-a100} @@ -39,4 +58,3 @@ echo "Worker 注册命令执行完成!" echo "" echo "📋 当前已注册的 Workers:" curl -s "$SERVER_URL/api/v1/workers/status" | python -m json.tool - diff --git a/aikg/scripts/server_related/setup_ssh_tunnel.sh b/aikg/scripts/server_related/setup_ssh_tunnel.sh index e0dadcbb4e..46ac43c1cf 100755 --- a/aikg/scripts/server_related/setup_ssh_tunnel.sh +++ b/aikg/scripts/server_related/setup_ssh_tunnel.sh @@ -2,6 +2,37 @@ # 建立 SSH 隧道连接到远程 Worker Service # 用法: ./scripts/server_related/setup_ssh_tunnel.sh [local_port] [remote_port] [ssh_host] [ssh_port] [ssh_user] # 或者通过环境变量设置: SSH_HOST, SSH_PORT, SSH_USER +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# SSH 本身支持 IPv4 和 IPv6,此脚本通过以下方式支持: +# +# SSH_HOST 可以是: +# - IPv4 地址: 192.168.1.100 +# - IPv6 地址: 2001:db8::1 +# - 主机名: server.example.com +# +# SSH 隧道绑定地址 (LOCAL_BIND_ADDR 环境变量): +# - IPv4 本地: 127.0.0.1 (默认) +# - IPv6 本地: ::1 +# - 所有接口: 0.0.0.0 或 :: +# +# 隧道远程端 (REMOTE_BIND_ADDR 环境变量): +# - 默认 localhost +# - IPv4: 127.0.0.1 +# - IPv6: ::1 +# +# IPv6 示例: +# export SSH_HOST=2001:db8::1 +# export LOCAL_BIND_ADDR=::1 +# export REMOTE_BIND_ADDR=::1 +# ./setup_ssh_tunnel.sh 9001 9001 +# +# 建立隧道后访问方式: +# - IPv4: http://127.0.0.1:$LOCAL_PORT +# - IPv6: http://[::1]:$LOCAL_PORT +# ======================================== LOCAL_PORT=${1:-9001} REMOTE_PORT=${2:-9001} @@ -9,25 +40,36 @@ SSH_HOST=${3:-${SSH_HOST}} SSH_PORT=${4:-${SSH_PORT:-22}} SSH_USER=${5:-${SSH_USER:-${USER}}} +# 本地和远程绑定地址,支持 IPv6 +LOCAL_BIND_ADDR=${LOCAL_BIND_ADDR:-localhost} +REMOTE_BIND_ADDR=${REMOTE_BIND_ADDR:-localhost} + # 检查必需的参数 if [ -z "$SSH_HOST" ]; then echo "错误: 未指定 SSH 主机地址" echo "用法: $0 [local_port] [remote_port] [ssh_host] [ssh_port] [ssh_user]" echo "或者设置环境变量: SSH_HOST, SSH_PORT, SSH_USER" + echo "" + echo "IPv6 示例:" + echo " export SSH_HOST=2001:db8::1" + echo " export LOCAL_BIND_ADDR=::1" + echo " $0 9001 9001" exit 1 fi echo "==========================================" echo "建立 SSH 隧道" echo "==========================================" -echo "本地端口: $LOCAL_PORT" -echo "远程端口: $REMOTE_PORT" +echo "本地绑定: $LOCAL_BIND_ADDR:$LOCAL_PORT" +echo "远程绑定: $REMOTE_BIND_ADDR:$REMOTE_PORT" echo "SSH 地址: $SSH_USER@$SSH_HOST:$SSH_PORT" echo "==========================================" echo "" echo "⚠️ 提示:" echo "1. 此脚本会建立 SSH 隧道(前台运行)" -echo "2. 隧道建立后,可以通过 http://localhost:$LOCAL_PORT 访问远程 Worker Service" +echo "2. 隧道建立后,可以通过以下地址访问远程 Worker Service:" +echo " - IPv4: http://127.0.0.1:$LOCAL_PORT" +echo " - IPv6: http://[::1]:$LOCAL_PORT" echo "3. 按 Ctrl+C 可以关闭隧道" echo "4. 需要输入 SSH 密码或使用 SSH 密钥认证" echo "" @@ -38,7 +80,7 @@ echo "" # -N: 不执行远程命令 # -L: 本地端口转发 # -o ServerAliveInterval=60: 保持连接活跃 -ssh -N -L ${LOCAL_PORT}:localhost:${REMOTE_PORT} \ +ssh -N -L ${LOCAL_BIND_ADDR}:${LOCAL_PORT}:${REMOTE_BIND_ADDR}:${REMOTE_PORT} \ -p ${SSH_PORT} \ ${SSH_USER}@${SSH_HOST} \ -o ServerAliveInterval=60 \ @@ -46,4 +88,3 @@ ssh -N -L ${LOCAL_PORT}:localhost:${REMOTE_PORT} \ echo "" echo "SSH 隧道已关闭" - diff --git a/aikg/scripts/server_related/start_server.sh b/aikg/scripts/server_related/start_server.sh index 1fe23b3333..71b71f5a54 100755 --- a/aikg/scripts/server_related/start_server.sh +++ b/aikg/scripts/server_related/start_server.sh @@ -2,22 +2,43 @@ # 启动 AIKG Server # 用法: ./scripts/server_related/start_server.sh [port] # ./scripts/server_related/start_server.sh 8000 +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# 通过环境变量 AIKG_SERVER_HOST 来控制监听地址: +# - IPv4 监听所有接口: export AIKG_SERVER_HOST=0.0.0.0 (默认) +# - IPv6 监听所有接口: export AIKG_SERVER_HOST=:: +# - IPv4 本地回环: export AIKG_SERVER_HOST=127.0.0.1 +# - IPv6 本地回环: export AIKG_SERVER_HOST=::1 +# - 指定 IPv4 地址: export AIKG_SERVER_HOST=192.168.1.100 +# - 指定 IPv6 地址: export AIKG_SERVER_HOST=2001:db8::1 +# +# 注意: 使用 :: 可以同时监听 IPv4 和 IPv6 (dual-stack),但需要操作系统支持 +# ======================================== + set -e -PORT=${1:-8000} +# 从环境变量获取 host,默认为 0.0.0.0 (IPv4 全接口) +HOST=${AIKG_SERVER_HOST:-0.0.0.0} +PORT=${1:-${AIKG_SERVER_PORT:-8000}} echo "==========================================" echo "启动 AIKG Server" echo "==========================================" +echo "Host: $HOST" echo "Port: $PORT" echo "==========================================" cd "$(dirname "$0")/../.." source env.sh -echo "Starting AIKG Server on port $PORT..." +echo "Starting AIKG Server on $HOST:$PORT..." # python -m ai_kernel_generator.server.app # 使用 uvicorn 直接启动以支持自定义端口 -uvicorn ai_kernel_generator.server.app:app --host 0.0.0.0 --port $PORT +# 设置环境变量供 Python 代码使用 +export AIKG_SERVER_HOST=$HOST +export AIKG_SERVER_PORT=$PORT +uvicorn ai_kernel_generator.server.app:app --host "$HOST" --port "$PORT" diff --git a/aikg/scripts/server_related/start_server_with_local_worker.sh b/aikg/scripts/server_related/start_server_with_local_worker.sh index 93f19fa922..322818ad38 100755 --- a/aikg/scripts/server_related/start_server_with_local_worker.sh +++ b/aikg/scripts/server_related/start_server_with_local_worker.sh @@ -3,18 +3,51 @@ # 用法: ./scripts/server_related/start_server_with_local_worker.sh [server_port] [worker_port] [backend] [arch] [devices] # 示例: ./scripts/server_related/start_server_with_local_worker.sh 8000 9001 ascend ascend910b4 0,1,2,3 # 示例: ./scripts/server_related/start_server_with_local_worker.sh 8000 9001 cuda a100 0,1,2,3 +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# 通过环境变量来控制监听地址和 URL 格式: +# +# 监听地址环境变量: +# - AIKG_SERVER_HOST: Server 监听地址,默认 0.0.0.0 +# - AIKG_WORKER_HOST: Worker 监听地址,默认 0.0.0.0 +# +# URL 环境变量 (用于服务发现和注册): +# - AIKG_SERVER_URL: Server 的访问地址,默认 http://localhost:$SERVER_PORT +# - AIKG_WORKER_URL: Worker 的访问地址,默认 http://localhost:$WORKER_PORT +# +# IPv4 示例: +# export AIKG_SERVER_HOST=0.0.0.0 +# export AIKG_SERVER_URL=http://192.168.1.100:8000 +# +# IPv6 示例 (注意 IPv6 地址需要用方括号包围): +# export AIKG_SERVER_HOST=:: +# export AIKG_SERVER_URL=http://[2001:db8::1]:8000 +# export AIKG_WORKER_HOST=:: +# export AIKG_WORKER_URL=http://[2001:db8::1]:9001 +# +# 双栈模式: +# 使用 :: 作为 host 可以同时监听 IPv4 和 IPv6 +# ======================================== set -e # 参数处理 -SERVER_PORT=${1:-8000} -WORKER_PORT=${2:-9001} +SERVER_PORT=${1:-${AIKG_SERVER_PORT:-8000}} +WORKER_PORT=${2:-${AIKG_WORKER_PORT:-9001}} BACKEND=${3:-cuda} ARCH=${4:-a100} DEVICES=${5:-0} -SERVER_URL="http://localhost:$SERVER_PORT" -WORKER_URL="http://localhost:$WORKER_PORT" +# 从环境变量获取 host,默认为 0.0.0.0 +SERVER_HOST=${AIKG_SERVER_HOST:-0.0.0.0} +WORKER_HOST=${AIKG_WORKER_HOST:-0.0.0.0} + +# URL 配置 - 支持从环境变量覆盖,以支持 IPv6 或自定义地址 +# 默认使用 localhost,IPv6 场景需要通过环境变量设置如 http://[::1]:8000 +SERVER_URL=${AIKG_SERVER_URL:-http://localhost:$SERVER_PORT} +WORKER_URL=${AIKG_WORKER_URL:-http://localhost:$WORKER_PORT} # 计算 capacity (device 数量) IFS=',' read -ra DEVICE_ARRAY <<< "$DEVICES" @@ -23,8 +56,12 @@ CAPACITY=${#DEVICE_ARRAY[@]} echo "==========================================" echo "启动 AIKG Server 和 Local Worker (全自动)" echo "==========================================" +echo "Server Host: $SERVER_HOST" echo "Server Port: $SERVER_PORT" +echo "Server URL: $SERVER_URL" +echo "Worker Host: $WORKER_HOST" echo "Worker Port: $WORKER_PORT" +echo "Worker URL: $WORKER_URL" echo "Backend: $BACKEND" echo "Arch: $ARCH" echo "Devices: $DEVICES (Capacity: $CAPACITY)" @@ -51,9 +88,12 @@ cleanup() { trap cleanup SIGINT SIGTERM # 1. 启动 Server -echo "🚀 Starting Server on port $SERVER_PORT..." -# 使用 uvicorn 启动 Server -uvicorn ai_kernel_generator.server.app:app --host 0.0.0.0 --port $SERVER_PORT > server.log 2>&1 & +echo "🚀 Starting Server on $SERVER_HOST:$SERVER_PORT..." +# 设置环境变量供 Python 代码使用 +export AIKG_SERVER_HOST=$SERVER_HOST +export AIKG_SERVER_PORT=$SERVER_PORT + +uvicorn ai_kernel_generator.server.app:app --host "$SERVER_HOST" --port $SERVER_PORT > server.log 2>&1 & SERVER_PID=$! echo "Server PID: $SERVER_PID" @@ -76,11 +116,12 @@ done echo " Server is UP!" # 2. 启动 Worker -echo "🚀 Starting Worker on port $WORKER_PORT..." +echo "🚀 Starting Worker on $WORKER_HOST:$WORKER_PORT..." export WORKER_BACKEND=$BACKEND export WORKER_ARCH=$ARCH export WORKER_DEVICES=$DEVICES export WORKER_PORT=$WORKER_PORT +export WORKER_HOST=$WORKER_HOST # 使用 python -m 启动 Worker python -m ai_kernel_generator.worker.server > worker.log 2>&1 & diff --git a/aikg/scripts/server_related/start_worker_service.sh b/aikg/scripts/server_related/start_worker_service.sh index 7d38c760ff..ab19767520 100755 --- a/aikg/scripts/server_related/start_worker_service.sh +++ b/aikg/scripts/server_related/start_worker_service.sh @@ -4,6 +4,20 @@ # devices: 逗号分隔的设备ID列表,例如 "0,1,2,3,4,5" 或单个设备 "0" # ./scripts/server_related/start_worker_service.sh cuda a100 0,1,2,3 9001 # ./scripts/server_related/start_worker_service.sh ascend ascend910b4 0,1,2,3 9001 +# +# ======================================== +# IPv4/IPv6 配置说明: +# ======================================== +# 通过环境变量 AIKG_WORKER_HOST 来控制监听地址: +# - IPv4 监听所有接口: export AIKG_WORKER_HOST=0.0.0.0 (默认) +# - IPv6 监听所有接口: export AIKG_WORKER_HOST=:: +# - IPv4 本地回环: export AIKG_WORKER_HOST=127.0.0.1 +# - IPv6 本地回环: export AIKG_WORKER_HOST=::1 +# - 指定 IPv4 地址: export AIKG_WORKER_HOST=192.168.1.100 +# - 指定 IPv6 地址: export AIKG_WORKER_HOST=2001:db8::1 +# +# 注意: 使用 :: 可以同时监听 IPv4 和 IPv6 (dual-stack),但需要操作系统支持 +# ======================================== set -e @@ -12,11 +26,14 @@ set -e BACKEND=${1:-cuda} ARCH=${2:-a100} DEVICES=${3:-0} -PORT=${4:-9001} +PORT=${4:-${AIKG_WORKER_PORT:-9001}} +# 从环境变量获取 host,默认为 0.0.0.0 (IPv4 全接口) +HOST=${AIKG_WORKER_HOST:-0.0.0.0} echo "==========================================" echo "启动 AIKG Worker Service" echo "==========================================" +echo "Host: $HOST" echo "Backend: $BACKEND" echo "Arch: $ARCH" echo "Devices: $DEVICES" @@ -28,14 +45,14 @@ export WORKER_BACKEND=$BACKEND export WORKER_ARCH=$ARCH export WORKER_DEVICES=$DEVICES export WORKER_PORT=$PORT +export WORKER_HOST=$HOST # 启动服务 cd "$(dirname "$0")/../.." source env.sh -echo "Starting Worker Service on port $PORT..." +echo "Starting Worker Service on $HOST:$PORT..." python -m ai_kernel_generator.worker.server # 或者使用 uvicorn 直接启动(更多控制选项) -# uvicorn ai_kernel_generator.worker.server:app --host 0.0.0.0 --port $PORT - +# uvicorn ai_kernel_generator.worker.server:app --host "$HOST" --port $PORT -- Gitee From 244be753f9060d1d90c8ff83dacb616fba67d822 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Thu, 4 Dec 2025 00:59:50 +0800 Subject: [PATCH 7/8] add num-concurrent for run_cuda_to_ascend_conversion --- .../examples/run_cuda_to_ascend_conversion.py | 83 ++++++++++++------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/aikg/examples/run_cuda_to_ascend_conversion.py b/aikg/examples/run_cuda_to_ascend_conversion.py index 30ff7c8f64..bb0f9ba25d 100644 --- a/aikg/examples/run_cuda_to_ascend_conversion.py +++ b/aikg/examples/run_cuda_to_ascend_conversion.py @@ -49,7 +49,7 @@ CUDA-to-Ascend 转换示例(多输出) 方式2: 直接使用 Remote Workers export CUDA_WORKER_URL=http://cuda-server:9001 export ASCEND_WORKER_URL=http://ascend-server:9001 - python examples/run_cuda_to_ascend_conversion.py --direct + python examples/run_cuda_to_ascend_conversion.py --direct --num-concurrent 2 方式3: 通过 Server API @@ -279,17 +279,21 @@ async def run_via_server_api(server_url: str): print(f" 错误: {status.get('error')[:300]}...") -async def run_direct_with_workers(): +async def run_direct_with_workers(num_concurrent: int = 4): """ 方式2: 直接使用 Remote Workers(不通过 Server) 手动处理参考数据生成和传递 + + Args: + num_concurrent: 同一 task 的并发数量,默认为 4,用于更快找到正确解 """ from ai_kernel_generator.config.config_validator import load_config from ai_kernel_generator.core.worker.manager import register_remote_worker, get_worker_manager from ai_kernel_generator.core.verifier.kernel_verifier import KernelVerifier from ai_kernel_generator.core.task import Task from ai_kernel_generator.core.async_pool.task_pool import TaskPool + from tests.utils import process_task_results op_name = get_op_name() task_desc = get_task_desc() @@ -369,7 +373,7 @@ async def run_direct_with_workers(): print() # ========== 3. 在 Ascend Worker 上执行转换 ========== - print("[Step 3] 在 Ascend Worker 上执行转换...") + print(f"[Step 3] 在 Ascend Worker 上执行转换({num_concurrent} 个并发)...") ascend_config = load_config("triton_ascend", backend="ascend") @@ -377,37 +381,38 @@ async def run_direct_with_workers(): ascend_config['use_reference_data'] = True ascend_config['reference_data'] = ref_bytes - task_pool = TaskPool() - - task = Task( - op_name=op_name, - task_desc=task_desc, - task_id="convert_001", - dsl="triton_ascend", - backend="ascend", - arch="ascend910b4", - config=ascend_config, - framework="torch", - workflow="coder_only_workflow" - ) - - task_pool.create_task(task.run) + task_pool = TaskPool(max_concurrency=num_concurrent) + + # 创建多个相同的 Task 并发运行,更快找到正确解 + for i in range(num_concurrent): + task = Task( + op_name=op_name, + task_desc=task_desc, + task_id=f"convert_{i:03d}", + dsl="triton_ascend", + backend="ascend", + arch="ascend910b4", + config=ascend_config, + framework="torch", + workflow="coder_only_workflow" + ) + task_pool.create_task(task.run) + results = await task_pool.wait_all() + # 使用通用的结果处理函数打印结果 print() print("[结果]") - for result_op_name, success, task_info in results: - if success: - print(f" ✓ {result_op_name} 转换成功!") - if task_info.get("coder_code"): - print(f"\n[生成的 Triton-Ascend 代码]") - print("-" * 40) - code = task_info.get("coder_code", "") - print(code[:800] + ("..." if len(code) > 800 else "")) - else: - print(f" ✗ {result_op_name} 转换失败") - if task_info.get("verifier_error"): - print(f" 错误: {task_info.get('verifier_error', '')[:200]}...") + success = process_task_results(results, print_summary=True) + + # 打印成功的代码 + for result_op_name, task_success, task_info in results: + if task_success and task_info.get("coder_code"): + print(f"\n[生成的 Triton-Ascend 代码 - {result_op_name}]") + print("-" * 40) + code = task_info.get("coder_code", "") + print(code[:800] + ("..." if len(code) > 800 else "")) + break # 只打印第一个成功的代码 async def run_quick_verify(): @@ -596,8 +601,9 @@ def print_usage(): python run_cuda_to_ascend_conversion.py --server 通过 Server API 提交任务(推荐) - python run_cuda_to_ascend_conversion.py --direct + python run_cuda_to_ascend_conversion.py --direct [--num-concurrent N] 直接使用 Remote Workers(需设置 CUDA_WORKER_URL 和 ASCEND_WORKER_URL 环境变量) + --num-concurrent N: 指定同一 task 的并发数量(默认 4),用于更快找到正确解 python run_cuda_to_ascend_conversion.py --verify 快速验证模式:仅测试参考数据生成和传输,不调用 LLM @@ -613,6 +619,9 @@ def print_usage(): export ASCEND_WORKER_URL=http://npu-server:9001 python run_cuda_to_ascend_conversion.py --direct + # 指定8个并发运行(更快找到正确解) + python run_cuda_to_ascend_conversion.py --direct --num-concurrent 8 + # 快速验证模式(仅测试参考数据生成和传输) export CUDA_WORKER_URL=http://gpu-server:9001 export ASCEND_WORKER_URL=http://npu-server:9001 @@ -634,7 +643,17 @@ if __name__ == "__main__": asyncio.run(run_via_server_api(server_url)) elif sys.argv[1] == "--direct": - asyncio.run(run_direct_with_workers()) + # 解析 --num-concurrent 参数 + num_concurrent = 4 # 默认值 + if "--num-concurrent" in sys.argv: + try: + idx = sys.argv.index("--num-concurrent") + num_concurrent = int(sys.argv[idx + 1]) + except (IndexError, ValueError): + print("错误: --num-concurrent 需要一个整数参数") + print_usage() + sys.exit(1) + asyncio.run(run_direct_with_workers(num_concurrent=num_concurrent)) elif sys.argv[1] == "--verify": success = asyncio.run(run_quick_verify()) -- Gitee From 1b50f9cfb2ea51a92660392e4585acafb9987fda Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Thu, 4 Dec 2025 01:06:28 +0800 Subject: [PATCH 8/8] update doc --- .../examples/run_cuda_to_ascend_conversion.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/aikg/examples/run_cuda_to_ascend_conversion.py b/aikg/examples/run_cuda_to_ascend_conversion.py index bb0f9ba25d..576f53913a 100644 --- a/aikg/examples/run_cuda_to_ascend_conversion.py +++ b/aikg/examples/run_cuda_to_ascend_conversion.py @@ -38,12 +38,19 @@ CUDA-to-Ascend 转换示例(多输出) # NPU 机器上启动 Ascend Worker ./scripts/server_related/start_worker_service.sh ascend ascend910b4 0,1,2,3,4,5,6,7 9001 +# [可选] IPv6 环境:通过环境变量设置监听地址 +# export AIKG_WORKER_HOST=:: # 监听所有 IPv6 接口(双栈模式) +# ./scripts/server_related/start_worker_service.sh cuda a100 0,1,2,3 9001 方式1: 快速验证模式(仅测试参考数据生成和传输,不调用 LLM) export CUDA_WORKER_URL=http://cuda-server:9001 export ASCEND_WORKER_URL=http://ascend-server:9001 python examples/run_cuda_to_ascend_conversion.py --verify + + # IPv6 场景(注意地址需要用方括号包围): + # export CUDA_WORKER_URL=http://[2001:db8::1]:9001 + # export ASCEND_WORKER_URL=http://[2001:db8::2]:9001 方式2: 直接使用 Remote Workers @@ -57,10 +64,17 @@ CUDA-to-Ascend 转换示例(多输出) # Server 机器上启动 AIKG Server ./scripts/server_related/start_server.sh 8000 + # [可选] IPv6 环境: + # export AIKG_SERVER_HOST=:: + # ./scripts/server_related/start_server.sh 8000 + # 注册 Workers ./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://gpu-server:9001 cuda a100 ./scripts/server_related/register_worker_to_server.sh http://localhost:8000 http://npu-server:9001 ascend ascend910b4 + # IPv6 环境下注册 Workers(URL 中 IPv6 地址需用方括号包围): + # ./scripts/server_related/register_worker_to_server.sh http://[::1]:8000 http://[2001:db8::1]:9001 cuda a100 + # 运行此脚本 python examples/run_cuda_to_ascend_conversion.py --server http://localhost:8000 @@ -626,6 +640,19 @@ def print_usage(): export CUDA_WORKER_URL=http://gpu-server:9001 export ASCEND_WORKER_URL=http://npu-server:9001 python run_cuda_to_ascend_conversion.py --verify + +IPv6 环境变量配置: + # 启动 Worker 时监听 IPv6(双栈模式) + export AIKG_WORKER_HOST=:: + ./scripts/server_related/start_worker_service.sh cuda a100 0 9001 + + # 启动 Server 时监听 IPv6 + export AIKG_SERVER_HOST=:: + ./scripts/server_related/start_server.sh 8000 + + # URL 中使用 IPv6 地址(注意方括号) + export CUDA_WORKER_URL=http://[2001:db8::1]:9001 + export ASCEND_WORKER_URL=http://[2001:db8::2]:9001 """) -- Gitee