From 9f3919edc9688955bceddc2210182b7c1164e532 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Tue, 19 Aug 2025 01:32:01 +0800 Subject: [PATCH 1/2] update aikg: error logic --- .../core/agent/agent_base.py | 2 +- .../utils/result_processor.py | 8 ++++- .../utils/workflow_controller.py | 32 +++++++++++-------- .../utils/workflow_manager.py | 12 ++++--- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/aikg/python/ai_kernel_generator/core/agent/agent_base.py b/aikg/python/ai_kernel_generator/core/agent/agent_base.py index 6762760f6..7dcd190f2 100644 --- a/aikg/python/ai_kernel_generator/core/agent/agent_base.py +++ b/aikg/python/ai_kernel_generator/core/agent/agent_base.py @@ -84,7 +84,7 @@ class AgentBase(ABC): except FileNotFoundError: raise FileNotFoundError(f"文件不存在: {file_path}") except UnicodeDecodeError: - raise UnicodeDecodeError(f"文件编码错误: {file_path}, 编码: {encoding}") + raise Exception(f"文件编码错误: {file_path}, 编码: {encoding}") except Exception as e: raise Exception(f"文件读取失败: {file_path}, 错误: {str(e)}") diff --git a/aikg/python/ai_kernel_generator/utils/result_processor.py b/aikg/python/ai_kernel_generator/utils/result_processor.py index 695c53478..b6753cd29 100644 --- a/aikg/python/ai_kernel_generator/utils/result_processor.py +++ b/aikg/python/ai_kernel_generator/utils/result_processor.py @@ -69,12 +69,18 @@ class ResultProcessor: if len(field_names) == 1: field_name = field_names[0] field_value = getattr(parsed_result, field_name, '') or '' - if field_value: + # 检查字段名是否包含"code" + if "code" in field_name.lower() and field_value: # 统一存储为{agent_name}_code格式 task_info_key = f"{agent_name}_code" task_info[task_info_key] = field_value # 添加到保存列表 params_to_save.append(("code", field_value)) + elif field_value: + # 如果字段名不包含"code",仍然保存字段但返回False + logger.error(f"Agent '{agent_name}' has single output field '{field_name}' which does not contain 'code'. " + f"Please name the field with 'code'.") + return False # 如果有多个字段,处理每个字段 elif len(field_names) > 1: diff --git a/aikg/python/ai_kernel_generator/utils/workflow_controller.py b/aikg/python/ai_kernel_generator/utils/workflow_controller.py index fdaca8d97..6a465978b 100644 --- a/aikg/python/ai_kernel_generator/utils/workflow_controller.py +++ b/aikg/python/ai_kernel_generator/utils/workflow_controller.py @@ -118,8 +118,8 @@ class WorkflowController: @staticmethod def count_sequence_repeats(agent_history: List[str], pattern: List[str]) -> int: """ - 统计指定序列在agent历史末尾的连续重复次数 - 检查历史末尾是否匹配重复的序列模式 + 统计指定序列在agent历史中的重复次数 + 检查历史中包含多少个完整的序列模式 Args: agent_history: agent执行历史 @@ -133,20 +133,18 @@ class WorkflowController: pattern_length = len(pattern) history_length = len(agent_history) - max_possible_repeats = history_length // pattern_length - - # 从最大可能重复数开始向下检查 - for repeat_count in range(max_possible_repeats, 0, -1): - required_length = repeat_count * pattern_length - if required_length <= history_length: - # 检查历史末尾指定长度是否匹配重复的序列模式 - tail_segment = agent_history[-required_length:] - expected_pattern = pattern * repeat_count + + # 如果历史长度小于模式长度,不可能有匹配 + if history_length < pattern_length: + return 0 - if tail_segment == expected_pattern: - return repeat_count + # 计算历史记录中包含多少个完整的模式序列 + count = 0 + for i in range(0, history_length - pattern_length + 1): + if agent_history[i:i+pattern_length] == pattern: + count += 1 - return 0 + return count @staticmethod def get_valid_next_agent(agent_name: str, agent_next_mapping: Dict[str, Set[str]], @@ -173,6 +171,12 @@ class WorkflowController: # 获取当前agent的可能下一步 possible_next = agent_next_mapping.get(agent_name, set()) + # 检查总步数上限 + if step_count >= max_step: + logger.info(f"Step count {step_count} exceeds max_step {max_step}") + # 当步数超过限制时,不返回任何下一步 + return set() + # 获取违禁agent illegal_agents = WorkflowController.get_illegal_agent( step_count, max_step, current_agent_name, agent_history, diff --git a/aikg/python/ai_kernel_generator/utils/workflow_manager.py b/aikg/python/ai_kernel_generator/utils/workflow_manager.py index 4e80db03a..5ca884774 100644 --- a/aikg/python/ai_kernel_generator/utils/workflow_manager.py +++ b/aikg/python/ai_kernel_generator/utils/workflow_manager.py @@ -41,6 +41,10 @@ class WorkflowManager: Returns: str: 解析后的完整路径 """ + # 如果是None,使用默认配置 + if workflow_name_or_path is None: + workflow_name_or_path = "default_workflow" + # 如果已经是完整的文件路径(包含.yaml或.yml) if workflow_name_or_path.endswith(('.yaml', '.yml')): # 如果是绝对路径,直接返回 @@ -76,14 +80,14 @@ class WorkflowManager: limitation_info = config.get('limitation_info', {}) # 处理必须设置的项 - required_settings = limitation_info.get('required', {}) - max_step = required_settings.get('max_step') + required_settings = limitation_info.get('required', {}) if limitation_info else {} + max_step = required_settings.get('max_step') if required_settings else None if max_step is None: raise ValueError("Missing required setting 'max_step' in limitation_info.required") # 处理可选设置的项 - optional_settings = limitation_info.get('optional', {}) - repeat_limits = optional_settings.get('repeat_limits', {}) + optional_settings = limitation_info.get('optional', {}) if limitation_info else {} + repeat_limits = optional_settings.get('repeat_limits', {}) if optional_settings else {} # 获取start_agent start_agent = config.get('start_agent') -- Gitee From 8ccef7e49fc2d40792fb36541e7eaaacd75a6d61 Mon Sep 17 00:00:00 2001 From: Yanzhi_YI Date: Tue, 19 Aug 2025 01:32:24 +0800 Subject: [PATCH 2/2] update ut: mock tests --- aikg/tests/ut/test_agent_base.py | 207 +++++++++ aikg/tests/ut/test_async_pool.py | 203 +++++++++ aikg/tests/ut/test_coder.py | 320 ++++++++++++++ aikg/tests/ut/test_collector_extended.py | 256 +++++++++++ aikg/tests/ut/test_common_utils.py | 137 ++++++ aikg/tests/ut/test_common_utils_extended.py | 218 +++++++++ aikg/tests/ut/test_conductor.py | 465 ++++++++++++++++++++ aikg/tests/ut/test_config_validator.py | 213 +++++++++ aikg/tests/ut/test_core_utils.py | 134 ++++++ aikg/tests/ut/test_database_components.py | 221 ++++++++++ aikg/tests/ut/test_designer.py | 146 ++++++ aikg/tests/ut/test_environment_check.py | 261 +++++++++++ aikg/tests/ut/test_hardware_utils.py | 147 +++++++ aikg/tests/ut/test_kernel_verifier.py | 300 +++++++++++++ aikg/tests/ut/test_model_loader.py | 158 +++++++ aikg/tests/ut/test_parser_registry.py | 141 ++++++ aikg/tests/ut/test_process_utils.py | 128 ++++++ aikg/tests/ut/test_result_processor.py | 287 ++++++++++++ aikg/tests/ut/test_trace.py | 174 ++++++++ aikg/tests/ut/test_workflow_controller.py | 179 ++++++++ aikg/tests/ut/test_workflow_manager.py | 231 ++++++++++ 21 files changed, 4526 insertions(+) create mode 100644 aikg/tests/ut/test_agent_base.py create mode 100644 aikg/tests/ut/test_async_pool.py create mode 100644 aikg/tests/ut/test_coder.py create mode 100644 aikg/tests/ut/test_collector_extended.py create mode 100644 aikg/tests/ut/test_common_utils.py create mode 100644 aikg/tests/ut/test_common_utils_extended.py create mode 100644 aikg/tests/ut/test_conductor.py create mode 100644 aikg/tests/ut/test_config_validator.py create mode 100644 aikg/tests/ut/test_core_utils.py create mode 100644 aikg/tests/ut/test_database_components.py create mode 100644 aikg/tests/ut/test_designer.py create mode 100644 aikg/tests/ut/test_environment_check.py create mode 100644 aikg/tests/ut/test_hardware_utils.py create mode 100644 aikg/tests/ut/test_kernel_verifier.py create mode 100644 aikg/tests/ut/test_model_loader.py create mode 100644 aikg/tests/ut/test_parser_registry.py create mode 100644 aikg/tests/ut/test_process_utils.py create mode 100644 aikg/tests/ut/test_result_processor.py create mode 100644 aikg/tests/ut/test_trace.py create mode 100644 aikg/tests/ut/test_workflow_controller.py create mode 100644 aikg/tests/ut/test_workflow_manager.py diff --git a/aikg/tests/ut/test_agent_base.py b/aikg/tests/ut/test_agent_base.py new file mode 100644 index 000000000..88645fc7d --- /dev/null +++ b/aikg/tests/ut/test_agent_base.py @@ -0,0 +1,207 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import mock_open, patch, Mock +import yaml +from ai_kernel_generator.core.agent.agent_base import AgentBase + + +class TestAgentBase: + """测试Agent基类""" + + def test_agent_base_init(self): + """测试Agent基类初始化""" + context = {"test": "value"} + config = {"config_key": "config_value"} + + agent = AgentBase(context=context, config=config) + + assert agent.context == context + assert agent.config == config + + def test_agent_base_init_default(self): + """测试Agent基类默认初始化""" + agent = AgentBase() + + assert agent.context == {} + assert agent.config is None + + def test_count_tokens_empty_text(self): + """测试统计空文本的token数量""" + AgentBase.count_tokens("", "test_model", {"agent_name": "test_agent"}) + + def test_count_tokens_success(self): + """测试成功统计文本的token数量""" + try: + import tiktoken + with patch('tiktoken.get_encoding') as mock_get_encoding: + mock_encoding = Mock() + mock_encoding.encode.return_value = [1, 2, 3, 4, 5] + mock_get_encoding.return_value = mock_encoding + + AgentBase.count_tokens("test text", "test_model", {"agent_name": "test_agent"}) + except ImportError: + # 如果没有安装tiktoken,跳过测试 + pass + + def test_count_tokens_import_error(self): + """测试导入tiktoken失败时的token统计""" + with patch('builtins.__import__', side_effect=ImportError("No module named 'tiktoken'")): + # 应该不会抛出异常 + AgentBase.count_tokens("test text", "test_model", {"agent_name": "test_agent"}) + + def test_count_tokens_exception(self): + """测试统计token时发生异常""" + try: + import tiktoken + with patch('tiktoken.get_encoding', side_effect=Exception("Test exception")): + # 应该不会抛出异常 + AgentBase.count_tokens("test text", "test_model", {"agent_name": "test_agent"}) + except ImportError: + # 如果没有安装tiktoken,跳过测试 + pass + + def test_read_file_success(self): + """测试成功读取文件""" + test_content = "test file content" + with patch('builtins.open', mock_open(read_data=test_content)): + content = AgentBase.read_file("/test/file.txt") + assert content == test_content + + def test_read_file_not_found(self): + """测试读取不存在的文件""" + with patch('builtins.open', side_effect=FileNotFoundError("File not found")): + with pytest.raises(FileNotFoundError, match="文件不存在"): + AgentBase.read_file("/nonexistent/file.txt") + + def test_read_file_unicode_decode_error(self): + """测试读取文件时的Unicode解码错误""" + with patch('builtins.open', side_effect=UnicodeDecodeError("utf-8", b"\xff", 0, 1, "Test error")): + with pytest.raises(Exception, match="文件编码错误"): + AgentBase.read_file("/test/file.txt") + + def test_read_file_other_exception(self): + """测试读取文件时的其他异常""" + with patch('builtins.open', side_effect=Exception("Test error")): + with pytest.raises(Exception, match="文件读取失败"): + AgentBase.read_file("/test/file.txt") + + def test_load_template_success(self): + """测试成功加载模板""" + template_content = "test template content" + with patch.object(AgentBase, 'read_file', return_value=template_content): + with patch('ai_kernel_generator.core.agent.agent_base.PromptTemplate') as mock_prompt_template: + agent = AgentBase() + prompt_template = agent.load_template("test_template.j2") + mock_prompt_template.assert_called_once_with( + template=template_content, + template_format="jinja2" + ) + + def test_load_template_exception(self): + """测试加载模板时发生异常""" + with patch.object(AgentBase, 'read_file', side_effect=Exception("Test error")): + agent = AgentBase() + with pytest.raises(ValueError, match="Failed to load template"): + agent.load_template("test_template.j2") + + def test_load_doc_success(self): + """测试成功加载文档""" + doc_content = "test document content" + with patch.object(AgentBase, 'read_file', return_value=doc_content): + with patch('os.path.exists', return_value=True): + agent = AgentBase() + agent.root_dir = "/test/root" + agent.config = {"docs_dir": {"agentbase": "/test/docs"}} + with patch('os.path.join') as mock_join: + mock_join.return_value = "/test/root/test_doc.md" + content = agent.load_doc("test_doc.md") + # 由于read_file被mock了,这里应该返回mock的内容 + # 但实际测试中我们需要检查read_file是否被正确调用 + assert True # 只要不抛出异常就行 + + def test_load_doc_file_not_found(self): + """测试加载不存在的文档""" + with patch('os.path.exists', return_value=False): + agent = AgentBase() + agent.root_dir = "/test/root" + agent.config = {"docs_dir": {"agentbase": "/test/docs"}} + content = agent.load_doc("nonexistent_doc.md") + assert content == "" # 应该返回空字符串 + + def test_load_doc_exception(self): + """测试加载文档时发生异常""" + with patch.object(AgentBase, 'read_file', side_effect=Exception("Test error")): + with patch('os.path.exists', return_value=True): + agent = AgentBase() + agent.root_dir = "/test/root" + agent.config = {"docs_dir": {"agentbase": "/test/docs"}} + content = agent.load_doc("test_doc.md") + assert content == "" # 应该返回空字符串 + + def test_get_agent_type_designer(self): + """测试获取designer类型的agent类型""" + class TestDesignerAgent(AgentBase): + pass + + agent = TestDesignerAgent() + agent_type = agent._get_agent_type() + assert agent_type == "designer" + + def test_get_agent_type_coder(self): + """测试获取coder类型的agent类型""" + class TestCoderAgent(AgentBase): + pass + + agent = TestCoderAgent() + agent_type = agent._get_agent_type() + assert agent_type == "coder" + + def test_get_agent_type_conductor(self): + """测试获取conductor类型的agent类型""" + class TestConductorAgent(AgentBase): + pass + + agent = TestConductorAgent() + agent_type = agent._get_agent_type() + assert agent_type == "conductor" + + def test_get_agent_type_other(self): + """测试获取其他类型的agent类型""" + class TestOtherAgent(AgentBase): + pass + + agent = TestOtherAgent() + agent_type = agent._get_agent_type() + assert agent_type == "testotheragent" # 类名的小写形式 + + def test_check_input_dict_empty(self): + """测试检查空输入字典""" + agent = AgentBase() + agent.context = {"agent_name": "test_agent"} + agent._check_input_dict({}) # 应该不会抛出异常 + + def test_check_input_dict_with_values(self): + """测试检查包含值的输入字典""" + agent = AgentBase() + agent.context = {"agent_name": "test_agent"} + input_dict = { + "key1": "value1", + "key2": "", + "key3": None, + "key4": [], + "key5": {} + } + agent._check_input_dict(input_dict) # 应该不会抛出异常 \ No newline at end of file diff --git a/aikg/tests/ut/test_async_pool.py b/aikg/tests/ut/test_async_pool.py new file mode 100644 index 000000000..280816a41 --- /dev/null +++ b/aikg/tests/ut/test_async_pool.py @@ -0,0 +1,203 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import asyncio +from unittest.mock import Mock, patch +from ai_kernel_generator.core.async_pool.device_pool import DevicePool +from ai_kernel_generator.core.async_pool.task_pool import TaskPool + + +class TestDevicePool: + """测试设备池""" + + @pytest.mark.asyncio + async def test_device_pool_init_default(self): + """测试设备池默认初始化""" + pool = DevicePool() + + # 检查默认设备列表 + assert pool.device_list == [0] + + # 检查可用设备队列 + assert pool.available_devices.qsize() == 1 + device_id = await pool.available_devices.get() + assert device_id == 0 + + @pytest.mark.asyncio + async def test_device_pool_init_custom(self): + """测试设备池自定义初始化""" + device_list = [0, 1, 2] + pool = DevicePool(device_list) + + # 检查自定义设备列表 + assert pool.device_list == device_list + assert pool.available_devices.qsize() == 3 + + @pytest.mark.asyncio + async def test_device_pool_acquire_device(self): + """测试获取设备""" + pool = DevicePool([0, 1]) + + # 获取一个设备 + device_id = await pool.acquire_device() + assert device_id in [0, 1] + assert pool.available_devices.qsize() == 1 + + @pytest.mark.asyncio + async def test_device_pool_release_device(self): + """测试释放设备""" + pool = DevicePool([0]) + + # 获取设备 + device_id = await pool.acquire_device() + assert pool.available_devices.qsize() == 0 + + # 释放设备 + await pool.release_device(device_id) + assert pool.available_devices.qsize() == 1 + + @pytest.mark.asyncio + async def test_device_pool_concurrent_access(self): + """测试设备池并发访问""" + pool = DevicePool([0, 1]) + + # 同时获取所有设备 + device1 = await pool.acquire_device() + device2 = await pool.acquire_device() + + assert pool.available_devices.qsize() == 0 + assert device1 != device2 + assert device1 in [0, 1] + assert device2 in [0, 1] + + # 释放设备 + await pool.release_device(device1) + await pool.release_device(device2) + assert pool.available_devices.qsize() == 2 + + +class TestTaskPool: + """测试任务池""" + + @pytest.mark.asyncio + async def test_task_pool_init(self): + """测试任务池初始化""" + pool = TaskPool(max_concurrency=5) + + # 检查信号量 + assert pool.semaphore._value == 5 + + # 检查任务列表 + assert pool.tasks == [] + + @pytest.mark.asyncio + async def test_task_pool_create_task(self): + """测试创建任务""" + pool = TaskPool() + + # 创建一个简单的异步函数 + async def simple_task(): + return "result" + + # 创建任务 + task = pool.create_task(simple_task) + + # 检查任务是否被添加到任务列表 + assert len(pool.tasks) == 1 + assert pool.tasks[0] == task + + @pytest.mark.asyncio + async def test_task_pool_wait_all(self): + """测试等待所有任务完成""" + pool = TaskPool() + + # 创建几个简单的异步函数 + async def task1(): + return "result1" + + async def task2(): + return "result2" + + # 创建任务 + pool.create_task(task1) + pool.create_task(task2) + + # 等待所有任务完成 + results = await pool.wait_all() + + # 检查结果 + assert len(results) == 2 + assert "result1" in results + assert "result2" in results + + @pytest.mark.asyncio + async def test_task_pool_concurrent_limit(self): + """测试任务池并发限制""" + max_concurrency = 2 + pool = TaskPool(max_concurrency=max_concurrency) + + # 创建一个会等待的异步函数 + execution_order = [] + + async def waiting_task(task_id): + execution_order.append(f"start_{task_id}") + await asyncio.sleep(0.1) # 等待一段时间 + execution_order.append(f"end_{task_id}") + return task_id + + # 创建超过并发限制的任务 + for i in range(4): + pool.create_task(waiting_task, i) + + # 等待所有任务完成 + results = await pool.wait_all() + + # 检查结果 + assert len(results) == 4 + assert set(results) == {0, 1, 2, 3} + + @pytest.mark.asyncio + async def test_task_pool_run_with_args(self): + """测试带参数的任务执行""" + pool = TaskPool() + + # 创建一个带参数的异步函数 + async def task_with_args(a, b, c=None): + return a + b + (c or 0) + + # 创建任务 + task = pool.create_task(task_with_args, 1, 2, c=3) + + # 等待任务完成 + result = await task + + # 检查结果 + assert result == 6 + + @pytest.mark.asyncio + async def test_task_pool_run_with_exception(self): + """测试任务执行时的异常处理""" + pool = TaskPool() + + # 创建一个会抛出异常的异步函数 + async def failing_task(): + raise ValueError("Test exception") + + # 创建任务 + task = pool.create_task(failing_task) + + # 等待任务完成,应该抛出异常 + with pytest.raises(ValueError, match="Test exception"): + await task \ No newline at end of file diff --git a/aikg/tests/ut/test_coder.py b/aikg/tests/ut/test_coder.py new file mode 100644 index 000000000..3f7d6a66a --- /dev/null +++ b/aikg/tests/ut/test_coder.py @@ -0,0 +1,320 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +import os +from pathlib import Path +from ai_kernel_generator.core.agent.coder import Coder, get_inspirations +from ai_kernel_generator.core.agent.agent_base import AgentBase + + +class TestCoderUtils: + """测试Coder工具函数""" + + def test_get_inspirations_empty_list(self): + """测试获取空灵感列表""" + result = get_inspirations([]) + assert result == "" + + def test_get_inspirations_valid_list(self): + """测试获取有效灵感列表""" + inspirations = [ + { + "strategy_mode": "evolution", + "impl_code": "test code 1", + "profile": (1.0, 2.0, 2.0) + }, + { + "strategy_mode": "evolution", + "impl_code": "test code 2", + "profile": 3.0 + } + ] + + result = get_inspirations(inspirations) + assert "test code 1" in result + assert "test code 2" in result + assert "加速比: 2.00x" in result + assert "代码执行耗时: 3.0000s" in result + + def test_get_inspirations_invalid_type(self): + """测试获取包含无效类型的灵感列表""" + inspirations = [ + "invalid_type", # 非字典类型 + { + "strategy_mode": "evolution", + "impl_code": "test code", + "profile": 1.0 + } + ] + + result = get_inspirations(inspirations) + assert "test code" in result + + def test_get_inspirations_empty_code(self): + """测试获取包含空代码的灵感列表""" + inspirations = [ + { + "strategy_mode": "evolution", + "impl_code": "", # 空代码 + "profile": 1.0 + }, + { + "strategy_mode": "evolution", + "impl_code": "test code", + "profile": 1.0 + } + ] + + result = get_inspirations(inspirations) + assert "test code" in result + assert "Inspiration 1" not in result # 空代码不应该被包含 + + +class TestCoder: + """测试Coder类""" + + def test_coder_init_success(self): + """测试Coder成功初始化""" + config = { + "agent_model_config": { + "coder": "deepseek_r1_default", + "api_generator": "deepseek_r1_default" + }, + "docs_dir": { + "coder": "resources/docs/triton_docs" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + coder = Coder( + op_name="test_op", + task_desc="test task description", + dsl="triton", + framework="torch", + backend="cuda", + arch="a100", + config=config + ) + + assert coder.op_name == "test_op" + assert coder.task_desc == "test task description" + assert coder.dsl == "triton" + assert coder.framework == "torch" + assert coder.backend == "cuda" + assert coder.arch == "a100" + assert coder.func_name == "test_op_triton_torch" + + def test_coder_init_missing_config(self): + """测试Coder初始化时缺少配置""" + with pytest.raises(ValueError, match="config is required for Coder"): + Coder( + op_name="test_op", + task_desc="test task description", + dsl="triton", + framework="torch", + backend="cuda", + arch="a100" + # 缺少config参数 + ) + + def test_coder_init_missing_parser(self): + """测试Coder初始化时缺少解析器""" + config = { + "agent_model_config": { + "coder": "deepseek_r1_default" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = None + + with pytest.raises(ValueError, match="Failed to create coder parser"): + Coder( + op_name="test_op", + task_desc="test task description", + dsl="triton", + framework="torch", + backend="cuda", + arch="a100", + config=config + ) + + def test_load_doc_swft_backend(self): + """测试加载SWFT后端文档""" + config = { + "agent_model_config": { + "coder": "deepseek_r1_default" + }, + "docs_dir": { + "coder": "resources/docs/swft_docs" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + coder = Coder( + op_name="test_op", + task_desc="test task description", + dsl="swft", + framework="numpy", + backend="ascend", + arch="ascend310p3", + config=config + ) + + # 测试加载文档 - 模拟文件存在且有内容的情况 + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test doc content")): + with patch('ai_kernel_generator.core.agent.coder.get_swft_docs_content', return_value="swft docs content"): + content = coder.load_doc("test_doc.md") + # 当文件存在且有内容时,应该返回该内容 + assert content == "test doc content" + + def test_load_doc_swft_backend_fallback(self): + """测试SWFT后端文档加载失败时的降级处理""" + config = { + "agent_model_config": { + "coder": "deepseek_r1_default" + }, + "docs_dir": { + "coder": "resources/docs/swft_docs" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + coder = Coder( + op_name="test_op", + task_desc="test task description", + dsl="swft", + framework="numpy", + backend="ascend", + arch="ascend310p3", + config=config + ) + + # 测试加载文档失败时的降级处理 + with patch.object(coder, 'read_file', side_effect=Exception("test error")): + with patch('ai_kernel_generator.core.agent.coder.get_swft_docs_content', return_value="swft docs content"): + content = coder.load_doc("test_doc.md") + assert content == "swft docs content" + + @patch('pathlib.Path.exists') + @patch('pathlib.Path.glob') + def test_load_dsl_examples_success(self, mock_glob, mock_exists): + """测试成功加载DSL示例""" + mock_exists.return_value = True + + # 模拟glob返回文件路径 + mock_py_file = Mock() + mock_py_file.suffix = '.py' + mock_py_file.name = 'torch_example.py' + mock_glob.return_value = [mock_py_file] + + config = { + "agent_model_config": { + "coder": "deepseek_r1_default" + }, + "docs_dir": { + "coder": "resources/docs/triton_docs" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + with patch('builtins.open', mock_open(read_data="test example code")): + coder = Coder( + op_name="test_op", + task_desc="test task description", + dsl="triton", + framework="torch", + backend="cuda", + arch="a100", + config=config + ) + + examples = coder._load_dsl_examples() + assert "test example code" in examples + + @patch('pathlib.Path.exists') + def test_load_dsl_examples_no_directory(self, mock_exists): + """测试DSL示例目录不存在""" + mock_exists.return_value = False + + config = { + "agent_model_config": { + "coder": "deepseek_r1_default" + }, + "docs_dir": { + "coder": "resources/docs/triton_docs" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + coder = Coder( + op_name="test_op", + task_desc="test task description", + dsl="triton", + framework="torch", + backend="cuda", + arch="a100", + config=config + ) + + examples = coder._load_dsl_examples() + assert examples == "" + + @patch('pathlib.Path.exists') + def test_load_dsl_examples_empty_framework(self, mock_exists): + """测试框架为空时加载DSL示例""" + config = { + "agent_model_config": { + "coder": "deepseek_r1_default" + }, + "docs_dir": { + "coder": "resources/docs/triton_docs" + } + } + + with patch('ai_kernel_generator.core.agent.coder.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + coder = Coder( + op_name="test_op", + task_desc="test task description", + dsl="triton", + framework="", # 空框架 + backend="cuda", + arch="a100", + config=config + ) + + examples = coder._load_dsl_examples() + assert examples == "" \ No newline at end of file diff --git a/aikg/tests/ut/test_collector_extended.py b/aikg/tests/ut/test_collector_extended.py new file mode 100644 index 000000000..c0eb7fc3c --- /dev/null +++ b/aikg/tests/ut/test_collector_extended.py @@ -0,0 +1,256 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +import tempfile +import os +from pathlib import Path +from ai_kernel_generator.utils.collector import Collector, get_collector + + +class TestCollector: + """测试Collector类""" + + @pytest.mark.asyncio + async def test_collector_singleton(self): + """测试Collector单例模式""" + collector1 = await get_collector() + collector2 = await get_collector() + assert collector1 is collector2 + + @pytest.mark.asyncio + async def test_collector_set_config_default(self): + """测试Collector设置默认配置""" + collector = await get_collector() + collector.set_config(None) + + # 应该使用默认目录 + assert collector._save_dir is not None + assert collector._save_dir == Path.cwd() / "save_data" + + @pytest.mark.asyncio + async def test_collector_set_config_custom(self): + """测试Collector设置自定义配置""" + collector = await get_collector() + + with tempfile.TemporaryDirectory() as tmp_dir: + log_dir = Path(tmp_dir) / "test.log" + config = {"log_dir": str(log_dir)} + collector.set_config(config) + + # 应该使用配置中的目录 + expected_dir = log_dir.parent / "save_data" + assert collector._save_dir == expected_dir + + @pytest.mark.asyncio + async def test_collector_collect_success(self): + """测试Collector成功收集数据""" + collector = await get_collector() + initial_count = collector._counter + + test_data = { + "agent_name": "test_agent", + "hash": "test_hash", + "data": "test_data" + } + + await collector.collect(test_data) + + # 计数器应该增加 + assert collector._counter == initial_count + 1 + + # 数据应该被存储 + key = ("test_agent", "test_hash") + assert key in collector._store + + @pytest.mark.asyncio + async def test_collector_collect_overwrite(self): + """测试Collector覆盖已存在的数据""" + collector = await get_collector() + + # 首次收集 + test_data1 = { + "agent_name": "test_agent", + "hash": "test_hash", + "data": "test_data1" + } + await collector.collect(test_data1) + + # 覆盖收集 + test_data2 = { + "agent_name": "test_agent", + "hash": "test_hash", + "data": "test_data2" + } + await collector.collect(test_data2) + + # 应该覆盖原有数据 + key = ("test_agent", "test_hash") + assert collector._store[key]["data"]["data"] == "test_data2" + + @pytest.mark.asyncio + async def test_collector_generate_filename(self): + """测试Collector生成文件名""" + collector = await get_collector() + + filename = collector._generate_filename("test/agent", "test\\hash", {"test": "data"}) + + # 文件名应该不包含非法字符 + assert "/" not in filename + assert "\\" not in filename + assert filename.endswith(".json") + + @pytest.mark.asyncio + async def test_collector_save_json_file_success(self): + """测试Collector成功保存JSON文件""" + collector = await get_collector() + + with tempfile.TemporaryDirectory() as tmp_dir: + collector._save_dir = Path(tmp_dir) + + test_data = {"test": "data"} + filename = "test.json" + success = collector._save_json_file(test_data, filename) + + assert success is True + + # 文件应该被创建 + file_path = collector._save_dir / filename + assert file_path.exists() + + # 文件内容应该正确 + with open(file_path, 'r', encoding='utf-8') as f: + saved_data = json.load(f) + assert saved_data == test_data + + @pytest.mark.asyncio + async def test_collector_save_json_file_no_save_dir(self): + """测试Collector保存JSON文件时没有保存目录""" + collector = await get_collector() + collector._save_dir = None + + test_data = {"test": "data"} + filename = "test.json" + success = collector._save_json_file(test_data, filename) + + assert success is False + + @pytest.mark.asyncio + async def test_collector_prepare_and_remove_data_with_task_id(self): + """测试Collector准备和移除带任务ID的数据""" + collector = await get_collector() + + with tempfile.TemporaryDirectory() as tmp_dir: + collector.set_config({"log_dir": str(Path(tmp_dir) / "test.log")}) + + # 收集测试数据 + await collector.collect({ + "agent_name": "test_agent1", + "hash": "test_hash1", + "task_id": "test_task" + }) + + await collector.collect({ + "agent_name": "test_agent2", + "hash": "test_hash2" + # 无task_id + }) + + # 准备数据 + files = await collector.prepare_and_remove_data(task_id="test_task") + + # 应该准备两个文件(指定任务的数据 + 无task_id的数据) + assert len(files) >= 1 + + @pytest.mark.asyncio + async def test_collector_prepare_database_data(self): + """测试Collector准备数据库数据""" + collector = await get_collector() + + with tempfile.TemporaryDirectory() as tmp_dir: + collector.set_config({"log_dir": str(Path(tmp_dir) / "test.log")}) + + task_info = { + "backend": "cuda", + "arch_name": "a100", + "framework": "torch", + "dsl": "triton", + "task_desc": "test_desc", + "coder_code": "test_code", + "profile_res": (1.0, 2.0, 2.0) + } + + db_file = collector.prepare_database_data(task_info) + + # 应该返回文件路径 + assert db_file != "" + assert db_file.endswith(".json") + + @pytest.mark.asyncio + async def test_collector_validate_data_fields_complete(self): + """测试Collector验证完整数据字段""" + collector = await get_collector() + + complete_data = { + 'hash': 'test_hash', + 'agent_name': 'designer', + 'op_name': 'test_op', + 'dsl': 'triton', + 'backend': 'cuda', + 'arch': 'A100', + 'framework': 'torch', + 'task_desc': 'test task', + 'model_name': 'deepseek_r1_default', + 'content': 'generated code', + 'formatted_prompt': 'test prompt', + 'reasoning_content': 'reasoning', + 'response_metadata': 'metadata' + } + + # 应该不抛出异常 + collector._validate_data_fields(complete_data) + + @pytest.mark.asyncio + async def test_collector_validate_data_fields_incomplete(self): + """测试Collector验证不完整数据字段""" + collector = await get_collector() + + incomplete_data = { + 'hash': 'test_hash', + 'agent_name': 'coder', + 'op_name': 'test_op' + # 缺少多个必需字段 + } + + # 应该不抛出异常,但会记录警告 + collector._validate_data_fields(incomplete_data) + + @pytest.mark.asyncio + async def test_collector_is_empty_value(self): + """测试Collector空值检测""" + collector = await get_collector() + + # 测试各种空值 + assert collector._is_empty_value(None) is True + assert collector._is_empty_value("") is True + assert collector._is_empty_value(" ") is True + assert collector._is_empty_value([]) is True + assert collector._is_empty_value({}) is True + + # 测试非空值 + assert collector._is_empty_value("test") is False + assert collector._is_empty_value(0) is False + assert collector._is_empty_value(False) is False \ No newline at end of file diff --git a/aikg/tests/ut/test_common_utils.py b/aikg/tests/ut/test_common_utils.py new file mode 100644 index 000000000..86924cad5 --- /dev/null +++ b/aikg/tests/ut/test_common_utils.py @@ -0,0 +1,137 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tempfile +import os +import yaml +from pathlib import Path +from ai_kernel_generator.utils.common_utils import load_yaml, create_log_dir, remove_copyright_from_text, get_md5_hash + + +class TestCommonUtils: + """测试通用工具函数""" + + def test_load_yaml_success(self, tmp_path): + """测试成功加载YAML文件""" + # 创建一个临时YAML文件 + yaml_content = """ +key1: value1 +key2: + nested_key: nested_value +list_key: + - item1 + - item2 +""" + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text(yaml_content) + + config = load_yaml(str(yaml_file)) + assert config["key1"] == "value1" + assert config["key2"]["nested_key"] == "nested_value" + assert config["list_key"] == ["item1", "item2"] + + def test_load_yaml_file_not_found(self): + """测试加载不存在的YAML文件""" + with pytest.raises(FileNotFoundError): + load_yaml("/nonexistent/path/to/file.yaml") + + def test_load_yaml_invalid_yaml(self, tmp_path): + """测试加载无效YAML内容""" + # 创建一个包含无效YAML的文件 + invalid_yaml_content = """ +key1: value1 + invalid_indent: value2 +""" + yaml_file = tmp_path / "invalid.yaml" + yaml_file.write_text(invalid_yaml_content) + + with pytest.raises(Exception): # yaml.safe_load会抛出具体的YAML错误 + load_yaml(str(yaml_file)) + + def test_create_log_dir(self): + """测试创建日志目录""" + log_dir = create_log_dir("test_prefix") + + # 检查目录是否存在 + assert os.path.exists(log_dir) + assert os.path.isdir(log_dir) + + # 检查目录名是否包含前缀 + assert "test_prefix" in os.path.basename(log_dir) + + # 清理测试目录 + os.rmdir(log_dir) + + def test_remove_copyright_from_text_empty(self): + """测试清除空文本中的版权信息""" + assert remove_copyright_from_text("") == "" + assert remove_copyright_from_text(None) is None + + def test_remove_copyright_from_text_no_copyright(self): + """测试清除不包含版权信息的文本""" + text = "This is a test text without copyright." + assert remove_copyright_from_text(text) == text + + def test_remove_copyright_from_text_with_copyright(self): + """测试清除包含版权信息的文本""" + text = """# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +This is the actual content.""" + + cleaned_text = remove_copyright_from_text(text) + # 版权信息应该被移除,只保留实际内容 + assert "This is the actual content." in cleaned_text + assert "Copyright" not in cleaned_text + + def test_get_md5_hash_success(self): + """测试成功生成MD5哈希""" + # 测试基本功能 + hash1 = get_md5_hash(a=1, b="test") + hash2 = get_md5_hash(b="test", a=1) # 参数顺序不同但值相同,应该得到相同哈希 + + assert isinstance(hash1, str) + assert len(hash1) == 32 # MD5哈希长度为32字符 + assert hash1 == hash2 # 相同参数应该得到相同哈希 + + def test_get_md5_hash_different_values(self): + """测试不同值生成不同哈希""" + hash1 = get_md5_hash(a=1, b="test") + hash2 = get_md5_hash(a=2, b="test") + + assert hash1 != hash2 # 不同参数应该得到不同哈希 + + def test_get_md5_hash_no_parameters(self): + """测试没有参数时抛出异常""" + with pytest.raises(ValueError, match="至少需要提供一个有效参数"): + get_md5_hash() + + def test_get_md5_hash_none_values(self): + """测试忽略None值""" + hash1 = get_md5_hash(a=1, b=None, c="test") + hash2 = get_md5_hash(a=1, c="test") + + assert hash1 == hash2 # None值应该被忽略 \ No newline at end of file diff --git a/aikg/tests/ut/test_common_utils_extended.py b/aikg/tests/ut/test_common_utils_extended.py new file mode 100644 index 000000000..53a4beeb4 --- /dev/null +++ b/aikg/tests/ut/test_common_utils_extended.py @@ -0,0 +1,218 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +from ai_kernel_generator.utils.common_utils import ParserFactory, remove_copyright_from_text, get_md5_hash + + +class TestParserFactory: + """测试解析器工厂""" + + def test_register_parser_success(self): + """测试成功注册解析器""" + parser_config = { + 'output_fields': { + 'code': { + 'field_type': 'str', + 'mandatory': True, + 'field_description': 'Generated code' + } + } + } + + ParserFactory.register_parser("test_parser", parser_config) + + # 验证解析器是否注册成功 + parser = ParserFactory.get_parser("test_parser") + assert parser is not None + + # 清理注册的解析器 + if "test_parser" in ParserFactory._dynamic_parsers: + del ParserFactory._dynamic_parsers["test_parser"] + + def test_register_parser_invalid_type(self): + """测试注册包含无效类型的解析器""" + parser_config = { + 'output_fields': { + 'code': { + 'field_type': 'invalid_type', + 'mandatory': True, + 'field_description': 'Generated code' + } + } + } + + ParserFactory.register_parser("test_parser", parser_config) + + # 应该会默认为str类型 + parser = ParserFactory.get_parser("test_parser") + assert parser is not None + + # 清理注册的解析器 + if "test_parser" in ParserFactory._dynamic_parsers: + del ParserFactory._dynamic_parsers["test_parser"] + + def test_get_parser_not_found(self): + """测试获取不存在的解析器""" + with pytest.raises(ValueError, match="Parser 'nonexistent_parser' not found"): + ParserFactory.get_parser("nonexistent_parser") + + def test_list_parsers(self): + """测试列出所有解析器""" + # 先注册一个测试解析器 + parser_config = { + 'output_fields': { + 'code': { + 'field_type': 'str' + } + } + } + ParserFactory.register_parser("test_parser", parser_config) + + parsers = ParserFactory.list_parsers() + assert "test_parser" in parsers + + # 清理注册的解析器 + if "test_parser" in ParserFactory._dynamic_parsers: + del ParserFactory._dynamic_parsers["test_parser"] + + def test_get_api_parser(self): + """测试获取API解析器""" + parser = ParserFactory.get_api_parser() + assert parser is not None + assert hasattr(parser, 'pydantic_object') + + def test_get_feature_parser(self): + """测试获取特征解析器""" + parser = ParserFactory.get_feature_parser() + assert parser is not None + assert hasattr(parser, 'pydantic_object') + + def test_robust_parse_success(self): + """测试稳健解析成功""" + parser = ParserFactory.get_feature_parser() + content = '{"op_name": "test", "op_type": "test_type", "input_specs": "test", "output_specs": "test", "computation": "test", "schedule": "test", "description": "test"}' + + result = ParserFactory.robust_parse(content, parser) + assert result is not None + + def test_robust_parse_with_json_block(self): + """测试解析包含JSON代码块的内容""" + parser = ParserFactory.get_feature_parser() + content = 'Some text before\n```json\n{"op_name": "test", "op_type": "test_type", "input_specs": "test", "output_specs": "test", "computation": "test", "schedule": "test", "description": "test"}\n```\nSome text after' + + result = ParserFactory.robust_parse(content, parser) + assert result is not None + + def test_extract_json_comprehensive_success(self): + """测试全面提取JSON成功""" + text = 'Some text\n{"key": "value", "code": "test code"}\nMore text' + result = ParserFactory._extract_json_comprehensive(text) + assert result is not None + assert '"key": "value"' in result + + def test_extract_final_json_success(self): + """测试提取最终JSON成功""" + text = 'Some text\n{"key": "value"}' + result = ParserFactory._extract_final_json(text) + assert result is not None + assert '"key": "value"' in result + + def test_extract_function_code_success(self): + """测试提取函数代码成功""" + filename = "/test/file.py" + function_name = "test_function" + file_content = '''def test_function(): + """Test function""" + return "test" + +def another_function(): + return "another" +''' + + with patch('builtins.open', mock_open(read_data=file_content)): + # 从markdown_utils导入extract_function_code方法 + from ai_kernel_generator.utils.markdown_utils import MarkdownUtils + code = MarkdownUtils().extract_function_code(filename, function_name) + assert code is not None + assert "def test_function():" in code + assert 'return "test"' in code + + +class TestTextUtils: + """测试文本工具函数""" + + def test_remove_copyright_from_text_empty(self): + """测试从空文本中移除版权信息""" + assert remove_copyright_from_text("") == "" + assert remove_copyright_from_text(None) is None + + def test_remove_copyright_from_text_no_copyright(self): + """测试从不包含版权信息的文本中移除版权信息""" + text = "This is a test text without copyright." + assert remove_copyright_from_text(text) == text + + def test_remove_copyright_from_text_with_copyright(self): + """测试从包含版权信息的文本中移除版权信息""" + text = """# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +This is the actual content.""" + + cleaned_text = remove_copyright_from_text(text) + # 版权信息应该被移除,只保留实际内容 + assert "This is the actual content." in cleaned_text + assert "Copyright" not in cleaned_text + + def test_get_md5_hash_success(self): + """测试成功生成MD5哈希""" + # 测试基本功能 + hash1 = get_md5_hash(a=1, b="test") + hash2 = get_md5_hash(b="test", a=1) # 参数顺序不同但值相同,应该得到相同哈希 + + assert isinstance(hash1, str) + assert len(hash1) == 32 # MD5哈希长度为32字符 + assert hash1 == hash2 # 相同参数应该得到相同哈希 + + def test_get_md5_hash_different_values(self): + """测试不同值生成不同哈希""" + hash1 = get_md5_hash(a=1, b="test") + hash2 = get_md5_hash(a=2, b="test") + + assert hash1 != hash2 # 不同参数应该得到不同哈希 + + def test_get_md5_hash_no_parameters(self): + """测试没有参数时抛出异常""" + with pytest.raises(ValueError, match="至少需要提供一个有效参数"): + get_md5_hash() + + def test_get_md5_hash_none_values(self): + """测试忽略None值""" + hash1 = get_md5_hash(a=1, b=None, c="test") + hash2 = get_md5_hash(a=1, c="test") + + assert hash1 == hash2 # None值应该被忽略 \ No newline at end of file diff --git a/aikg/tests/ut/test_conductor.py b/aikg/tests/ut/test_conductor.py new file mode 100644 index 000000000..e9a09e7cf --- /dev/null +++ b/aikg/tests/ut/test_conductor.py @@ -0,0 +1,465 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import yaml +import os +from ai_kernel_generator.core.agent.conductor import Conductor + + +class TestConductor: + """测试Conductor类""" + + def test_conductor_init_success(self): + """测试Conductor成功初始化""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': {'possible_next_agent': ['coder']}, + 'coder': {'possible_next_agent': ['verifier']}, + 'verifier': {'possible_next_agent': ['finish']} + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + assert conductor.op_name == "test_op" + assert conductor.task_desc == "test task description" + assert conductor.task_id == "test_task" + assert conductor.dsl == "triton" + assert conductor.framework == "torch" + assert conductor.arch == "a100" + + def test_conductor_init_missing_config(self): + """测试缺少配置时的Conductor初始化""" + with pytest.raises(ValueError, match="config is required for Conductor"): + Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100" + ) + + def test_conductor_init_missing_workflow_config(self): + """测试缺少工作流配置时的Conductor初始化""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with pytest.raises(ValueError, match="workflow_config_path is required for Conductor"): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + config=config + ) + + def test_conductor_set_task_info(self): + """测试设置任务信息""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': {'possible_next_agent': ['coder']}, + 'coder': {'possible_next_agent': ['verifier']}, + 'verifier': {'possible_next_agent': ['finish']} + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + base_doc = { + "api_docs": "test api docs", + "dsl_basic_docs": "test dsl docs" + } + + conductor.set_task_info(base_doc) + + assert "op_name" in conductor.task_info + assert "task_id" in conductor.task_info + assert "api_docs" in conductor.task_info + assert "dsl_basic_docs" in conductor.task_info + + def test_conductor_get_agent_history(self): + """测试获取agent历史记录""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': {'possible_next_agent': ['coder']}, + 'coder': {'possible_next_agent': ['verifier']}, + 'verifier': {'possible_next_agent': ['finish']} + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + # 添加一些测试记录 + mock_record1 = Mock() + mock_record1.agent_name = "designer" + mock_record2 = Mock() + mock_record2.agent_name = "coder" + conductor.trace.trace_list = [mock_record1, mock_record2] + + history = conductor.get_agent_history() + assert history == ["designer", "coder"] + + def test_conductor_get_current_agent_name(self): + """测试获取当前agent名称""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': {'possible_next_agent': ['coder']}, + 'coder': {'possible_next_agent': ['verifier']}, + 'verifier': {'possible_next_agent': ['finish']} + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + # 测试空历史记录 + assert conductor.get_current_agent_name() is None + + # 添加测试记录 + mock_record = Mock() + mock_record.agent_name = "designer" + conductor.trace.trace_list = [mock_record] + + assert conductor.get_current_agent_name() == "designer" + + def test_conductor_get_valid_next_agent(self): + """测试获取有效的下一个agent""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': {'possible_next_agent': ['coder']}, + 'coder': {'possible_next_agent': ['verifier']}, + 'verifier': {'possible_next_agent': ['finish']} + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder', 'verifier'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + # 测试获取designer的下一个agent + valid_agents = conductor.get_valid_next_agent("designer") + assert valid_agents == {"coder"} + + def test_conductor_record_agent_execution_no_parser(self): + """测试记录没有解析器的agent执行""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'verifier': {} # verifier不需要解析器 + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + # 记录verifier执行结果 + success = conductor.record_agent_execution( + agent_name="verifier", + result="True", + error_log="", + profile_res=(1.0, 2.0, 2.0) + ) + + assert success is True + assert len(conductor.trace.trace_list) == 1 + + def test_conductor_record_agent_execution_with_parser(self): + """测试记录有解析器的agent执行""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/tmp/test/log/dir" + } + + # 确保测试目录存在 + os.makedirs("/tmp/test/log/dir", exist_ok=True) + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': { + 'output_format': { + 'parser_name': 'designer_parser', + 'parser_definition': { + 'output_fields': { + 'code': 'str' + } + } + } + } + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + # 创建一个模拟的解析器 + mock_parser = Mock() + + # 直接设置agent_info,绕过工作流配置加载 + conductor.agent_info = { + 'designer': { + 'output_format': { + 'parser_definition': { + 'output_fields': { + 'code': 'str' + } + } + } + } + } + + with patch('ai_kernel_generator.utils.common_utils.ParserFactory.robust_parse') as mock_robust_parse: + mock_parsed_result = Mock() + mock_parsed_result.code = "test code" + mock_robust_parse.return_value = mock_parsed_result + + # 记录designer执行结果 + success = conductor.record_agent_execution( + agent_name="designer", + result='{"code": "test code"}' + ) + + assert success is True + assert len(conductor.trace.trace_list) == 1 + + def test_conductor_get_next_agent_no_trace(self): + """测试没有trace记录时获取下一个agent""" + config = { + "agent_model_config": { + "conductor": "deepseek_r1_default" + }, + "log_dir": "/test/log/dir" + } + + with patch('os.path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test workflow content")): + with patch('yaml.safe_load', return_value={ + 'agent_info': { + 'designer': {'possible_next_agent': ['coder']}, + 'coder': {'possible_next_agent': ['verifier']}, + 'verifier': {'possible_next_agent': ['finish']} + }, + 'limitation_info': {'required': {'max_step': 20}}, + 'start_agent': 'designer', + 'max_step': 20, + 'repeat_limits': {}, + 'agent_next_mapping': { + 'designer': {'coder'}, + 'coder': {'verifier'}, + 'verifier': {'finish'} + }, + 'mandatory_llm_analysis': [] + }): + conductor = Conductor( + op_name="test_op", + task_desc="test task description", + task_id="test_task", + dsl="triton", + framework="torch", + arch="a100", + workflow_config_path="/test/workflow/path", + config=config + ) + + # 应该抛出异常,因为没有trace记录 + with pytest.raises(ValueError, match="get_next_agent"): + # 使用asyncio.run来运行异步方法 + import asyncio + asyncio.run(conductor.get_next_agent()) \ No newline at end of file diff --git a/aikg/tests/ut/test_config_validator.py b/aikg/tests/ut/test_config_validator.py new file mode 100644 index 000000000..65e64b29c --- /dev/null +++ b/aikg/tests/ut/test_config_validator.py @@ -0,0 +1,213 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tempfile +import os +from pathlib import Path +from ai_kernel_generator.config.config_validator import ConfigValidator, load_config + + +class TestConfigValidator: + """测试配置验证器""" + + def test_init_with_valid_config(self, tmp_path): + """测试使用有效配置初始化""" + # 创建一个临时配置文件 + config_content = """ +agent_model_config: + designer: deepseek_r1_default + coder: deepseek_r1_default + verifier: deepseek_r1_default +log_dir: ~/tmp +docs_dir: + designer: resources/docs/triton_docs + coder: resources/docs/triton_docs +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + assert validator.config is not None + + def test_validate_llm_models_success(self, tmp_path): + """测试LLM模型验证成功""" + config_content = """ +agent_model_config: + designer: deepseek_r1_default + coder: deepseek_r1_default +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + # 应该不抛出异常 + validator.validate_llm_models() + + def test_validate_llm_models_invalid_model(self, tmp_path): + """测试LLM模型验证失败 - 无效模型""" + config_content = """ +agent_model_config: + designer: invalid_model_name +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + with pytest.raises(ValueError, match="非法的模型名称配置"): + validator.validate_llm_models() + + def test_validate_llm_models_missing_field(self, tmp_path): + """测试LLM模型验证失败 - 缺少字段""" + config_content = """ +log_dir: ~/tmp +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + with pytest.raises(ValueError, match="配置文件中缺少 agent_model_config 字段"): + validator.validate_llm_models() + + def test_validate_log_dir_success(self, tmp_path): + """测试日志目录验证成功""" + config_content = """ +log_dir: ~/tmp +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + validator.validate_log_dir() + assert "log_dir" in validator.config + + def test_validate_log_dir_missing_field(self, tmp_path): + """测试日志目录验证失败 - 缺少字段""" + config_content = """ +agent_model_config: + designer: deepseek_r1_default +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + with pytest.raises(ValueError, match="配置文件中缺少 log_dir 字段"): + validator.validate_log_dir() + + def test_validate_docs_dir_success(self, tmp_path): + """测试文档目录验证成功""" + # 创建模拟的文档目录结构 + docs_dir = tmp_path / "docs" + docs_dir.mkdir() + + config_content = f""" +docs_dir: + designer: {docs_dir} + coder: {docs_dir} +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + # 应该不抛出异常 + validator.validate_docs_dir() + + def test_validate_docs_dir_missing_field(self, tmp_path): + """测试文档目录验证失败 - 缺少字段""" + config_content = """ +log_dir: ~/tmp +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + with pytest.raises(ValueError, match="配置文件中缺少 docs_dir 字段"): + validator.validate_docs_dir() + + def test_validate_docs_dir_invalid_type(self, tmp_path): + """测试文档目录验证失败 - 无效类型""" + config_content = """ +docs_dir: invalid_type +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + with pytest.raises(ValueError, match="docs_dir 必须是一个字典"): + validator.validate_docs_dir() + + def test_validate_docs_dir_nonexistent_path(self, tmp_path): + """测试文档目录验证失败 - 不存在的路径""" + nonexistent_path = "/nonexistent/path" + + config_content = f""" +docs_dir: + designer: {nonexistent_path} +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + with pytest.raises(ValueError, match="docs_dir 中 designer 指定的目录不存在"): + validator.validate_docs_dir() + + def test_validate_all_success(self, tmp_path): + """测试整体验证成功""" + # 创建模拟的文档目录 + docs_dir = tmp_path / "docs" + docs_dir.mkdir() + + config_content = f""" +agent_model_config: + designer: deepseek_r1_default + coder: deepseek_r1_default +log_dir: ~/tmp +docs_dir: + designer: {docs_dir} + coder: {docs_dir} +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + validator = ConfigValidator(str(config_file)) + # 应该不抛出异常 + validator.validate_all() + + def test_load_config_with_path(self, tmp_path): + """测试通过路径加载配置""" + config_content = """ +agent_model_config: + designer: deepseek_r1_default +log_dir: ~/tmp +docs_dir: + designer: resources/docs/triton_docs +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + config = load_config(config_path=str(config_file)) + assert "agent_model_config" in config + assert "log_dir" in config + assert "docs_dir" in config + + def test_load_config_with_dsl(self): + """测试通过DSL加载默认配置""" + # 注意:这会尝试加载默认配置文件,可能不存在 + with pytest.raises(ValueError, match="No default config found for dsl"): + load_config(dsl="nonexistent") + + def test_load_config_missing_both(self): + """测试缺少dsl和config_path参数""" + with pytest.raises(ValueError, match="No default config found for dsl"): + load_config() \ No newline at end of file diff --git a/aikg/tests/ut/test_core_utils.py b/aikg/tests/ut/test_core_utils.py new file mode 100644 index 000000000..9b669a6c1 --- /dev/null +++ b/aikg/tests/ut/test_core_utils.py @@ -0,0 +1,134 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from ai_kernel_generator.core.utils import check_backend_arch, check_dsl, check_task_type, check_task_config + + +class TestCoreUtils: + """测试核心工具函数""" + + # 测试 check_backend_arch + def test_check_backend_arch_valid_ascend(self): + """测试有效的Ascend后端配置""" + # 应该不抛出异常 + check_backend_arch("ascend", "ascend910b4") + check_backend_arch("ascend", "ascend310p3") + + def test_check_backend_arch_valid_cuda(self): + """测试有效的CUDA后端配置""" + # 应该不抛出异常 + check_backend_arch("cuda", "a100") + check_backend_arch("cuda", "v100") + + def test_check_backend_arch_valid_cpu(self): + """测试有效的CPU后端配置""" + # 应该不抛出异常 + check_backend_arch("cpu", "x86_64") + check_backend_arch("cpu", "aarch64") + + def test_check_backend_arch_invalid_backend(self): + """测试无效的后端""" + with pytest.raises(ValueError, match="backend must be ascend, cuda or cpu"): + check_backend_arch("invalid_backend", "ascend910b4") + + def test_check_backend_arch_invalid_ascend_arch(self): + """测试无效的Ascend架构""" + with pytest.raises(ValueError, match="ascend backend only support"): + check_backend_arch("ascend", "invalid_arch") + + def test_check_backend_arch_invalid_cuda_arch(self): + """测试无效的CUDA架构""" + with pytest.raises(ValueError, match="cuda backend only support"): + check_backend_arch("cuda", "invalid_arch") + + def test_check_backend_arch_invalid_cpu_arch(self): + """测试无效的CPU架构""" + with pytest.raises(ValueError, match="cpu backend only support"): + check_backend_arch("cpu", "invalid_arch") + + # 测试 check_dsl + def test_check_dsl_valid(self): + """测试有效的DSL类型""" + # 应该不抛出异常 + check_dsl("triton") + check_dsl("triton-russia") + check_dsl("swft") + + def test_check_dsl_invalid(self): + """测试无效的DSL类型""" + with pytest.raises(ValueError, match="dsl must be triton or swft"): + check_dsl("invalid_dsl") + + # 测试 check_task_type + def test_check_task_type_valid(self): + """测试有效的任务类型""" + # 应该不抛出异常 + check_task_type("precision_only") + check_task_type("profile") + + def test_check_task_type_invalid(self): + """测试无效的任务类型""" + with pytest.raises(ValueError, match="task_type must be precision_only or profile"): + check_task_type("invalid_task_type") + + # 测试 check_task_config + def test_check_task_config_valid_combinations(self): + """测试有效的配置组合""" + valid_combinations = [ + ("mindspore", "ascend", "ascend910b4", "triton"), + ("mindspore", "ascend", "ascend310p3", "swft"), + ("torch", "ascend", "ascend910b4", "triton"), + ("torch", "ascend", "ascend310p3", "swft"), + ("torch", "cuda", "a100", "triton"), + ("numpy", "ascend", "ascend310p3", "swft"), + ] + + for framework, backend, arch, dsl in valid_combinations: + # 应该不抛出异常 + check_task_config(framework, backend, arch, dsl) + + def test_check_task_config_invalid_framework(self): + """测试无效的框架""" + with pytest.raises(ValueError, match="Unsupported framework"): + check_task_config("invalid_framework", "ascend", "ascend910b4", "triton") + + def test_check_task_config_invalid_backend(self): + """测试无效的后端""" + with pytest.raises(ValueError, match="does not support backend"): + check_task_config("mindspore", "invalid_backend", "ascend910b4", "triton") + + def test_check_task_config_invalid_arch(self): + """测试无效的架构""" + with pytest.raises(ValueError, match="does not support arch"): + check_task_config("mindspore", "ascend", "invalid_arch", "triton") + + def test_check_task_config_invalid_dsl(self): + """测试无效的DSL""" + with pytest.raises(ValueError, match="does not support dsl"): + check_task_config("mindspore", "ascend", "ascend910b4", "invalid_dsl") + + def test_check_task_config_mismatched_combination(self): + """测试不匹配的组合""" + # ascend910b4只支持triton,但使用了swft + with pytest.raises(ValueError, match="does not support dsl"): + check_task_config("mindspore", "ascend", "ascend910b4", "swft") + + # cuda只支持triton,但使用了swft + with pytest.raises(ValueError, match="does not support dsl"): + check_task_config("torch", "cuda", "a100", "swft") + + # numpy只支持swft,但使用了triton + with pytest.raises(ValueError, match="does not support dsl"): + check_task_config("numpy", "ascend", "ascend310p3", "triton") \ No newline at end of file diff --git a/aikg/tests/ut/test_database_components.py b/aikg/tests/ut/test_database_components.py new file mode 100644 index 000000000..368c643ed --- /dev/null +++ b/aikg/tests/ut/test_database_components.py @@ -0,0 +1,221 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +from ai_kernel_generator.database.database import Database, RetrievalStrategy +from ai_kernel_generator.database.vector_store import VectorStore +from ai_kernel_generator.database.evolve_database import EvolveDatabase + + +class TestDatabase: + """测试数据库类""" + + def test_database_init_default(self): + """测试数据库默认初始化""" + with patch('ai_kernel_generator.database.database.DEFAULT_DATABASE_PATH') as mock_default_path: + mock_default_path.parent.parent.__truediv__.return_value = Mock() + db = Database() + assert db is not None + + def test_database_init_custom(self): + """测试数据库自定义初始化""" + config_path = "/test/config/path" + database_path = "/test/database/path" + + with patch('pathlib.Path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test_config")): + db = Database(config_path=config_path, database_path=database_path) + assert db.database_path == database_path + + def test_database_init_random_mode(self): + """测试数据库随机模式初始化""" + with patch('ai_kernel_generator.database.database.DEFAULT_DATABASE_PATH') as mock_default_path: + mock_default_path.parent.parent.__truediv__.return_value = Mock() + db = Database(random_mode=True) + assert db.random_mode is True + + @patch('ai_kernel_generator.database.database.FeatureExtractor') + def test_extract_features_success(self, mock_feature_extractor): + """测试成功提取特征""" + # 模拟特征提取器 + mock_extractor_instance = Mock() + mock_feature_extractor.return_value = mock_extractor_instance + + mock_extractor_instance.run.return_value = ("test_content", "", "") + + class MockParsedContent: + def __init__(self): + self.op_name = "test_op" + self.op_type = "test_type" + self.input_specs = "test_input" + self.output_specs = "test_output" + self.computation = "test_computation" + self.schedule = "test_schedule" + self.description = "test_description" + + mock_extractor_instance.feature_parser.parse.return_value = MockParsedContent() + + config = { + "agent_model_config": { + "feature_extractor": "deepseek_r1_default" + } + } + + db = Database(config_path="/test/config", database_path="/test/database") + db.config = config + + features = db.extract_features( + impl_code="test_impl_code", + framework_code="test_framework_code", + backend="cuda", + arch="a100", + dsl="triton" + ) + + assert features is not None + assert features["op_name"] == "test_op" + + def test_randomicity_search_success(self): + """测试随机搜索成功""" + with patch('pathlib.Path.exists', return_value=True): + with patch('pathlib.Path.is_dir', return_value=True): + with patch('pathlib.Path.iterdir') as mock_iterdir: + # 模拟目录中的案例 + mock_case1 = Mock() + mock_case1.is_dir.return_value = True + mock_case2 = Mock() + mock_case2.is_dir.return_value = True + mock_iterdir.return_value = [mock_case1, mock_case2] + + with patch('random.sample') as mock_sample: + mock_sample.return_value = [mock_case1] + + db = Database(database_path="/test/database") + result = db.randomicity_search( + output_content=["impl_code"], + k=1, + backend="cuda", + arch="a100", + dsl="triton", + framework="torch" + ) + + assert len(result) == 1 + + def test_randomicity_search_invalid_params(self): + """测试随机搜索时无效参数""" + db = Database(database_path="/test/database") + + with pytest.raises(ValueError, match="arch and dsl must be provided"): + db.randomicity_search( + output_content=["impl_code"], + k=1, + backend="cuda" + # 缺少arch和dsl + ) + + def test_randomicity_search_path_not_exists(self): + """测试随机搜索时路径不存在""" + with patch('pathlib.Path.exists', return_value=False): + db = Database(database_path="/test/database") + + with pytest.raises(ValueError, match="Sample path"): + db.randomicity_search( + output_content=["impl_code"], + k=1, + backend="cuda", + arch="a100", + dsl="triton", + framework="torch" + ) + + +class TestVectorStore: + """测试向量存储类""" + + def test_vector_store_init(self): + """测试向量存储初始化""" + with patch('os.environ.__setitem__'): + with patch('ai_kernel_generator.database.vector_store.FAISS') as mock_faiss: + mock_faiss.load_local.return_value = Mock() + + with patch('pathlib.Path.exists', return_value=True): + with patch('pathlib.Path.__truediv__', return_value=Mock()): + with patch('ai_kernel_generator.database.vector_store.HuggingFaceEmbeddings'): + vector_store = VectorStore("/test/config", "/test/database") + assert vector_store is not None + + def test_vector_store_build_empty(self): + """测试构建空向量存储""" + with patch('os.environ.__setitem__'): + with patch('ai_kernel_generator.database.vector_store.HuggingFaceEmbeddings'): + with patch('pathlib.Path.rglob', return_value=[]): # 没有元数据文件 + with patch('ai_kernel_generator.database.vector_store.FAISS') as mock_faiss: + vector_store = VectorStore("/test/config", "/test/database") + # 应该成功创建空的向量存储 + assert vector_store is not None + + +class TestEvolveDatabase: + """测试进化数据库类""" + + def test_evolve_database_init(self): + """测试进化数据库初始化""" + with patch('ai_kernel_generator.database.evolve_database.DEFAULT_EVOLVE_DATABASE_PATH') as mock_default_path: + mock_default_path.parent.parent.__truediv__.return_value = Mock() + db = EvolveDatabase() + assert db is not None + # 应该使用默认的进化数据库路径 + assert db.database_path is not None + + def test_evolve_database_init_custom(self): + """测试进化数据库自定义初始化""" + config_path = "/test/config/path" + database_path = "/test/database/path" + + with patch('pathlib.Path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data="test_config")): + db = EvolveDatabase(config_path=config_path, database_path=database_path) + assert db.database_path == database_path + + def test_optimality_search_success(self): + """测试最优性搜索成功""" + with patch('pathlib.Path.rglob') as mock_rglob: + # 模拟元数据文件 + mock_metadata_file = Mock() + mock_rglob.return_value = [mock_metadata_file] + + with patch('builtins.open', mock_open(read_data='{"profile": 1.0}')): + db = EvolveDatabase(database_path="/test/database") + result = db.optimality_search() + + # 应该返回一个包含文档的列表 + assert isinstance(result, list) + + def test_optimality_search_invalid_json(self): + """测试最优性搜索时无效JSON""" + with patch('pathlib.Path.rglob') as mock_rglob: + # 模拟元数据文件 + mock_metadata_file = Mock() + mock_rglob.return_value = [mock_metadata_file] + + with patch('builtins.open', mock_open(read_data='invalid json')): + with patch('json.load', side_effect=json.JSONDecodeError("test", "test", 0)): + db = EvolveDatabase(database_path="/test/database") + result = db.optimality_search() + + # 应该处理异常并返回空列表或默认值 + assert isinstance(result, list) \ No newline at end of file diff --git a/aikg/tests/ut/test_designer.py b/aikg/tests/ut/test_designer.py new file mode 100644 index 000000000..8cc006322 --- /dev/null +++ b/aikg/tests/ut/test_designer.py @@ -0,0 +1,146 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +import os +from pathlib import Path +from ai_kernel_generator.core.agent.designer import Designer + + +class TestDesigner: + """测试Designer类""" + + def test_designer_init_success(self): + """测试Designer成功初始化""" + config = { + "agent_model_config": { + "designer": "deepseek_r1_default" + }, + "docs_dir": { + "designer": "resources/docs/triton_docs" + } + } + + with patch('ai_kernel_generator.core.agent.designer.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + designer = Designer( + op_name="test_op", + task_desc="test task description", + dsl="triton", + backend="cuda", + arch="a100", + config=config + ) + + assert designer.op_name == "test_op" + assert designer.task_desc == "test task description" + assert designer.dsl == "triton" + assert designer.backend == "cuda" + assert designer.arch == "a100" + + def test_designer_init_missing_config(self): + """测试Designer初始化时缺少配置""" + with pytest.raises(ValueError, match="config is required for Designer"): + Designer( + op_name="test_op", + task_desc="test task description", + dsl="triton", + backend="cuda", + arch="a100" + # 缺少config参数 + ) + + def test_designer_init_missing_parser(self): + """测试Designer初始化时缺少解析器""" + config = { + "agent_model_config": { + "designer": "deepseek_r1_default" + } + } + + with patch('ai_kernel_generator.core.agent.designer.create_step_parser') as mock_parser: + mock_parser.return_value = None + + with pytest.raises(ValueError, match="Failed to create designer parser"): + Designer( + op_name="test_op", + task_desc="test task description", + dsl="triton", + backend="cuda", + arch="a100", + config=config + ) + + @patch('ai_kernel_generator.core.agent.designer.extract_function_details') + def test_designer_init_swft_with_api(self, mock_extract_function_details): + """测试Designer初始化SWFT实现类型时添加API支持""" + mock_extract_function_details.return_value = {"api": "test api details"} + + config = { + "agent_model_config": { + "designer": "deepseek_r1_default" + }, + "docs_dir": { + "designer": "resources/docs/swft_docs" + } + } + + with patch('ai_kernel_generator.core.agent.designer.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + designer = Designer( + op_name="test_op", + task_desc="test task description", + dsl="swft", + backend="ascend", + arch="ascend310p3", + config=config + ) + + assert "supported_compute_api" in designer.base_doc + + @patch('ai_kernel_generator.core.agent.designer.extract_function_details') + def test_designer_init_swft_api_extraction_failure(self, mock_extract_function_details): + """测试Designer初始化SWFT时API提取失败""" + mock_extract_function_details.side_effect = Exception("API extraction failed") + + config = { + "agent_model_config": { + "designer": "deepseek_r1_default" + }, + "docs_dir": { + "designer": "resources/docs/swft_docs" + } + } + + with patch('ai_kernel_generator.core.agent.designer.create_step_parser') as mock_parser: + mock_parser.return_value = Mock() + mock_parser.return_value.get_format_instructions.return_value = "test format instructions" + + designer = Designer( + op_name="test_op", + task_desc="test task description", + dsl="swft", + backend="ascend", + arch="ascend310p3", + config=config + ) + + # 应该不会抛出异常,但supported_compute_api可能不在base_doc中 + assert True # 如果没有抛出异常,测试通过 \ No newline at end of file diff --git a/aikg/tests/ut/test_environment_check.py b/aikg/tests/ut/test_environment_check.py new file mode 100644 index 000000000..167dee74f --- /dev/null +++ b/aikg/tests/ut/test_environment_check.py @@ -0,0 +1,261 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +from ai_kernel_generator.utils.environment_check import check_env, _check_llm_api, _load_llm_config + + +class TestEnvironmentCheck: + """测试环境检查工具""" + + @patch('importlib.import_module') + def test_check_env_success(self, mock_import): + """测试环境检查成功""" + # 模拟所有导入都成功 + mock_import.return_value = Mock() + + with patch('subprocess.run') as mock_subprocess: + # 模拟nvidia-smi成功 + mock_subprocess.return_value.returncode = 0 + + with patch('ai_kernel_generator.utils.environment_check._check_llm_api', return_value=True): + result = check_env( + framework="torch", + backend="cuda", + dsl="triton", + config={"agent_model_config": {"designer": "deepseek_r1_default"}} + ) + + assert result is True + + @patch('importlib.import_module') + def test_check_env_missing_base_package(self, mock_import): + """测试缺少基础包的环境检查""" + # 模拟导入numpy失败 + def import_side_effect(name): + if name == 'numpy': + raise ImportError("No module named 'numpy'") + return Mock() + + mock_import.side_effect = import_side_effect + + with patch('ai_kernel_generator.utils.environment_check._check_llm_api', return_value=True): + result = check_env( + framework="torch", + backend="cuda", + dsl="triton" + ) + + # 应该返回False,因为缺少基础包 + # 注意:实际输出可能因实现而异 + assert result is False + + @patch('importlib.import_module') + def test_check_env_missing_framework(self, mock_import): + """测试缺少框架的环境检查""" + # 模拟导入torch失败 + def import_side_effect(name): + if name == 'torch': + raise ImportError("No module named 'torch'") + return Mock() + + mock_import.side_effect = import_side_effect + + with patch('ai_kernel_generator.utils.environment_check._check_llm_api', return_value=True): + result = check_env( + framework="torch", + backend="cuda", + dsl="triton" + ) + + # 应该返回False,因为缺少框架 + assert result is False + + @patch('importlib.import_module') + def test_check_env_missing_dsl(self, mock_import): + """测试缺少DSL的环境检查""" + # 模拟导入triton失败 + def import_side_effect(name): + if name == 'triton': + raise ImportError("No module named 'triton'") + return Mock() + + mock_import.side_effect = import_side_effect + + with patch('ai_kernel_generator.utils.environment_check._check_llm_api', return_value=True): + result = check_env( + framework="torch", + backend="cuda", + dsl="triton" + ) + + # 应该返回False,因为缺少DSL + assert result is False + + @patch('importlib.import_module') + def test_check_env_nvidia_smi_not_found(self, mock_import): + """测试找不到nvidia-smi的环境检查""" + # 模拟所有导入都成功 + mock_import.return_value = Mock() + + with patch('subprocess.run') as mock_subprocess: + # 模拟nvidia-smi命令未找到 + mock_subprocess.side_effect = FileNotFoundError("No such file or directory") + + with patch('ai_kernel_generator.utils.environment_check._check_llm_api', return_value=True): + result = check_env( + framework="torch", + backend="cuda", + dsl="triton" + ) + + # 应该返回False,因为找不到nvidia-smi + assert result is False + + def test_load_llm_config_file_not_found(self): + """测试加载不存在的LLM配置文件""" + with patch('pathlib.Path.exists', return_value=False): + config = _load_llm_config() + assert config is None + + def test_load_llm_config_success(self): + """测试成功加载LLM配置文件""" + test_config = { + "deepseek_r1_default": { + "api_base": "http://test.api", + "model": "test_model", + "api_key_env": "TEST_API_KEY" + } + } + + with patch('pathlib.Path.exists', return_value=True): + with patch('builtins.open', mock_open(read_data=json.dumps(test_config))): + with patch('yaml.safe_load', return_value=test_config): + config = _load_llm_config() + assert config is not None + assert "deepseek_r1_default" in config + + def test_check_llm_api_no_config(self): + """测试没有配置时的LLM API检查""" + result = _check_llm_api() + assert result is True # 应该返回True,因为跳过了检查 + + def test_check_llm_api_with_config_missing_model(self): + """测试配置中缺少模型时的LLM API检查""" + config = { + "agent_model_config": { + "designer": "nonexistent_model" # 不存在的模型 + } + } + + llm_config = { + "deepseek_r1_default": { + "api_base": "http://test.api", + "model": "test_model", + "api_key_env": "TEST_API_KEY" + } + } + + with patch('ai_kernel_generator.utils.environment_check._load_llm_config', return_value=llm_config): + result = _check_llm_api(config=config) + assert result is False # 应该返回False,因为模型不存在 + + @patch('requests.get') + @patch('requests.post') + def test_check_llm_api_connection_success(self, mock_post, mock_get): + """测试LLM API连接成功""" + config = { + "agent_model_config": { + "designer": "deepseek_r1_default" + } + } + + llm_config = { + "deepseek_r1_default": { + "api_base": "http://test.api", + "model": "test_model", + "api_key_env": "TEST_API_KEY" + } + } + + # 模拟GET请求成功 + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + mock_post.side_effect = Exception("Should not be called") + + with patch('ai_kernel_generator.utils.environment_check._load_llm_config', return_value=llm_config): + with patch('os.getenv', return_value="test_key"): + result = _check_llm_api(config=config) + assert result is True + + @patch('requests.get') + @patch('requests.post') + def test_check_llm_api_connection_failure(self, mock_post, mock_get): + """测试LLM API连接失败""" + config = { + "agent_model_config": { + "designer": "deepseek_r1_default" + } + } + + llm_config = { + "deepseek_r1_default": { + "api_base": "http://test.api", + "model": "test_model", + "api_key_env": "TEST_API_KEY" + } + } + + # 模拟所有请求都失败 + mock_get.side_effect = Exception("Connection failed") + mock_post.side_effect = Exception("Connection failed") + + with patch('ai_kernel_generator.utils.environment_check._load_llm_config', return_value=llm_config): + with patch('os.getenv', return_value="test_key"): + result = _check_llm_api(config=config) + assert result is False + + def test_check_env_for_task_success(self): + """测试为任务检查环境成功""" + from ai_kernel_generator.utils.environment_check import check_env_for_task + + with patch('ai_kernel_generator.utils.environment_check.check_env', return_value=True): + try: + check_env_for_task( + framework="torch", + backend="cuda", + dsl="triton", + config={"agent_model_config": {"designer": "deepseek_r1_default"}} + ) + # 如果没有异常,测试通过 + assert True + except RuntimeError: + # 如果抛出异常,测试失败 + assert False, "Should not raise RuntimeError when check_env returns True" + + def test_check_env_for_task_failure(self): + """测试为任务检查环境失败""" + from ai_kernel_generator.utils.environment_check import check_env_for_task + + with patch('ai_kernel_generator.utils.environment_check.check_env', return_value=False): + with pytest.raises(RuntimeError, match="环境检查失败"): + check_env_for_task( + framework="torch", + backend="cuda", + dsl="triton", + config={"agent_model_config": {"designer": "deepseek_r1_default"}} + ) \ No newline at end of file diff --git a/aikg/tests/ut/test_hardware_utils.py b/aikg/tests/ut/test_hardware_utils.py new file mode 100644 index 000000000..57c20e9c9 --- /dev/null +++ b/aikg/tests/ut/test_hardware_utils.py @@ -0,0 +1,147 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import patch, mock_open +from ai_kernel_generator.utils.hardware_utils import get_cpu_info, get_hardware_doc + + +class TestHardwareUtils: + """测试硬件工具函数""" + + @patch('platform.system') + @patch('subprocess.run') + def test_get_cpu_info_linux_success(self, mock_subprocess, mock_system): + """测试Linux系统成功获取CPU信息""" + mock_system.return_value = "Linux" + mock_subprocess.return_value.returncode = 0 + mock_subprocess.return_value.stdout = """Architecture: x86_64 +CPU(s): 8 +Thread(s) per core: 2 +Core(s) per socket: 4 +Socket(s): 1 +Model name: Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz +CPU MHz: 2000.000 +CPU max MHz: 4000.000 +L1d cache: 128 KiB +L1i cache: 128 KiB +L2 cache: 1 MiB +L3 cache: 8 MiB +Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht""" + + cpu_info = get_cpu_info() + assert "# Linux CPU信息" in cpu_info + assert "Architecture" in cpu_info + assert "CPU(s)" in cpu_info + + @patch('platform.system') + @patch('subprocess.run') + def test_get_cpu_info_darwin_success(self, mock_subprocess, mock_system): + """测试macOS系统成功获取CPU信息""" + mock_system.return_value = "Darwin" + mock_subprocess.return_value.returncode = 0 + mock_subprocess.return_value.stdout = """hw.ncpu: 8 +hw.physicalcpu: 4 +hw.logicalcpu: 8 +hw.cpufrequency: 2000000000 +hw.cpufrequency_max: 4000000000 +hw.l1dcachesize: 131072 +hw.l1icachesize: 131072 +hw.l2cachesize: 1048576 +hw.l3cachesize: 8388608 +machdep.cpu.brand_string: Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz +machdep.cpu.core_count: 4 +machdep.cpu.thread_count: 8""" + + cpu_info = get_cpu_info() + assert "# macOS CPU信息" in cpu_info + assert "hw.ncpu" in cpu_info + assert "machdep.cpu.brand_string" in cpu_info + + @patch('platform.system') + @patch('subprocess.run') + def test_get_cpu_info_windows_success(self, mock_subprocess, mock_system): + """测试Windows系统成功获取CPU信息""" + mock_system.return_value = "Windows" + mock_subprocess.return_value.returncode = 0 + mock_subprocess.return_value.stdout = """Name=Intel(R) Core(TM) i7-8550U CPU @ 1.80GHz +NumberOfCores=4 +NumberOfLogicalProcessors=8 +L2CacheSize=1024 +L3CacheSize=8192 +MaxClockSpeed=2000 +Architecture=9 +Family=197 +Manufacturer=GenuineIntel""" + + cpu_info = get_cpu_info() + assert "# Windows CPU信息" in cpu_info + assert "Name=" in cpu_info + assert "NumberOfCores=" in cpu_info + + @patch('platform.system') + @patch('subprocess.run') + def test_get_cpu_info_subprocess_failure(self, mock_subprocess, mock_system): + """测试子进程执行失败的情况""" + mock_system.return_value = "Linux" + mock_subprocess.side_effect = Exception("Subprocess failed") + + cpu_info = get_cpu_info() + assert cpu_info == "" # 应该返回空字符串 + + @patch('platform.system') + def test_get_cpu_info_unsupported_system(self, mock_system): + """测试不支持的操作系统""" + mock_system.return_value = "UnsupportedOS" + + cpu_info = get_cpu_info() + assert cpu_info == "" # 应该返回空字符串 + + @patch('ai_kernel_generator.utils.hardware_utils.get_project_root') + def test_get_hardware_doc_cpu_backend(self, mock_get_project_root): + """测试CPU后端的硬件文档获取""" + with patch('ai_kernel_generator.utils.hardware_utils.get_cpu_info') as mock_get_cpu_info: + mock_get_cpu_info.return_value = "# CPU Information" + + doc = get_hardware_doc("cpu", "any_arch") + assert doc == "# CPU Information" + + def test_get_hardware_doc_invalid_backend(self): + """测试无效后端的硬件文档获取""" + with pytest.raises(ValueError, match="不支持的backend"): + get_hardware_doc("invalid_backend", "any_arch") + + def test_get_hardware_doc_invalid_architecture(self): + """测试无效架构的硬件文档获取""" + with pytest.raises(ValueError, match="不支持的architecture"): + get_hardware_doc("ascend", "invalid_arch") + + @patch('ai_kernel_generator.utils.hardware_utils.get_project_root') + @patch('builtins.open', new_callable=mock_open, read_data="# Hardware Documentation") + def test_get_hardware_doc_success(self, mock_file, mock_get_project_root): + """测试成功获取硬件文档""" + mock_get_project_root.return_value = "/mock/project/root" + + doc = get_hardware_doc("ascend", "ascend910b4") + assert doc == "# Hardware Documentation" + mock_file.assert_called_once() + + @patch('ai_kernel_generator.utils.hardware_utils.get_project_root') + @patch('builtins.open', side_effect=FileNotFoundError("File not found")) + def test_get_hardware_doc_file_not_found(self, mock_file, mock_get_project_root): + """测试硬件文档文件不存在的情况""" + mock_get_project_root.return_value = "/mock/project/root" + + with pytest.raises(ValueError, match="硬件文档不存在"): + get_hardware_doc("ascend", "ascend910b4") \ No newline at end of file diff --git a/aikg/tests/ut/test_kernel_verifier.py b/aikg/tests/ut/test_kernel_verifier.py new file mode 100644 index 000000000..185aa46ac --- /dev/null +++ b/aikg/tests/ut/test_kernel_verifier.py @@ -0,0 +1,300 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +import os +from pathlib import Path +from ai_kernel_generator.core.verifier.kernel_verifier import KernelVerifier + + +class TestKernelVerifier: + """测试Kernel验证器""" + + def test_kernel_verifier_init_success(self): + """测试Kernel验证器成功初始化""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + assert verifier.op_name == "test_op" + assert verifier.framework_code == "test_framework_code" + assert verifier.task_id == "test_task" + assert verifier.framework == "torch" + assert verifier.dsl == "triton" + assert verifier.backend == "cuda" + assert verifier.arch == "a100" + assert verifier.impl_func_name == "test_op_triton_torch" + + def test_kernel_verifier_init_with_custom_func_name(self): + """测试Kernel验证器使用自定义函数名初始化""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + impl_func_name="custom_func_name", + config={"log_dir": "/test/log/dir"} + ) + + assert verifier.impl_func_name == "custom_func_name" + + def test_kernel_verifier_init_invalid_cuda_arch(self): + """测试Kernel验证器初始化时CUDA后端与无效架构的组合""" + with pytest.raises(ValueError, match="cuda后端只支持a100和v100架构"): + KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="torch", + dsl="triton", + backend="cuda", + arch="invalid_arch", + config={"log_dir": "/test/log/dir"} + ) + + def test_kernel_verifier_init_invalid_ascend_arch(self): + """测试Kernel验证器初始化时Ascend后端与无效架构的组合""" + with pytest.raises(ValueError, match="ascend后端只支持ascend910b4和ascend310p3架构"): + KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="torch", + dsl="triton", + backend="ascend", + arch="invalid_arch", + config={"log_dir": "/test/log/dir"} + ) + + def test_kernel_verifier_init_missing_config(self): + """测试Kernel验证器初始化时缺少配置""" + with pytest.raises(ValueError, match="config is required for KernelVerifier"): + KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100" + # 缺少config参数 + ) + + def test_generate_import_statements_triton_torch(self): + """测试生成Triton+PyTorch的import语句""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + import_statements = verifier._generate_import_statements() + assert "import torch" in import_statements + assert "import triton" in import_statements + assert "import triton.language as tl" in import_statements + + def test_generate_import_statements_triton_mindspore(self): + """测试生成Triton+MindSpore的import语句""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="mindspore", + dsl="triton", + backend="ascend", + arch="ascend910b4", + config={"log_dir": "/test/log/dir"} + ) + + import_statements = verifier._generate_import_statements() + assert "import torch" in import_statements + assert "import triton" in import_statements + assert "import triton.language as tl" in import_statements + assert "import mindspore as ms" in import_statements + + def test_generate_import_statements_swft(self): + """测试生成SWFT的import语句""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="numpy", + dsl="swft", + backend="ascend", + arch="ascend310p3", + config={"log_dir": "/test/log/dir"} + ) + + import_statements = verifier._generate_import_statements() + assert "from swft.core import *" in import_statements + assert "from swft.api import *" in import_statements + assert "import numpy as np" in import_statements + + def test_generate_import_statements_numpy(self): + """测试生成NumPy的import语句""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + framework="numpy", + dsl="custom_dsl", + backend="cpu", + arch="x86_64", + config={"log_dir": "/test/log/dir"} + ) + + import_statements = verifier._generate_import_statements() + assert "import numpy as np" in import_statements + + @patch('os.makedirs') + @patch('os.path.exists') + @patch('builtins.open', new_callable=mock_open) + def test_gen_verify_project(self, mock_file, mock_exists, mock_makedirs): + """测试生成验证项目""" + mock_exists.return_value = True + + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + with patch.object(verifier, '_generate_import_statements', return_value="import torch\n"): + verifier.gen_verify_project( + impl_code="test_impl_code", + verify_dir="/test/verify/dir" + ) + + # 验证文件被创建 + assert mock_file.call_count >= 3 # 至少创建3个文件 + + @patch('os.path.exists') + @patch('ai_kernel_generator.core.verifier.kernel_verifier.run_command') + def test_run_verify_success(self, mock_run_command, mock_exists): + """测试成功运行验证""" + mock_exists.return_value = True + mock_run_command.return_value = (True, "") # 验证成功 + + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + with patch('os.getcwd', return_value="/original/dir"): + with patch('os.chdir'): + result, log = verifier.run_verify("/test/verify/dir") + + # 验证命令执行 + mock_run_command.assert_called_once() + assert result is True + + @patch('os.path.exists') + @patch('ai_kernel_generator.core.verifier.kernel_verifier.run_command') + def test_run_verify_failure(self, mock_run_command, mock_exists): + """测试运行验证失败""" + mock_exists.return_value = True + mock_run_command.return_value = (False, "test error") # 验证失败 + + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + with patch('os.getcwd', return_value="/original/dir"): + with patch('os.chdir'): + result, log = verifier.run_verify("/test/verify/dir") + + # 验证命令执行 + mock_run_command.assert_called_once() + assert result is False + assert log == "test error" + + @patch('os.makedirs') + @patch('os.path.exists') + @patch('builtins.open', new_callable=mock_open) + def test_gen_profile_project(self, mock_file, mock_exists, mock_makedirs): + """测试生成性能分析项目""" + mock_exists.return_value = True + + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + verifier.gen_profile_project( + verify_dir="/test/verify/dir", + device_id=0, + warmup_times=5, + run_times=50 + ) + + # 验证文件被创建 + assert mock_file.call_count >= 2 # 创建2个性能分析文件 + + def test_save_speedup_result(self): + """测试保存加速比结果""" + verifier = KernelVerifier( + op_name="test_op", + framework_code="test_framework_code", + task_id="test_task", + framework="torch", + dsl="triton", + backend="cuda", + arch="a100", + config={"log_dir": "/test/log/dir"} + ) + + with patch('os.makedirs'): + with patch('builtins.open', mock_open()): + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + verifier.save_speedup_result( + speedup=2.5, + base_time=100.0, + gen_time=40.0, + unique_dir="test_unique_dir" + ) \ No newline at end of file diff --git a/aikg/tests/ut/test_model_loader.py b/aikg/tests/ut/test_model_loader.py new file mode 100644 index 000000000..594571d34 --- /dev/null +++ b/aikg/tests/ut/test_model_loader.py @@ -0,0 +1,158 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import os +from unittest.mock import Mock, patch, mock_open, MagicMock +import yaml +from ai_kernel_generator.core.llm.model_loader import create_model, CONFIG_PATH + + +class TestModelLoader: + """测试LLM模型加载器""" + + def test_create_model_default(self): + """测试创建默认模型""" + # 由于需要配置文件,这里主要测试错误路径 + with patch('os.path.exists', return_value=False): + with pytest.raises(FileNotFoundError): + create_model() + + def test_create_model_invalid_preset(self, tmp_path): + """测试创建无效预设的模型""" + # 创建一个临时配置文件 + config_content = """ +default_preset: "test_preset" +test_preset: + api_base: "http://test.api" + model: "test_model" + api_key_env: "TEST_API_KEY" +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + # 设置环境变量 + os.environ["TEST_API_KEY"] = "test_key" + + # 尝试创建不存在的预设 + with pytest.raises(ValueError, match="预设 'nonexistent_preset' 未找到"): + create_model(name="nonexistent_preset", config_path=str(config_file)) + + # 清理环境变量 + del os.environ["TEST_API_KEY"] + + def test_create_model_missing_config_file(self): + """测试缺少配置文件时创建模型""" + with pytest.raises(FileNotFoundError): + create_model(config_path="/nonexistent/path/to/config.yaml") + + @patch('ai_kernel_generator.core.llm.model_loader.ChatOllama') + def test_create_ollama_model(self, mock_chat_ollama): + """测试创建Ollama模型""" + # 创建一个临时配置文件 + config_content = """ +ollama_test_model: + model: "test_model" + temperature: 0.7 +""" + with patch('ai_kernel_generator.core.llm.model_loader.CONFIG_PATH', new=MagicMock()): + with patch('builtins.open', mock_open(read_data=config_content)): + with patch('yaml.safe_load', return_value=yaml.safe_load(config_content)): + # 模拟配置文件存在 + with patch('os.path.exists', return_value=True): + model = create_model(name="ollama_test_model") + # 验证是否创建了ChatOllama实例 + mock_chat_ollama.assert_called_once() + + @patch('ai_kernel_generator.core.llm.model_loader.ChatDeepSeek') + def test_create_deepseek_model(self, mock_chat_deepseek): + """测试创建DeepSeek模型""" + # 创建一个临时配置文件 + config_content = """ +deepseek_test_model: + api_base: "http://test.api" + model: "test_model" + api_key_env: "TEST_API_KEY" +""" + # 设置环境变量 + os.environ["TEST_API_KEY"] = "test_key" + + with patch('ai_kernel_generator.core.llm.model_loader.CONFIG_PATH', new=MagicMock()): + with patch('builtins.open', mock_open(read_data=config_content)): + with patch('yaml.safe_load', return_value=yaml.safe_load(config_content)): + # 模拟配置文件存在 + with patch('os.path.exists', return_value=True): + model = create_model(name="deepseek_test_model") + # 验证是否创建了ChatDeepSeek实例 + mock_chat_deepseek.assert_called_once() + + # 清理环境变量 + del os.environ["TEST_API_KEY"] + + def test_create_model_missing_api_key_env(self, tmp_path): + """测试缺少API密钥环境变量配置时创建模型""" + # 创建一个临时配置文件 + config_content = """ +test_model: + api_base: "http://test.api" + model: "test_model" + # 缺少api_key_env +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + with pytest.raises(ValueError, match="未配置 api_key_env"): + create_model(name="test_model", config_path=str(config_file)) + + def test_create_model_missing_api_key(self, tmp_path): + """测试缺少API密钥时创建模型""" + # 创建一个临时配置文件 + config_content = """ +test_model: + api_base: "http://test.api" + model: "test_model" + api_key_env: "TEST_API_KEY" +""" + config_file = tmp_path / "test_config.yaml" + config_file.write_text(config_content) + + # 确保环境变量未设置 + if "TEST_API_KEY" in os.environ: + del os.environ["TEST_API_KEY"] + + with pytest.raises(ValueError, match="API密钥未找到"): + create_model(name="test_model", config_path=str(config_file)) + + @patch('ai_kernel_generator.core.llm.model_loader.ChatOllama') + def test_create_ollama_model_with_env_override(self, mock_chat_ollama): + """测试使用环境变量覆盖API基础URL创建Ollama模型""" + # 创建一个临时配置文件 + config_content = """ +ollama_test_model: + model: "test_model" +""" + # 设置环境变量 + os.environ["AIKG_OLLAMA_API_BASE"] = "http://custom.ollama.api" + + with patch('ai_kernel_generator.core.llm.model_loader.CONFIG_PATH', new=MagicMock()): + with patch('builtins.open', mock_open(read_data=config_content)): + with patch('yaml.safe_load', return_value=yaml.safe_load(config_content)): + # 模拟配置文件存在 + with patch('os.path.exists', return_value=True): + model = create_model(name="ollama_test_model") + # 验证是否创建了ChatOllama实例 + mock_chat_ollama.assert_called_once() + + # 清理环境变量 + del os.environ["AIKG_OLLAMA_API_BASE"] \ No newline at end of file diff --git a/aikg/tests/ut/test_parser_registry.py b/aikg/tests/ut/test_parser_registry.py new file mode 100644 index 000000000..aeee548f8 --- /dev/null +++ b/aikg/tests/ut/test_parser_registry.py @@ -0,0 +1,141 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tempfile +from pathlib import Path +from ai_kernel_generator.utils.parser_registry import create_step_parser, create_conductor_parser, _convert_to_internal_format + + +class TestParserRegistry: + """测试解析器注册""" + + def test_convert_to_internal_format_simple(self): + """测试简单格式转换为内部格式""" + parser_definition = { + 'output_fields': { + 'code': 'str', + 'explanation': 'str' + } + } + + internal_format = _convert_to_internal_format(parser_definition) + + assert 'output_fields' in internal_format + assert len(internal_format['output_fields']) == 2 + assert internal_format['output_fields']['code']['field_type'] == 'str' + assert internal_format['output_fields']['code']['mandatory'] is True + assert internal_format['output_fields']['explanation']['field_type'] == 'str' + assert internal_format['output_fields']['explanation']['mandatory'] is True + + def test_convert_to_internal_format_detailed(self): + """测试详细格式转换为内部格式""" + parser_definition = { + 'output_fields': { + 'code': { + 'field_type': 'str', + 'mandatory': True, + 'field_description': 'Generated code' + }, + 'explanation': { + 'field_type': 'str', + 'mandatory': False, + 'field_description': 'Code explanation' + } + } + } + + internal_format = _convert_to_internal_format(parser_definition) + + assert 'output_fields' in internal_format + assert len(internal_format['output_fields']) == 2 + assert internal_format['output_fields']['code']['field_type'] == 'str' + assert internal_format['output_fields']['code']['mandatory'] is True + assert internal_format['output_fields']['code']['field_description'] == 'Generated code' + assert internal_format['output_fields']['explanation']['field_type'] == 'str' + assert internal_format['output_fields']['explanation']['mandatory'] is False + assert internal_format['output_fields']['explanation']['field_description'] == 'Code explanation' + + def test_create_conductor_parser(self): + """测试创建Conductor解析器""" + parser = create_conductor_parser() + + # 检查解析器是否正确创建 + assert parser is not None + assert hasattr(parser, 'pydantic_object') + + # 检查解析器的字段 + model = parser.pydantic_object + assert 'decision' in model.model_fields + assert 'suggestion' in model.model_fields + + def test_create_step_parser_no_output_format(self, tmp_path): + """测试为没有输出格式的步骤创建解析器""" + # 创建一个临时的workflow配置文件 + workflow_content = """ +agent_info: + designer: + # 没有output_format配置 + possible_next_agent: [coder] +start_agent: designer +limitation_info: + required: + max_step: 20 +""" + workflow_file = tmp_path / "test_workflow.yaml" + workflow_file.write_text(workflow_content) + + parser = create_step_parser("designer", str(workflow_file)) + assert parser is None + + def test_create_step_parser_no_parser_definition(self, tmp_path): + """测试为没有解析器定义的步骤创建解析器""" + # 创建一个临时的workflow配置文件 + workflow_content = """ +agent_info: + designer: + possible_next_agent: [coder] + output_format: + parser_name: designer_parser + # 没有parser_definition +start_agent: designer +limitation_info: + required: + max_step: 20 +""" + workflow_file = tmp_path / "test_workflow.yaml" + workflow_file.write_text(workflow_content) + + parser = create_step_parser("designer", str(workflow_file)) + assert parser is None + + def test_create_step_parser_invalid_workflow_file(self): + """测试为无效的工作流文件创建解析器""" + with pytest.raises(FileNotFoundError): + create_step_parser("designer", "/nonexistent/path/to/workflow.yaml") + + def test_create_step_parser_missing_agent_info(self, tmp_path): + """测试为缺少agent_info的工作流文件创建解析器""" + # 创建一个临时的无效workflow配置文件 + workflow_content = """ +start_agent: designer +limitation_info: + required: + max_step: 20 +""" + workflow_file = tmp_path / "invalid_workflow.yaml" + workflow_file.write_text(workflow_content) + + with pytest.raises(ValueError, match="No 'agent_info' found in workflow config"): + create_step_parser("designer", str(workflow_file)) \ No newline at end of file diff --git a/aikg/tests/ut/test_process_utils.py b/aikg/tests/ut/test_process_utils.py new file mode 100644 index 000000000..e2888e7ec --- /dev/null +++ b/aikg/tests/ut/test_process_utils.py @@ -0,0 +1,128 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +from ai_kernel_generator.utils.process_utils import run_command + + +class TestProcessUtils: + """测试进程工具函数""" + + @patch('subprocess.Popen') + def test_run_command_success(self, mock_popen): + """测试成功运行命令""" + # 模拟子进程 + mock_process = Mock() + mock_process.communicate.return_value = ("stdout output", "") + mock_process.returncode = 0 + mock_popen.return_value = mock_process + + result, error = run_command(["echo", "test"], "test_command") + + assert result is True + assert error == "" + mock_popen.assert_called_once() + + @patch('subprocess.Popen') + def test_run_command_failure(self, mock_popen): + """测试运行命令失败""" + # 模拟子进程 + mock_process = Mock() + mock_process.communicate.return_value = ("", "error output") + mock_process.returncode = 1 + mock_popen.return_value = mock_process + + result, error = run_command(["false"], "test_command") + + assert result is False + assert error == "error output" + mock_popen.assert_called_once() + + @patch('subprocess.Popen') + def test_run_command_with_stdout(self, mock_popen): + """测试运行命令带标准输出""" + # 模拟子进程 + mock_process = Mock() + mock_process.communicate.return_value = ("test output\n", "") + mock_process.returncode = 0 + mock_popen.return_value = mock_process + + result, error = run_command(["echo", "test"], "test_command") + + assert result is True + assert error == "" + mock_popen.assert_called_once() + + @patch('subprocess.Popen') + def test_run_command_with_stderr(self, mock_popen): + """测试运行命令带标准错误输出""" + # 模拟子进程 + mock_process = Mock() + mock_process.communicate.return_value = ("", "test error") + mock_process.returncode = 1 + mock_popen.return_value = mock_process + + result, error = run_command(["false"], "test_command") + + assert result is False + assert error == "test error" + mock_popen.assert_called_once() + + @patch('subprocess.Popen') + def test_run_command_timeout(self, mock_popen): + """测试运行命令超时处理""" + # 模拟子进程超时 + mock_process = Mock() + mock_process.communicate.side_effect = [Exception("Timeout")] + mock_process.returncode = -1 + mock_popen.return_value = mock_process + + # 注意:这里的测试可能不完全准确,因为实际的超时处理在函数内部有特殊的逻辑 + # 我们主要测试函数不会抛出未处理的异常 + try: + result, error = run_command(["sleep", "10"], "test_command", timeout=1) + # 如果没有异常,说明超时处理正常工作 + except Exception: + # 如果有异常,说明测试失败 + assert False, "run_command should handle timeout without unhandled exceptions" + + @patch('subprocess.Popen') + def test_run_command_exception(self, mock_popen): + """测试运行命令时发生异常""" + mock_popen.side_effect = Exception("Test exception") + + result, error = run_command(["test"], "test_command") + + assert result is False + assert "Test exception" in error + + @patch('subprocess.Popen') + def test_run_command_with_env(self, mock_popen): + """测试运行命令带环境变量""" + # 模拟子进程 + mock_process = Mock() + mock_process.communicate.return_value = ("", "") + mock_process.returncode = 0 + mock_popen.return_value = mock_process + + test_env = {"TEST_VAR": "test_value"} + result, error = run_command(["echo", "test"], "test_command", env=test_env) + + assert result is True + # 验证环境变量被传递 + args, kwargs = mock_popen.call_args + assert "env" in kwargs + assert kwargs["env"]["TEST_VAR"] == "test_value" \ No newline at end of file diff --git a/aikg/tests/ut/test_result_processor.py b/aikg/tests/ut/test_result_processor.py new file mode 100644 index 000000000..6bfbeeefd --- /dev/null +++ b/aikg/tests/ut/test_result_processor.py @@ -0,0 +1,287 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch +from ai_kernel_generator.utils.result_processor import ResultProcessor +from ai_kernel_generator.utils.common_utils import ParserFactory +from ai_kernel_generator.core.trace import Trace + + +class TestResultProcessor: + """测试结果处理器""" + + def test_parse_and_update_code_no_parser(self): + """测试没有解析器时的代码解析和更新""" + task_info = {} + result = '{"code": "test code"}' + + # 创建一个模拟的trace对象 + trace = Mock(spec=Trace) + + success = ResultProcessor.parse_and_update_code( + agent_name="designer", + result=result, + task_info=task_info, + agent_parser=None, # 没有解析器 + trace=trace + ) + + assert success is False + assert len(task_info) == 0 + + def test_parse_and_update_code_empty_result(self): + """测试空结果时的代码解析和更新""" + task_info = {} + parser = Mock() + + # 创建一个模拟的trace对象 + trace = Mock(spec=Trace) + + success = ResultProcessor.parse_and_update_code( + agent_name="designer", + result="", # 空结果 + task_info=task_info, + agent_parser=parser, + trace=trace + ) + + assert success is False + assert len(task_info) == 0 + + def test_parse_and_update_code_single_field(self): + """测试单字段的代码解析和更新""" + task_info = {} + + # 创建一个模拟的解析器 + class MockParsedResult: + def __init__(self): + self.code = "test code" + + # 创建一个模拟的trace对象 + trace = Mock(spec=Trace) + trace.save_parsed_code = Mock() + + with patch.object(ParserFactory, 'robust_parse', return_value=MockParsedResult()): + parser = Mock() + + success = ResultProcessor.parse_and_update_code( + agent_name="designer", + result='{"code": "test code"}', + task_info=task_info, + agent_parser=parser, + trace=trace, + agent_info={ + "designer": { + "output_format": { + "parser_definition": { + "output_fields": { + "code": { + "field_type": "str", + "mandatory": True + } + } + } + } + } + } + ) + + assert success is True + assert "designer_code" in task_info + assert task_info["designer_code"] == "test code" + + def test_parse_and_update_code_multiple_fields(self): + """测试多字段的代码解析和更新""" + task_info = {} + + # 创建一个模拟的解析器 + class MockParsedResult: + def __init__(self): + self.code = "test code" + self.explanation = "test explanation" + + # 创建一个模拟的trace对象 + trace = Mock(spec=Trace) + trace.save_parsed_code = Mock() + + with patch.object(ParserFactory, 'robust_parse', return_value=MockParsedResult()): + parser = Mock() + + success = ResultProcessor.parse_and_update_code( + agent_name="designer", + result='{"code": "test code", "explanation": "test explanation"}', + task_info=task_info, + agent_parser=parser, + trace=trace, + agent_info={ + "designer": { + "output_format": { + "parser_definition": { + "output_fields": { + "code": { + "field_type": "str", + "mandatory": True + }, + "explanation": { + "field_type": "str", + "mandatory": False + } + } + } + } + } + } + ) + + assert success is True + assert "designer_code" in task_info + assert task_info["designer_code"] == "test code" + assert "designer_explanation" in task_info + assert task_info["designer_explanation"] == "test explanation" + + def test_parse_and_update_code_no_code_field(self): + """测试没有code字段的代码解析和更新""" + task_info = {} + + # 创建一个模拟的解析器 + class MockParsedResult: + def __init__(self): + self.explanation = "test explanation" + + # 创建一个模拟的trace对象 + trace = Mock(spec=Trace) + trace.save_parsed_code = Mock() + + with patch.object(ParserFactory, 'robust_parse', return_value=MockParsedResult()): + parser = Mock() + + success = ResultProcessor.parse_and_update_code( + agent_name="designer", + result='{"explanation": "test explanation"}', + task_info=task_info, + agent_parser=parser, + trace=trace, + agent_info={ + "designer": { + "output_format": { + "parser_definition": { + "output_fields": { + "explanation": { + "field_type": "str", + "mandatory": False + } + } + } + } + } + } + ) + + assert success is False # 应该失败,因为没有code字段 + + def test_update_verifier_result_true(self): + """测试更新verifier结果为True""" + task_info = {} + + ResultProcessor.update_verifier_result( + result="True", + error_log="", + task_info=task_info + ) + + assert task_info["verifier_result"] is True + assert task_info["verifier_error"] == "" + + def test_update_verifier_result_false(self): + """测试更新verifier结果为False""" + task_info = {} + + ResultProcessor.update_verifier_result( + result="False", + error_log="test error", + task_info=task_info + ) + + assert task_info["verifier_result"] is False + assert task_info["verifier_error"] == "test error" + + def test_update_verifier_result_with_profile(self): + """测试更新verifier结果包含性能数据""" + task_info = {} + profile_res = (1.0, 2.0, 3.0) + + ResultProcessor.update_verifier_result( + result="True", + error_log="", + task_info=task_info, + profile_res=profile_res + ) + + assert task_info["verifier_result"] is True + assert task_info["profile_res"] == profile_res + + def test_get_agent_parser_cache_hit(self): + """测试缓存命中的agent解析器获取""" + agent_parsers = { + "designer": Mock() # 已缓存的解析器 + } + + parser = ResultProcessor.get_agent_parser( + agent_name="designer", + workflow_config_path="/path/to/workflow.yaml", + agent_parsers=agent_parsers + ) + + assert parser == agent_parsers["designer"] + + def test_parse_conductor_decision_success(self): + """测试成功解析Conductor决策""" + # 创建一个模拟的解析器 + class MockParsedResult: + def __init__(self): + self.decision = "coder" + self.suggestion = "test suggestion" + + parser = Mock() + parser.parse.return_value = MockParsedResult() + + decision, suggestion = ResultProcessor.parse_conductor_decision( + content='{"decision": "coder", "suggestion": "test suggestion"}', + conductor_parser=parser, + valid_next_agents={"coder", "designer"} + ) + + assert decision == "coder" + assert suggestion == "test suggestion" + + def test_parse_conductor_decision_invalid_decision(self): + """测试解析无效的Conductor决策""" + # 创建一个模拟的解析器 + class MockParsedResult: + def __init__(self): + self.decision = "invalid_agent" + self.suggestion = "test suggestion" + + parser = Mock() + parser.parse.return_value = MockParsedResult() + + decision, suggestion = ResultProcessor.parse_conductor_decision( + content='{"decision": "invalid_agent", "suggestion": "test suggestion"}', + conductor_parser=parser, + valid_next_agents={"coder", "designer"} + ) + + assert decision is None # 无效决策应该返回None + assert suggestion == "test suggestion" \ No newline at end of file diff --git a/aikg/tests/ut/test_trace.py b/aikg/tests/ut/test_trace.py new file mode 100644 index 000000000..14c48a1f8 --- /dev/null +++ b/aikg/tests/ut/test_trace.py @@ -0,0 +1,174 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import Mock, patch, mock_open +import json +from ai_kernel_generator.core.trace import Trace, AgentRecord + + +class TestTrace: + """测试Trace类""" + + def test_trace_init(self): + """测试Trace初始化""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + assert trace.op_name == "test_op" + assert trace.task_id == "test_task" + assert trace.log_dir == "/test/log/dir" + assert trace.trace_list == [] + + def test_agent_record_init(self): + """测试AgentRecord初始化""" + record = AgentRecord( + agent_name="test_agent", + result="test_result", + prompt="test_prompt", + reasoning="test_reasoning", + error_log="test_error", + profile_res=(1.0, 2.0, 3.0) + ) + + assert record.agent_name == "test_agent" + assert record.result == "test_result" + assert record.prompt == "test_prompt" + assert record.reasoning == "test_reasoning" + assert record.error_log == "test_error" + assert record.profile_res == (1.0, 2.0, 3.0) + + def test_agent_record_init_defaults(self): + """测试AgentRecord默认值初始化""" + record = AgentRecord(agent_name="test_agent") + + assert record.agent_name == "test_agent" + assert record.result == "" + assert record.prompt == "" + assert record.reasoning == "" + assert record.error_log == "" + assert record.profile_res == () + + def test_insert_agent_record(self): + """测试插入agent记录""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + trace.insert_agent_record( + agent_name="test_agent", + result="test_result", + prompt="test_prompt", + reasoning="test_reasoning", + error_log="test_error", + profile_res=(1.0, 2.0, 3.0) + ) + + assert len(trace.trace_list) == 1 + record = trace.trace_list[0] + assert record.agent_name == "test_agent" + assert record.result == "test_result" + assert record.prompt == "test_prompt" + assert record.reasoning == "test_reasoning" + assert record.error_log == "test_error" + assert record.profile_res == (1.0, 2.0, 3.0) + + @patch('os.makedirs') + @patch('builtins.open', new_callable=mock_open) + def test_save_parameters_to_files(self, mock_file, mock_makedirs): + """测试保存参数到文件""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + params = [ + ("param1", "content1"), + ("param2", "content2") + ] + + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + trace.save_parameters_to_files("test_agent", params) + + # 验证目录创建 + mock_makedirs.assert_called_once() + + # 验证文件写入 + assert mock_file.call_count == 2 + + @patch('os.makedirs') + @patch('builtins.open', new_callable=mock_open) + def test_insert_agent_record_with_file_save(self, mock_file, mock_makedirs): + """测试插入agent记录时保存文件""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + trace.insert_agent_record( + agent_name="designer", + result='{"code": "test code"}', + prompt="test prompt", + reasoning="test reasoning" + ) + + # 验证文件被写入 + assert mock_file.called + + @patch('os.makedirs') + @patch('builtins.open', new_callable=mock_open) + def test_insert_agent_record_verifier_with_error(self, mock_file, mock_makedirs): + """测试插入verifier记录时保存错误日志文件""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + trace.insert_agent_record( + agent_name="verifier", + error_log="test error log" + ) + + # 验证错误日志文件被写入 + assert mock_file.called + + @patch('os.makedirs') + @patch('builtins.open', new_callable=mock_open) + def test_save_parsed_code(self, mock_file, mock_makedirs): + """测试保存解析后的代码""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + params = [ + ("code", "test code"), + ("explanation", "test explanation") + ] + + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + trace.save_parsed_code("test_agent", params) + + # 验证目录创建 + mock_makedirs.assert_called_once() + + # 验证文件写入 + mock_file.assert_called_once() + + @patch('os.makedirs') + @patch('builtins.open', new_callable=mock_open) + def test_insert_conductor_agent_record(self, mock_file, mock_makedirs): + """测试插入conductor记录""" + trace = Trace("test_op", "test_task", "/test/log/dir") + + with patch('os.path.expanduser', return_value="/expanded/test/log/dir"): + trace.insert_conductor_agent_record( + res="test result", + prompt="test prompt", + reasoning="test reasoning", + agent_name="decision" + ) + + # 验证目录创建 + mock_makedirs.assert_called_once() + + # 验证文件写入 - 应该调用3次open(result, prompt, reasoning) + assert mock_file.call_count == 3 \ No newline at end of file diff --git a/aikg/tests/ut/test_workflow_controller.py b/aikg/tests/ut/test_workflow_controller.py new file mode 100644 index 000000000..735d0cd40 --- /dev/null +++ b/aikg/tests/ut/test_workflow_controller.py @@ -0,0 +1,179 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from ai_kernel_generator.utils.workflow_controller import WorkflowController + + +class TestWorkflowController: + """测试工作流控制器""" + + def test_count_consecutive_repeats_empty_history(self): + """测试空历史记录的连续重复计数""" + count = WorkflowController.count_consecutive_repeats([], "designer") + assert count == 0 + + def test_count_consecutive_repeats_no_matches(self): + """测试没有匹配的历史记录""" + history = ["coder", "verifier", "designer"] + count = WorkflowController.count_consecutive_repeats(history, "coder") + assert count == 0 + + def test_count_consecutive_repeats_single_match(self): + """测试单个匹配的历史记录""" + history = ["designer", "coder", "designer"] + count = WorkflowController.count_consecutive_repeats(history, "designer") + assert count == 1 + + def test_count_consecutive_repeats_multiple_consecutive_matches(self): + """测试多个连续匹配的历史记录""" + history = ["designer", "coder", "verifier", "verifier", "verifier"] + count = WorkflowController.count_consecutive_repeats(history, "verifier") + assert count == 3 + + def test_count_consecutive_repeats_non_consecutive_matches(self): + """测试非连续匹配的历史记录""" + history = ["designer", "coder", "designer", "verifier", "designer"] + count = WorkflowController.count_consecutive_repeats(history, "designer") + assert count == 1 # 只计算末尾的连续重复 + + def test_count_sequence_repeats_empty_pattern(self): + """测试空模式的序列重复计数""" + history = ["designer", "coder", "verifier"] + count = WorkflowController.count_sequence_repeats(history, []) + assert count == 0 + + def test_count_sequence_repeats_empty_history(self): + """测试空历史记录的序列重复计数""" + pattern = ["designer", "coder"] + count = WorkflowController.count_sequence_repeats([], pattern) + assert count == 0 + + def test_count_sequence_repeats_single_occurrence(self): + """测试单次出现的序列重复计数""" + history = ["designer", "coder", "verifier"] + pattern = ["designer", "coder"] + count = WorkflowController.count_sequence_repeats(history, pattern) + assert count == 1 + + def test_count_sequence_repeats_multiple_occurrences(self): + """测试多次出现的序列重复计数""" + history = ["designer", "coder", "verifier", "designer", "coder", "verifier"] + pattern = ["designer", "coder", "verifier"] + count = WorkflowController.count_sequence_repeats(history, pattern) + assert count == 2 + + def test_count_sequence_repeats_partial_match(self): + """测试部分匹配的序列重复计数""" + history = ["designer", "coder", "verifier", "designer", "coder"] + pattern = ["designer", "coder", "verifier"] + count = WorkflowController.count_sequence_repeats(history, pattern) + assert count == 1 # 只有完整的一次匹配 + + def test_get_illegal_agent_step_limit_exceeded(self): + """测试步数限制超标的违法agent获取""" + illegal_agents = WorkflowController.get_illegal_agent( + step_count=25, # 超过限制 + max_step=20, + current_agent_name="designer", + agent_history=["designer", "coder"], + repeat_limits={}, + agent_info={"designer": {}, "coder": {}, "verifier": {}} + ) + + # 应该返回所有agent作为违法agent + assert len(illegal_agents) == 3 + assert "designer" in illegal_agents + assert "coder" in illegal_agents + assert "verifier" in illegal_agents + + def test_get_illegal_agent_consecutive_limit_exceeded(self): + """测试连续重复限制超标的违法agent获取""" + illegal_agents = WorkflowController.get_illegal_agent( + step_count=5, + max_step=20, + current_agent_name="designer", + agent_history=["designer", "designer", "designer", "designer"], # 4次连续 + repeat_limits={ + "single_agent": { + "designer": 3 # 限制为3次 + } + }, + agent_info={"designer": {}, "coder": {}, "verifier": {}} + ) + + # designer应该被标记为违法agent + assert "designer" in illegal_agents + + def test_get_illegal_agent_sequence_limit_exceeded(self): + """测试序列重复限制超标的违法agent获取""" + illegal_agents = WorkflowController.get_illegal_agent( + step_count=10, + max_step=20, + current_agent_name="coder", # 序列的最后一个agent + agent_history=["designer", "coder", "verifier", "designer", "coder"], # 两次重复 + repeat_limits={ + "sequences": { + "test_sequence": { + "pattern": ["designer", "coder"], # 限制模式 + "max_repeats": 1 # 限制为1次 + } + } + }, + agent_info={"designer": {}, "coder": {}, "verifier": {}} + ) + + # 序列的第一个agent应该被标记为违法agent + assert "designer" in illegal_agents + + def test_get_valid_next_agent_no_illegal_agents(self): + """测试没有违法agent时的有效下一步获取""" + agent_next_mapping = { + "designer": {"coder"}, + "coder": {"verifier"}, + "verifier": {"designer", "finish"} + } + + valid_agents = WorkflowController.get_valid_next_agent( + agent_name="designer", + agent_next_mapping=agent_next_mapping, + step_count=5, + max_step=20, + current_agent_name="designer", + agent_history=["coder", "designer"], + repeat_limits={}, + agent_info={"designer": {}, "coder": {}, "verifier": {}} + ) + + assert valid_agents == {"coder"} + + def test_get_valid_next_agent_with_illegal_agents(self): + """测试有违法agent时的有效下一步获取""" + agent_next_mapping = { + "verifier": {"designer", "finish"} + } + + valid_agents = WorkflowController.get_valid_next_agent( + agent_name="verifier", + agent_next_mapping=agent_next_mapping, + step_count=25, # 超过限制 + max_step=20, + current_agent_name="verifier", + agent_history=["designer", "coder"], + repeat_limits={}, + agent_info={"designer": {}, "coder": {}, "verifier": {}} + ) + + # 由于所有agent都是违法的,应该返回空集合 + assert valid_agents == set() \ No newline at end of file diff --git a/aikg/tests/ut/test_workflow_manager.py b/aikg/tests/ut/test_workflow_manager.py new file mode 100644 index 000000000..990777d72 --- /dev/null +++ b/aikg/tests/ut/test_workflow_manager.py @@ -0,0 +1,231 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tempfile +import os +from pathlib import Path +from ai_kernel_generator.utils.workflow_manager import WorkflowManager + + +class TestWorkflowManager: + """测试工作流管理器""" + + def test_resolve_workflow_config_path_none(self): + """测试解析None工作流配置路径""" + # 这应该返回默认配置路径 + path = WorkflowManager.resolve_workflow_config_path(None) + assert "config" in path + assert path.endswith(".yaml") + + def test_resolve_workflow_config_path_filename(self): + """测试解析文件名工作流配置路径""" + path = WorkflowManager.resolve_workflow_config_path("test_workflow") + assert "config" in path + assert path.endswith("test_workflow.yaml") + + def test_resolve_workflow_config_path_relative_path(self): + """测试解析相对路径工作流配置路径""" + path = WorkflowManager.resolve_workflow_config_path("config/test_workflow.yaml") + assert path.endswith("config/test_workflow.yaml") + + def test_resolve_workflow_config_path_absolute_path(self): + """测试解析绝对路径工作流配置路径""" + absolute_path = "/absolute/path/to/test_workflow.yaml" + path = WorkflowManager.resolve_workflow_config_path(absolute_path) + assert path == absolute_path + + def test_load_workflow_config_file_not_found(self): + """测试加载不存在的工作流配置文件""" + with pytest.raises(FileNotFoundError): + WorkflowManager.load_workflow_config("/nonexistent/path/to/workflow.yaml") + + def test_load_workflow_config_invalid_yaml(self, tmp_path): + """测试加载无效YAML格式的工作流配置文件""" + # 创建一个包含无效YAML的配置文件 + invalid_yaml_content = """ +agent_info: + designer: + possible_next_agent: [coder] + invalid_indent: value +""" + yaml_file = tmp_path / "invalid_workflow.yaml" + yaml_file.write_text(invalid_yaml_content) + + with pytest.raises(Exception): # yaml.safe_load会抛出具体的YAML错误 + WorkflowManager.load_workflow_config(str(yaml_file)) + + def test_load_workflow_config_missing_agent_info(self, tmp_path): + """测试加载缺少agent_info的工作流配置文件""" + yaml_content = """ +start_agent: designer +limitation_info: + required: + max_step: 20 +""" + yaml_file = tmp_path / "missing_agent_info.yaml" + yaml_file.write_text(yaml_content) + + with pytest.raises(ValueError, match="No 'agent_info' found in workflow config"): + WorkflowManager.load_workflow_config(str(yaml_file)) + + def test_load_workflow_config_missing_max_step(self, tmp_path): + """测试加载缺少max_step的工作流配置文件""" + yaml_content = """ +agent_info: + designer: + possible_next_agent: [coder] +start_agent: designer +limitation_info: + required: + # missing max_step +""" + yaml_file = tmp_path / "missing_max_step.yaml" + yaml_file.write_text(yaml_content) + + with pytest.raises(ValueError, match="Missing required setting 'max_step'"): + WorkflowManager.load_workflow_config(str(yaml_file)) + + def test_load_workflow_config_missing_start_agent(self, tmp_path): + """测试加载缺少start_agent的工作流配置文件""" + yaml_content = """ +agent_info: + designer: + possible_next_agent: [coder] +limitation_info: + required: + max_step: 20 +""" + yaml_file = tmp_path / "missing_start_agent.yaml" + yaml_file.write_text(yaml_content) + + with pytest.raises(ValueError, match="Missing required 'start_agent'"): + WorkflowManager.load_workflow_config(str(yaml_file)) + + def test_initialize_task_info_fields_basic(self): + """测试基本任务信息字段初始化""" + agent_info = { + "designer": { + "output_format": { + "parser_definition": { + "output_fields": { + "code": { + "field_type": "str", + "mandatory": True + } + } + } + } + } + } + + task_info = WorkflowManager.initialize_task_info_fields( + agent_info=agent_info, + op_name="test_op", + task_id="test_task", + dsl="triton", + task_desc="test description" + ) + + # 检查基本字段 + assert task_info["op_name"] == "test_op" + assert task_info["task_id"] == "test_task" + assert task_info["dsl"] == "triton" + assert task_info["task_desc"] == "test description" + + # 检查从agent_info生成的字段 + assert "designer_code" in task_info + + # 检查verifier特殊字段 + assert task_info["verifier_result"] is False + assert task_info["verifier_error"] == "" + + def test_initialize_task_info_fields_multiple_fields(self): + """测试多字段任务信息初始化""" + agent_info = { + "designer": { + "output_format": { + "parser_definition": { + "output_fields": { + "code": { + "field_type": "str", + "mandatory": True + }, + "explanation": { + "field_type": "str", + "mandatory": False + } + } + } + } + }, + "coder": { + "output_format": { + "parser_definition": { + "output_fields": { + "code": { + "field_type": "str", + "mandatory": True + } + } + } + } + } + } + + task_info = WorkflowManager.initialize_task_info_fields( + agent_info=agent_info, + op_name="test_op", + task_id="test_task", + dsl="triton" + ) + + # 检查多个字段 + assert "designer_code" in task_info + assert "designer_explanation" in task_info + assert "coder_code" in task_info + + def test_initialize_task_info_fields_with_base_doc(self): + """测试带基础文档的任务信息初始化""" + agent_info = { + "designer": { + "output_format": { + "parser_definition": { + "output_fields": { + "code": { + "field_type": "str", + "mandatory": True + } + } + } + } + } + } + + base_doc = { + "api_docs": "test api docs", + "dsl_basic_docs": "test dsl docs" + } + + task_info = WorkflowManager.initialize_task_info_fields( + agent_info=agent_info, + op_name="test_op", + task_id="test_task", + dsl="triton", + base_doc=base_doc + ) + + # 检查基础文档字段被正确添加 + assert task_info["api_docs"] == "test api docs" + assert task_info["dsl_basic_docs"] == "test dsl docs" \ No newline at end of file -- Gitee