From 287c7bad769edbf28917ac5b209587e13f41a947 Mon Sep 17 00:00:00 2001 From: shenyue_ustc Date: Mon, 29 Sep 2025 09:30:28 +0800 Subject: [PATCH 1/4] add prompts and test generator --- README.md | 2 +- main.py | 67 +++++++ src/prompts/env_prompt.py | 86 ++++++++ src/prompts/eval_prompt.py | 11 ++ src/prompts/tool_prompt.py | 67 +++++++ src/prompts/val_prompt.py | 46 +++++ src/test_generator/TestGenerator.py | 297 ++++++++++++++++++++++++++++ 7 files changed, 575 insertions(+), 1 deletion(-) create mode 100644 main.py create mode 100644 src/prompts/env_prompt.py create mode 100644 src/prompts/eval_prompt.py create mode 100644 src/prompts/tool_prompt.py create mode 100644 src/prompts/val_prompt.py create mode 100644 src/test_generator/TestGenerator.py diff --git a/README.md b/README.md index 3c7ed10..72ab8e0 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,7 @@ python main.py gen-cases --config xxx/mcp-config.json ### 5. 执行测试用例并验证 ```bash # 执行校验命令 -python main.py val-cases --config xxx/mcp-config.json --testpath ./logs/perf_mcp_2025-09-11T07-31-04-418670 +python main.py val-cases --config xxx/mcp-config.json --testpath ./logs/perf_mcp_2025-09-11T07-31-04-418670/testcases.json ``` - `--testpath`:指定步骤 4 生成的测试用例目录路径; - 执行结果:用例的执行结果将保存至步骤 4 输出的用例目录。 \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..23b4c04 --- /dev/null +++ b/main.py @@ -0,0 +1,67 @@ +import asyncio +import argparse + +def parse_args(): + parser = argparse.ArgumentParser( + description="MCP Server initialization args", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + subparsers = parser.add_subparsers(dest='command', required=True) + + # gen-cases 子命令 + gen_parser = subparsers.add_parser('gen-cases', help='Generate test cases') + gen_parser.add_argument( + "--config", + type=str, + default="./mcp-servers-perf.json", + help="Path to MCP Server config file" + ) + + #val-cases子命令 + val_parser = subparsers.add_parser('val-cases', help='Validate test cases') + val_parser.add_argument( + "--config", + type=str, + default="./mcp-servers-perf.json", + help="Path to MCP Server config file" + ) + val_parser.add_argument( + "--testpath", + type=str, + default=".logs/perf_mcp_2025-09-12T06-43-29-026631/testcases.json", + help="Path to get testcases" + ) + + # reporter 子命令 + rep_parser = subparsers.add_parser('rep-cases', help='report testing results') + rep_parser.add_argument( + "--valpath", + type=str, + default=".logs/perf_mcp_2025-09-11T07-31-04-418670/validation_results.json", + help="Path to MCP Server config file" + ) + + return parser.parse_args() + + +async def gen_cases(config_path): + from src.test_generator.TestGenerator import TestGenerator + generator = TestGenerator(config_path=config_path) + return await generator.run() + + +async def val_cases(config_path, testcase_path): + from src.validator.Response_validator_withenv import ResponseValidator_withenv + validator = ResponseValidator_withenv(config_path=config_path, testcase_path=testcase_path) + return await validator.run() + +async def main(): + args = parse_args() + if args.command == 'gen-cases': + await gen_cases(args.config) + if args.command == 'val-cases': + await val_cases(args.config, args.testpath) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/prompts/env_prompt.py b/src/prompts/env_prompt.py new file mode 100644 index 0000000..44437b2 --- /dev/null +++ b/src/prompts/env_prompt.py @@ -0,0 +1,86 @@ +env_prompt = """You have received a debugging task: a test case has failed its validation due to not meeting specific environmental requirements. +The specific reason is: {reason}. Your job is to write a Bash script that configures the environment to satisfy these special requirements. + +**Important Notes:** +- The default environment already has Anaconda, Python, and the following common dependencies pre-installed: + {dependencies} + **Do not include redundant commands for these in your script.** +- Carefully review the test case's description field to identify any additional environment setup needed beyond the standard baseline. +- **Target Specifics**: Focus only on additional setup needed to address the failure, as indicated in the test case description. Examples include: + - Enabling/disabling services (e.g., stopping Docker if the test requires it to be unavailable) + - Creating required directories/files with specific structures or contents + - Installing/removing packages not in the pre-installed list +- Minimal Steps: Only include core commands needed to meet the special requirements—no extra configurations, optimizations, or redundant checks. +- The target operating system is openEuler. Please ensure: + - Use yum or dnf instead of apt-get + - Employ commands suitable for RHEL/CentOS series + +### Input Example +Test case(s): +{testcases} + +Test output and validation results: +output: {output} +validation_results: {validation_results} + +### Output Format +Wrap your bash script between the following tags: +```plaintext + +``` +""" + +not_pass_judge_prompt = """# Debugging Task Instructions + +You have received a debugging task for a test case that failed its validation rules. Please follow these steps to identify the root cause: + +### 1. Identify the Test Case Type +- Before identification, check the `expect` field of the test case: + - If the value is `"Success"`, the test case is **expected to succeed** (i.e., it is a normal/happy-path scenario). + - If the value is `"Error"`, the test case is **intentionally designed to fail** (i.e., it validates error-handling logic). + +### 2. Confirm Environmental Requirements +- Standard Testing Environment Baseline: The default environment already includes pre-installed Anaconda, Python, and the common dependencies +{dependencies} +Check for Special Requirements: Read the test case's description field to determine if it needs additional environmental setup beyond the standard baseline. +For additional environmental setup, examples include: + - Enabling/disabling services (e.g., stopping Docker if the test requires it to be unavailable) + - Creating required directories/files with specific structures or contents + - Installing/removing packages not in the pre-installed list + +### 3. The definition of rule type + - **contains**: Verify the response includes an **exact, static substring** (max 4 words/fragments). + - `value`: `"[exact_text_fragment]"`(wrap in double quotes; use single quotes inside for clarity, e.g., `" 'Python 3.11 installed' "`) + - **equals**: Verify the **entire response** matches an exact, fixed value. (fixed numbers, short strings, booleans). + - `value`: For strings: `"[exact_string]"`; for numbers/booleans: `[exact_value]`(no quotes) + - **schema**: Validate JSON structure/data types (required fields, value types). + - `value`: `[valid_json_schema]` + - **llm**: For semantic validation requiring human-like judgment (e.g., summary accuracy). + - `value`: `Natural language specifying semantic criteria.` + +### 4. Analyze the Failure Cause +- Review the `output` (actual result of the test) and `validation_results` (details of failed/passed rules): + - Judge if the failure stems from **unmet special environmental requirements** (e.g., unstopped Docker service, or unconfigured simulation that the test depends on). + - Or if the failure is caused by **issues with the validation rule itself** (e.g., incorrect expected values, or too tightly strict validation rules). +- Base your judgment on the known standard testing environment (do not assume unstated configurations). + +### Input Example +Test case(s): +{testcases} + +Test output and validation results: +output: {output} +validation_results: {validation_results} + +### Output Format +Provide a clear, direct answer using the following JSON format (do not include extra content): + +```json +{{ + "option": "a. Unmet special environmental requirement" or "b. Issue with the validation rule itself", + "reason": "Concise, specific explanation (e.g., 'The test requires an nginx process running on port 80, but the standard environment lacks this setup' or 'The validation rule expects Python 3.8, but the test uses Python 3.9—an unreasonable mismatch')" +}} +``` +""" diff --git a/src/prompts/eval_prompt.py b/src/prompts/eval_prompt.py new file mode 100644 index 0000000..c456862 --- /dev/null +++ b/src/prompts/eval_prompt.py @@ -0,0 +1,11 @@ +eval_prompt = """Create a natural, conversational request for an AI assistant to perform this specific test scenario: + +Tool: {tool.name} - {tool.description} +Purpose: {test_case.description} + +Parameters to include: +{test_case_inputs} + +Craft a single, fluent sentence that naturally incorporates these parameter values as if you're asking for help with this specific task. Make it sound like a real user request rather than a technical specification. + +Natural language request:""" \ No newline at end of file diff --git a/src/prompts/tool_prompt.py b/src/prompts/tool_prompt.py new file mode 100644 index 0000000..c217a91 --- /dev/null +++ b/src/prompts/tool_prompt.py @@ -0,0 +1,67 @@ +tool_prompt = """ +You are an expert in generating comprehensive test cases for tools accessed through MCP (Model Context Protocol) servers on Linux operating systems. Your task is to create diverse, realistic test cases that thoroughly validate tool functionality. +## Tool Definition +Name: {tool.name} +Description: {tool.description} +Parameters: {input_properties} +Tool Source Code: {tool_function_str} +(since the tool may lack a formal outputSchema, the source code is the authoritative reference for output format/structure) + +## Instructions +1. Generate {tests_per_tool} diverse test cases covering these categories: + - Happy Path (80%): Normal, expected usage scenarios with valid inputs + - Error Cases (20%): Invalid or edge case inputs that should trigger proper error handling + +2. For each test case, provide these fields: + - `description`: A concise explanation of the scenario and its test intent. + - `input`: A JSON object with concrete and plausible parameter values. Avoid generic placeholders; use specific, realistic values as users would (e.g., use "3.11" for Python version). + - `expect`: + - `status`: "success" for happy path, or "error" if an error is expected. + - `validationRules`: An **array of assertion rules** to precisely check the tool's response. **Validation rules are critical for automated testing, so each rule must be clear, unambiguous, and directly translatable to Python test code.** + Every rule must include 3 key fields: + - `type`: One of [`contains`,`equals`,`schema`, `llm`] + - `value`: A **machine-parsable value** (no redundant natural language) in the specified format—see below for details. + - `message`: A helpful description to show if validation fails. + - **Choose the validation `type` and construct `value` as follows:** + - **contains**: Verify the response includes an **exact, static substring** (max 4 words/fragments). + - `value`: `"[exact_text_fragment]"`(wrap in double quotes; use single quotes inside for clarity, e.g., `" 'Python 3.11 installed' "`) + - **equals**: Verify the **entire response** matches an exact, fixed value. (fixed numbers, short strings, booleans). + - `value`: For strings: `"[exact_string]"`; for numbers/booleans: `[exact_value]`(no quotes) + - **schema**: Validate JSON structure/data types (required fields, value types). + - `value`: `[valid_json_schema]` + - **llm**: For semantic validation requiring human-like judgment (e.g., summary accuracy). + - `value`: `Natural language specifying semantic criteria.` + + - **For successful (“success”) test cases:** + - Prefer `schema` if the response is structured. Prefer `contains` for specific fragments, and `equals` only for fixed outputs. Use `llm` for **semantic validation requiring human judgment** (e.g., summary accuracy) + + - **For error (“error”) test cases:** + - Validate on the **presence and clarity of error indications** (e.g., specific error messages, error codes, required fields in the error response). + - Prefer `"contains"` for specific fragments/errors and `"schema"` only if error objects are structured. + - Make sure the rule can be programmatically checked. + +## Output Format + +Return a **pure JSON array** of test cases in the following structure: + +```json +[ + {{ + "description": "A brief but precise description of the test scenario.", + "input": {{ /* Concrete parameter values */ }}, + "expect": {{ + "status": "success|error", + "validationRules": [ + {{ + "type": "contains|equals|schema|llm", + "value": "xxx", // See format above; must be directly checkable + "message": "Custom failure explanation for this rule." + }} + /* ... more validation rules as appropriate ... */ + ] + }} + }} + /* ... more test cases ... */ +] +``` +""" \ No newline at end of file diff --git a/src/prompts/val_prompt.py b/src/prompts/val_prompt.py new file mode 100644 index 0000000..b1f78aa --- /dev/null +++ b/src/prompts/val_prompt.py @@ -0,0 +1,46 @@ +val_prompt_tool = """You are an expert evaluator responsible for assessing whether a specific MCP tool executed correctly for a given query. The validation rule serves as the expected criterion to validate the tool's output. + +**Tool:** {tool_name} +**Input:** {input} +**Validation Rule:** {validation_rule} +**Execution Output:** {output} + +Did the tool execute correctly and produce output that meets the expectations defined by the validation rule? Answer "yes" or "no" and provide a brief explanation. + +Output format: +```json +{{ + "answer": "yes" | "no", + "explanation": "Explanation of the result" +}} +```""" + +val_prompt_eval = """You are an expert evaluator specializing in assessing test cases for tools accessed via MCP (Model Context Protocol) servers on Linux operating systems. Your core responsibility is to verify whether the final output of a chat session (which executes MCP tools) aligns with the expected output of the corresponding test case. + +### 1. Test Case Category Definition +Test cases are divided into two types, and you need to first confirm the category of the current case: +- **Happy-path cases**: These cases represent normal, expected usage scenarios with valid inputs— the chat session should execute tools without errors and return results that match expected behavior. +- **Error cases**: These cases involve invalid inputs or edge cases— the chat session should trigger proper error handling (e.g., return error prompts, avoid abnormal crashes) instead of normal results. + +### 2. Current Test Case Information +- **User Query**: {{ query }} +- **Test Case Type**: {% if expect_type == "success" %}Happy-path case (valid inputs, expect normal execution){% else %}Error case (invalid/edge inputs, expect proper error handling){% endif %} +{{ expected_output }} +- **Chat Session's Final Output**: {{ output }} + +### 3. Evaluation Task +For the above test case, focus on two key points to evaluate: +1. Whether the chat session's tool execution process matches the case type (e.g., happy-path cases should have no execution errors; error cases should trigger error handling). +2. Whether the chat session's final output is consistent with the "Expected Output" (including result content for happy-path cases, or error prompt logic for error cases). + +Answer "yes" if the final output meets the test case's expectations; answer "no" otherwise. Provide a brief explanation to support your judgment. + +### 4. Output Format +```json +{{ '{{' }} + "answer": "yes" | "no", + "explanation": "Clear explanation of why the final output meets/does not meet expectations" +{{ '}}' }} +``` +""" + diff --git a/src/test_generator/TestGenerator.py b/src/test_generator/TestGenerator.py new file mode 100644 index 0000000..011f99c --- /dev/null +++ b/src/test_generator/TestGenerator.py @@ -0,0 +1,297 @@ +import json +import datetime +import uuid +import re +import os +from typing import List, Dict, Any, Optional, Protocol +from ..llm.LLM import LLMClient +from ..type.types_def import ToolDefinition, TestCase +from ..prompts.tool_prompt import tool_prompt_0903 +from ..prompts.eval_prompt import eval_prompt +from ..client.MCPClient import Configuration, Server +from ..utils.read_source_code import ReadSourceCode + +class TestGenerator: + """ + Generator for test cases using Large Language Model + """ + + def __init__(self, api_key: str = None, config_path: str = None): + """ + Create a new test generator + + Args: + api_key: API key for the language model + """ + Config_class = Configuration() + self.config = Config_class.load_config(config_path) + self.llm = LLMClient(api_key) + self.readsc = ReadSourceCode(config_path) + async def run(self): + # load config + + servers = [Server(name, srv_config) for name, srv_config in self.config["mcpServers"].items()] + tests_per_tool = self.config["numTestsPerTool"] + for server in servers: + + # connect server + print("\n========================================") + print(f"Testing server: {server.name}") + print("========================================\n") + await server.initialize() + + # Get available tools + tools = await server.list_tools() + if not tools: + Warning('No tools found in the MCP server. Nothing to test.') + print(f"Found {len(tools)} tools:") + print("\n".join([f"{tool.format_for_llm()}" for tool in tools])) + + # Generate tests + print(f"Generating {tests_per_tool} tests per tool...") + test_cases = await self.generate_tests_for_each_server(tools, tests_per_tool,server.name) + print(f"Generated {len(test_cases)} test cases in total.") + + self.save_to_file(server.name, test_cases) + + await server.cleanup() + + + async def generate_tests_for_each_server( + self, + tools: List[ToolDefinition], + tests_per_tool: int, + server_name: str + ) -> List[TestCase]: + """ + Generate test cases for the given tools + + Args: + server_name: Name of the MCP server + tools: Tool definitions to generate tests for + config: Tester configuration + + Returns: + List of generated test cases + """ + all_tests: List[TestCase] = [] + tool_functions = self.readsc.get_code(server_name) + for tool in tools: + try: + print(f"Generating tests for tool interface: {tool.name}") + # import ipdb; ipdb.set_trace() + + tool_prompt_formatted = self.create_tool_prompt(tool, tests_per_tool, tool_functions[tool.name]) + # print(prompt) + + + response = self.llm.get_response( + [{"role": "user", "content": tool_prompt_formatted}] + + ) + + if response: + test_cases = self.parse_response(response, tool.name) + + # Generate natural language queries for each test case + for test_case in test_cases: + print(f"Generating natural language query for {tool.name}") + eval_prompt_formatted = self.create_eval_prompt(tool, test_case) + try: + test_case.query = self.llm.get_response( + [{"role":"user","content": eval_prompt_formatted}] + ) + except Exception as err: + print(f"Failed to generate natural language query for {tool.name}: {err}") + test_case.query = '' + + all_tests.extend(test_cases) + print(f"Generated {len(test_cases)} tests for {tool.name}") + + else: + print(f"No response received for {tool.name}") + + except Exception as error: + print(f"Error generating tests for tool {tool.name}: {error}") + + return all_tests + + def create_tool_prompt(self, tool: ToolDefinition, tests_per_tool: int, tool_function_str: str) -> str: + """ + Create a prompt for the LLM to generate test cases for testing tool exeucation + + Args: + tool: Tool definition to generate tests for + tests_per_tool: Number of tests to generate per tool + + Returns: + Formatted prompt string + """ + # Extract input schema properties safely + input_properties = {} + if tool.input_schema and hasattr(tool.input_schema, 'properties'): + input_properties = json.dumps(input_properties, indent=2) + + formatted_prompt = tool_prompt_0903.format( + tool=tool, + input_properties=input_properties, + tests_per_tool=tests_per_tool, + tool_function_str=tool_function_str + ) + return formatted_prompt + + def create_eval_prompt(self, tool: ToolDefinition, test_case: TestCase) -> str: + """ + Create a prompt for the LLM to generate test cases for end-to-end agent running + + Args: + tool: Tool definition to generate tests for + tests_per_tool: Number of tests to generate per tool + + Returns: + Formatted prompt string + """ + + input = {} + if test_case.input: + input = json.dumps(test_case.input, indent=2) + + formatted_prompt = eval_prompt.format(tool=tool, test_case=test_case, test_case_inputs = input) + return formatted_prompt + + + def parse_response(self, response_text: str, tool_name: str) -> List[TestCase]: + """ + Parse LLM's response into test cases + + Args: + response_text: LLM's response text + tool_name: Name of the tool being tested + + Returns: + List of parsed test cases + """ + json_content = response_text + + try: + # Extract JSON content between backticks if present + json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response_text) + if json_match and json_match.group(1): + json_content = json_match.group(1) + else: + # Attempt to gracefully handle cases where the LLM might forget the backticks + print(f"[{tool_name}] LLM response did not contain JSON within backticks. Attempting to parse directly.") + + # Parse JSON + try: + parsed_json = json.loads(json_content) + except json.JSONDecodeError as parse_error: + print(f"[{tool_name}] Failed to parse JSON from LLM response. Error: {parse_error}") + print(f"[{tool_name}] Raw response text was: {response_text}") + return [] # Return empty if JSON parsing fails + + # Ensure parsed_json is a list + if not isinstance(parsed_json, list): + print(f"[{tool_name}] Parsed JSON is not an array. LLM response might be malformed. Raw response: {response_text}") + # If it's a single object that looks like a test case, wrap it in an array + if (isinstance(parsed_json, dict) and + 'description' in parsed_json and + 'input' in parsed_json and + 'expect' in parsed_json): + print(f"[{tool_name}] Attempting to recover by wrapping single test case object in an array.") + parsed_json = [parsed_json] + else: + return [] + + valid_test_cases: List[TestCase] = [] + + for index, test in enumerate(parsed_json): + # Basic validation for essential fields + if not isinstance(test, dict): + print(f"[{tool_name}] Test case at index {index} is not a valid object. Skipping.") + continue + + if not test.get('description') or not isinstance(test['description'], str): + print(f"[{tool_name}] Test case at index {index} is missing or has an invalid 'description'. Skipping: {json.dumps(test)}") + continue + + if 'input' not in test or not isinstance(test['input'], dict): + print(f"[{tool_name}] Test case \"{test['description']}\" is missing or has invalid 'inputs'. Skipping: {json.dumps(test)}") + continue + + if not test.get('expect') or not isinstance(test['expect'], dict): + print(f"[{tool_name}] Test case \"{test['description']}\" is missing or has invalid 'expect'. Skipping: {json.dumps(test)}") + continue + + expected_outcome = test['expect'] + if (not expected_outcome.get('status') or + expected_outcome['status'] not in ['success', 'error']): + print(f"[{tool_name}] Test case \"{test['description']}\" has missing or invalid 'expectedOutcome.status'. Skipping: {json.dumps(test)}") + continue + + # Create test case + test_case = TestCase( + id=str(uuid.uuid4()), + toolName=tool_name, + description=test['description'], + query='', + input=test['input'], + expect={ + "status":expected_outcome['status'], + "validation_rules": expected_outcome.get('validationRules', []) or [] + } + ) + valid_test_cases.append(test_case) + return valid_test_cases + + except Exception as error: + # Catch any other unexpected errors during processing + print(f"[{tool_name}] Unexpected error in parse_response: {error}") + print(f"[{tool_name}] Response text was: {response_text}") + return [] + + + def testcases_to_dict(self, testcases: List[TestCase])-> List: + res = [] + for case in testcases: + res.append( { + "id": case.id, + "toolName": case.toolName, + "description": case.description, + "query": case.query, + "input": case.input, + "expect": case.expect + }) + return res + + def save_to_file(self, server_name: str, testcases: List[TestCase]): + """ + save test cases (array of JSON) to file + """ + testcases = self.testcases_to_dict(testcases) + try: + if not isinstance(testcases, list): + raise ValueError("input data should be an array of JSON") + + current_timestamp = datetime.datetime.utcnow().isoformat() + safe_timestamp = current_timestamp.replace(":", "-").replace(".", "-") + + folerpath = os.path.join(".logs",f'{server_name}_{safe_timestamp}') + if not os.path.exists(folerpath): + os.mkdir(folerpath) + + filename = f"testcases.json" + filepath = os.path.join(folerpath,filename) + with open(filepath, 'w', encoding='utf-8') as file: + json.dump(testcases, file, ensure_ascii=False, indent=4) + print(f"{server_name} test cases are successfully saved into {filepath}") + return True + + except IOError as e: + print(f"文件操作错误: {e}") + except ValueError as e: + print(f"数据格式错误: {e}") + except Exception as e: + print(f"发生未知错误: {e}") + + return False \ No newline at end of file -- Gitee From 9ea15afe556ea6e3b1cc0dd8e28cb9789dbabc14 Mon Sep 17 00:00:00 2001 From: shenyue_ustc Date: Mon, 29 Sep 2025 14:34:12 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=88=A0=E9=99=A4=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/test_generator/TestGenerator.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/test_generator/TestGenerator.py b/src/test_generator/TestGenerator.py index 011f99c..9e97fab 100644 --- a/src/test_generator/TestGenerator.py +++ b/src/test_generator/TestGenerator.py @@ -3,12 +3,13 @@ import datetime import uuid import re import os -from typing import List, Dict, Any, Optional, Protocol +from typing import List from ..llm.LLM import LLMClient from ..type.types_def import ToolDefinition, TestCase -from ..prompts.tool_prompt import tool_prompt_0903 +from ..prompts.tool_prompt import tool_prompt from ..prompts.eval_prompt import eval_prompt -from ..client.MCPClient import Configuration, Server +from ..client.Client import Configuration +from ..client.MCPClient import MCPClient from ..utils.read_source_code import ReadSourceCode class TestGenerator: @@ -30,7 +31,7 @@ class TestGenerator: async def run(self): # load config - servers = [Server(name, srv_config) for name, srv_config in self.config["mcpServers"].items()] + servers = [MCPClient(name, srv_config) for name, srv_config in self.config["mcpServers"].items()] tests_per_tool = self.config["numTestsPerTool"] for server in servers: @@ -79,11 +80,8 @@ class TestGenerator: for tool in tools: try: print(f"Generating tests for tool interface: {tool.name}") - # import ipdb; ipdb.set_trace() tool_prompt_formatted = self.create_tool_prompt(tool, tests_per_tool, tool_functions[tool.name]) - # print(prompt) - response = self.llm.get_response( [{"role": "user", "content": tool_prompt_formatted}] @@ -132,7 +130,7 @@ class TestGenerator: if tool.input_schema and hasattr(tool.input_schema, 'properties'): input_properties = json.dumps(input_properties, indent=2) - formatted_prompt = tool_prompt_0903.format( + formatted_prompt = tool_prompt.format( tool=tool, input_properties=input_properties, tests_per_tool=tests_per_tool, -- Gitee From 94630e1e0a49a02d612740a50c057f1a297bb783 Mon Sep 17 00:00:00 2001 From: shenyue_ustc Date: Mon, 29 Sep 2025 16:30:42 +0800 Subject: [PATCH 3/4] add client --- main.py | 2 - src/client/DockerRegistry.py | 187 +++++++++ src/client/MCPClient.py | 582 ++++++++++++++++++++++++++++ src/client/Session.py | 148 +++++++ src/test_generator/TestGenerator.py | 2 +- src/utils/parse_json.py | 31 ++ src/utils/read_source_code.py | 105 +++++ 7 files changed, 1054 insertions(+), 3 deletions(-) create mode 100644 src/client/DockerRegistry.py create mode 100644 src/client/MCPClient.py create mode 100644 src/client/Session.py create mode 100644 src/utils/parse_json.py create mode 100644 src/utils/read_source_code.py diff --git a/main.py b/main.py index 23b4c04..a3b71cb 100644 --- a/main.py +++ b/main.py @@ -49,8 +49,6 @@ async def gen_cases(config_path): from src.test_generator.TestGenerator import TestGenerator generator = TestGenerator(config_path=config_path) return await generator.run() - - async def val_cases(config_path, testcase_path): from src.validator.Response_validator_withenv import ResponseValidator_withenv validator = ResponseValidator_withenv(config_path=config_path, testcase_path=testcase_path) diff --git a/src/client/DockerRegistry.py b/src/client/DockerRegistry.py new file mode 100644 index 0000000..5eb0506 --- /dev/null +++ b/src/client/DockerRegistry.py @@ -0,0 +1,187 @@ +import logging +import asyncio +import atexit +from typing import Set +import signal +import subprocess +import sys +import os + +class DockerContainerRegistry: + """全局Docker容器注册表,确保程序退出时清理所有容器""" + _instance = None + _containers: Set[str] = set() + _cleanup_lock = asyncio.Lock() + _initialized = False + _cleanup_in_progress = False # 添加清理状态标志 + _signal_count = 0 # 信号计数器 + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def initialize(cls): + """初始化全局清理机制""" + if cls._initialized: + return + + instance = cls() + + atexit.register(instance._sync_cleanup_all) + + def signal_handler(signum, frame): + instance._handle_signal(signum) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + cls._initialized = True + logging.info("Docker容器注册表已初始化") + + def _handle_signal(self, signum): + """处理信号,避免重复清理""" + self._signal_count += 1 + + if self._cleanup_in_progress: + if self._signal_count <= 2: + logging.info(f"清理正在进行中,请稍候... (信号计数: {self._signal_count})") + return + elif self._signal_count <= 5: + logging.warning(f"强制中断清理过程... (信号计数: {self._signal_count})") + return + else: + logging.error("多次中断信号,强制退出程序") + os._exit(1) + + logging.info(f"接收到信号 {signum},开始清理Docker容器...") + self._cleanup_in_progress = True + + try: + self._sync_cleanup_all() + except Exception as e: + logging.error(f"清理过程中出错: {e}") + finally: + logging.info("程序退出") + sys.exit(0) + + def register_container(self, container_name: str): + """注册容器""" + self._containers.add(container_name) + logging.debug(f"注册Docker容器: {container_name}") + + def unregister_container(self, container_name: str): + """注销容器""" + self._containers.discard(container_name) + logging.debug(f"注销Docker容器: {container_name}") + + def _sync_cleanup_all(self): + """同步清理所有注册的容器""" + if not self._containers or self._cleanup_in_progress: + return + + self._cleanup_in_progress = True + + try: + logging.info(f"开始清理 {len(self._containers)} 个Docker容器...") + containers_to_clean = self._containers.copy() + + for container_name in containers_to_clean: + try: + result = subprocess.run( + ["docker", "kill", container_name], + capture_output=True, + text=True, + timeout=3 # 减少超时时间 + ) + if result.returncode == 0: + logging.info(f"成功清理容器: {container_name}") + self._containers.discard(container_name) + else: + logging.warning(f"清理容器失败: {container_name}, {result.stderr}") + except subprocess.TimeoutExpired: + logging.warning(f"清理容器 {container_name} 超时,跳过") + except Exception as e: + logging.error(f"清理容器 {container_name} 出错: {e}") + + # 如果还有容器未清理,尝试强制清理 + if self._containers: + logging.info("尝试强制清理剩余容器...") + for container_name in list(self._containers): + try: + subprocess.run( + ["docker", "rm", "-f", container_name], + capture_output=True, + text=True, + timeout=2 + ) + self._containers.discard(container_name) + logging.info(f"强制清理容器: {container_name}") + except Exception as e: + logging.error(f"强制清理容器 {container_name} 失败: {e}") + + logging.info("Docker容器清理完成") + + finally: + self._cleanup_in_progress = False + + async def async_cleanup_all(self): + """异步清理所有容器""" + if not self._containers: + return + + async with self._cleanup_lock: + if self._cleanup_in_progress: + return + + self._cleanup_in_progress = True + + try: + containers_to_clean = list(self._containers) + + # 并发清理所有容器,但添加超时限制 + tasks = [] + for container_name in containers_to_clean: + task = asyncio.create_task(self._async_kill_container(container_name)) + tasks.append(task) + + if tasks: + # 设置总体超时时间 + try: + results = await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), + timeout=10 # 总体超时10秒 + ) + for container_name, result in zip(containers_to_clean, results): + if isinstance(result, Exception): + logging.error(f"异步清理容器 {container_name} 失败: {result}") + else: + self.unregister_container(container_name) + except asyncio.TimeoutError: + logging.warning("异步清理容器超时") + finally: + self._cleanup_in_progress = False + + async def _async_kill_container(self, container_name: str): + """异步终止单个容器""" + loop = asyncio.get_event_loop() + try: + result = await loop.run_in_executor( + None, + lambda: subprocess.run( + ["docker", "kill", container_name], + capture_output=True, + text=True, + timeout=5 # 减少单个容器的超时时间 + ) + ) + if result.returncode == 0: + logging.info(f"异步清理容器成功: {container_name}") + return True + else: + logging.warning(f"异步清理容器失败: {container_name}, {result.stderr}") + return False + except Exception as e: + logging.error(f"异步清理容器 {container_name} 出错: {e}") + return False \ No newline at end of file diff --git a/src/client/MCPClient.py b/src/client/MCPClient.py new file mode 100644 index 0000000..bfb5b4d --- /dev/null +++ b/src/client/MCPClient.py @@ -0,0 +1,582 @@ +import asyncio +import sys +import logging +import subprocess +import os +import time +import shutil +import docker +from docker.errors import NotFound, APIError +from contextlib import AsyncExitStack +from typing import Any, Optional +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from ..utils.parse_json import parse_evaluation_json +from .DockerRegistry import DockerContainerRegistry + +class MCPClient: + """MCP Client, 支持可靠的Docker生命周期管理""" + + def __init__(self, name: str, config: dict[str, Any], env_script: str = "", use_docker: bool = False) -> None: + self.name: str = name + self.config: dict[str, Any] = config + self.session: Optional[ClientSession] = None + self._cleanup_lock: asyncio.Lock = asyncio.Lock() + self.exit_stack: AsyncExitStack = AsyncExitStack() + + self.env_script = env_script + + + # 状态管理 + self._is_initialized = False + self._is_cleaning_up = False + self._cleanup_completed = asyncio.Event() + + # Docker相关配置 + self.use_docker = use_docker + self.abs_script_path = self.get_command_script_path() + self.host_mcp_path = self.abs_script_path.split('src')[0] if self.abs_script_path else "" + self.container_mcp_path = "/app/" + self.server_port = config.get("port", 8080) + + # Docker进程管理 + self.docker_process = None + self.container_id = None + self.container_name = None + + self.client = docker.from_env() + + # 初始化全局容器注册表 + if use_docker: + DockerContainerRegistry.initialize() + self._registry = DockerContainerRegistry() + + async def initialize(self) -> None: + """初始化服务器""" + if self._is_initialized: + logging.warning(f"服务器 {self.name} 已经初始化") + return + + try: + logging.info(f"开始初始化服务器 {self.name}") + + if self.use_docker: + await self._initialize_docker() + else: + await self._initialize_host_server() + + self._is_initialized = True + + except Exception as e: + logging.error(f"初始化失败: {e}") + await self.cleanup() + raise + + async def _initialize_host_server(self) -> None: + """在主机上启动MCP服务器""" + command = shutil.which("npx") if self.config["command"] == "npx" else self.config["command"] + if command is None: + raise ValueError(f"主机命令不存在: {self.config['command']}") + + server_params = StdioServerParameters( + command=command, + args=self.config["args"], + env={**os.environ, **self.config["env"]} if self.config.get("env") else None, + ) + + try: + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + read, write = stdio_transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self.session = session + logging.info(f"主机上的MCP服务器 {self.name} 已初始化") + except Exception as e: + logging.error(f"主机服务器初始化失败: {e}") + raise + + def _build_docker_command(self) -> list[str]: + """构建Docker运行命令""" + self.container_name = f"mcp-server-{self.name}-{int(time.time())}" + + docker_cmd = [ + "docker", "run", + "--rm", + "-i", + "--name", self.container_name, + "--workdir", self.container_mcp_path, + ] + + # 挂载主机MCP代码目录到容器 + docker_cmd.extend([ + "-v", f"{self.host_mcp_path}:{self.container_mcp_path}" + ]) + + # 添加环境变量 + env_vars = { + "PYTHONPATH": self.container_mcp_path, + "PYTHONUNBUFFERED": "1", + "PIP_ROOT_USER_ACTION": "ignore", + } + env_vars.update(self.config.get("env", {})) + + for key, value in env_vars.items(): + docker_cmd.extend(["-e", f"{key}={value}"]) + + docker_cmd.extend(["-a", "stdout", "-a", "stderr"]) + + # 添加Docker镜像 + self.docker_image = "val:latest" + docker_cmd.append(self.docker_image) + + startup_script = self._build_correct_bash_script() + docker_cmd.extend(["bash", "-c", startup_script]) + + return docker_cmd + def get_command_script_path(self) -> str: + """获取命令脚本路径""" + try: + server_args = self.config['args'] + source_file = None + work_dir = os.getcwd() + for i, arg in enumerate(server_args): + if arg == "--directory" and i + 1 < len(server_args): + work_dir = server_args[i + 1] + work_dir = os.path.abspath(work_dir) + break + + for arg in server_args: + if arg.endswith(".py"): + source_file = arg + break + + if not source_file: + logging.warning("未在 args 中找到 .py 源代码文件") + return None + + if os.path.isabs(source_file): + absolute_path = source_file + else: + absolute_path = os.path.join(work_dir, source_file) + + if os.path.exists(absolute_path): + return absolute_path + else: + logging.error(f"源代码文件不存在:{absolute_path}") + return None + except Exception as e: + logging.error(f"获取脚本路径出错: {e}") + return None + def _build_correct_bash_script(self) -> str: + """构建启动脚本""" + container_command = self._get_container_command() + script = f'''set -e + + echo "=== 快速环境检查 ===" + echo "Python版本: $(python --version)" + echo "工作目录: $(pwd)" + echo "文件列表:" + ls -la + echo "" + + # 只安装项目特定的新依赖 + echo "=== 检查并安装项目依赖 ===" + if [ -f requirements.txt ]; then + echo "发现requirements.txt文件" + pip install -qq --upgrade-strategy only-if-needed -r requirements.txt + if [ $? -eq 0 ]; then + echo "依赖安装完成(无异常)" + else + echo "依赖安装失败!(可去掉 -qq 参数重新执行查看详细错误)" + fi + else + echo "未找到requirements.txt文件" + fi + + echo "=== 执行自定义环境部署 ===" + {self.env_script} + + echo "=== 启动MCP服务器 ===" + echo "执行命令: {container_command}" + exec {container_command}''' + + return script + + def _get_container_command(self) -> str: + """获取容器内的命令字符串""" + command = self.config.get("command", "python") + if command == "uv": + command = "uv run" + script_rel_path = 'src'+self.abs_script_path.split('src')[-1] + return command + " " + script_rel_path + + + async def _initialize_docker(self): + """初始化Docker中的MCP服务器,支持输出显示""" + original_docker_command = self._build_docker_command() + + # 修改Docker命令,使用tee同时输出到终端和MCP + docker_command = self._add_output_redirection(original_docker_command) + + logging.info(f"启动Docker命令: {' '.join(docker_command)}") + + # 注册容器到全局注册表 + if self.container_name: + self._registry.register_container(self.container_name) + + # 清理可能存在的同名容器 + await self._cleanup_existing_container() + + # 使用修改后的命令建立MCP连接(只启动一个进程) + server_params = StdioServerParameters( + command=docker_command[0], + args=docker_command[1:], + env=None + ) + + try: + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server_params)) + + read, write = stdio_transport + session = await self.exit_stack.enter_async_context( + ClientSession(read, write) + ) + + # 等待容器稳定 + await asyncio.sleep(3) + self.search_for_container() + + # 启动监控任务 + monitor_task = await self._start_container_monitoring(session) + + # 注册清理回调 + self._register_cleanup_callback(monitor_task) + + self.session = session + logging.info(f"Docker MCP服务器 {self.name} 已初始化") + + return session + + except Exception as e: + logging.error(f"Docker MCP服务器初始化失败: {str(e)}") + raise + + def _add_output_redirection(self, docker_command): + """为Docker命令添加输出重定向""" + # 找到bash脚本部分 + if len(docker_command) >= 3 and docker_command[-2] == "-c": + # 修改bash脚本,添加tee命令 + original_script = docker_command[-1] + + # 将输出同时发送到stderr(显示在终端)和stdout(给MCP) + modified_script = f''' + # 设置输出重定向 + exec > >(tee /dev/stderr) + exec 2>&1 + + # 原始脚本 + {original_script} + ''' + new_command = docker_command[:-1] + [modified_script] + return new_command + + return docker_command + + async def _cleanup_existing_container(self): + """清理可能存在的同名容器""" + if not self.container_name: + return + + try: + import subprocess + stop_cmd = f"docker stop {self.container_name} 2>/dev/null || true" + remove_cmd = f"docker rm {self.container_name} 2>/dev/null || true" + + subprocess.run(stop_cmd, shell=True, capture_output=True) + subprocess.run(remove_cmd, shell=True, capture_output=True) + + logging.debug(f"已清理可能存在的容器: {self.container_name}") + except Exception as e: + logging.warning(f"清理现有容器时出错: {str(e)}") + + async def _start_container_monitoring(self, session): + """启动容器监控和会话初始化""" + try: + # 创建监控任务 + async def monitor(): + """容器状态监控循环""" + try: + while True: + await asyncio.sleep(2) + self.search_for_container() + except asyncio.CancelledError: + logging.debug("监控任务被取消") + raise + except Exception as e: + logging.error(f"监控任务错误: {str(e)}") + raise + + # 启动初始化和监控任务 + init_task = asyncio.create_task(session.initialize()) + monitor_task = asyncio.create_task(monitor()) + + try: + # 等待任务完成 + done, pending = await asyncio.wait( + [init_task, monitor_task], + return_when=asyncio.FIRST_COMPLETED + ) + + # 取消还在运行的任务 + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # 检查完成的任务是否有异常 + for task in done: + await task + + # 返回监控任务引用 + return monitor_task + + except Exception as e: + # 清理任务 + for task in [init_task, monitor_task]: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise + + except Exception as e: + logging.error(f"启动容器监控失败: {str(e)}") + raise + + def _register_cleanup_callback(self, monitor_task): + """注册清理回调函数""" + def cleanup(): + """清理所有资源""" + try: + # 清理监控任务 + if monitor_task and not monitor_task.done(): + monitor_task.cancel() + except Exception as cleanup_error: + logging.warning(f"清理过程中出现错误: {str(cleanup_error)}") + + self.exit_stack.callback(cleanup) + + def search_for_container(self): + try: + container = self.client.containers.get(self.container_name) + if container.status != "running": + raise RuntimeError(f"容器{self.container_name}未处于运行状态,当前状态: {container.status}") + except NotFound: + raise RuntimeError(f"容器{self.container_name}不存在") + except APIError as e: + raise RuntimeError(f"Docker API错误: {str(e)}") + + async def list_tools(self) -> list[Any]: + """列出可用工具""" + if not self.session: + raise RuntimeError(f"Server {self.name} not initialized") + + tools_response = await self.session.list_tools() + tools = [] + + for item in tools_response: + if isinstance(item, tuple) and item[0] == "tools": + tools.extend(Tool(tool.name, tool.description, tool.inputSchema) for tool in item[1]) + + return tools + + async def execute_tool( + self, + tool_name: str, + arguments: dict[str, Any], + retries: int = 1, + delay: float = 1.0, + ) -> list[Any]: + """执行工具""" + if not self.session: + raise RuntimeError(f"Server {self.name} not initialized") + + attempt = 0 + while attempt < retries: + try: + logging.info(f"Executing {tool_name}...") + result = await self.session.call_tool(tool_name, arguments) + tool_result = [] + for rc in result.content: + if rc.type == "text": + if '{' and '}' in rc.text: + try: + # 假设parse_evaluation_json函数存在 + rc_text_json = parse_evaluation_json(rc.text) + tool_result.append(rc_text_json) + except: + tool_result.append(rc.text) + else: + tool_result.append(rc.text) + elif rc.type == "image": + logging.warning("Image result is not supported yet") + elif rc.type == "resource": + logging.warning("Resource result is not supported yet") + return tool_result + except Exception as e: + attempt += 1 + logging.warning(f"Error executing tool: {e}. Attempt {attempt} of {retries}.") + if attempt < retries: + await asyncio.sleep(delay) + else: + logging.error("Max retries reached. Failing.") + raise + + async def _force_kill_docker_container_async(self) -> bool: + """异步强制终止Docker容器""" + if not self.container_name: + return True + + loop = asyncio.get_event_loop() + try: + # 使用线程池执行同步的docker命令 + result = await loop.run_in_executor( + None, + lambda: subprocess.run( + ["docker", "kill", self.container_name], + capture_output=True, + text=True, + timeout=10 + ) + ) + + if result.returncode == 0: + logging.info(f"成功强制终止容器: {self.container_name}") + return True + else: + logging.warning(f"终止容器失败: {result.stderr}") + return False + + except Exception as e: + logging.error(f"强制终止容器出错: {e}") + return False + + async def cleanup(self) -> None: + """清理服务器资源""" + async with self._cleanup_lock: + if self._is_cleaning_up: + # 等待之前的清理完成 + await self._cleanup_completed.wait() + return + + self._is_cleaning_up = True + self._cleanup_completed.clear() + + try: + logging.info(f"开始清理服务器 {self.name}") + + # 1. 标记为未初始化 + self._is_initialized = False + + # 2. 如果是Docker模式,先强制终止容器 + if self.use_docker and self.container_name: + success = await self._force_kill_docker_container_async() + if success: + # 从注册表中移除 + self._registry.unregister_container(self.container_name) + await asyncio.sleep(0.5) + + self.session = None + self.stdio_context = None + + try: + await self.exit_stack.aclose() + logging.debug("exit_stack清理完成") + except Exception as e: + logging.warning(f"exit_stack清理出错: {e}") + finally: + self.exit_stack = AsyncExitStack() + + self.docker_process = None + self.container_id = None + if self.use_docker: + self.container_name = None # 重置容器名,允许下次重新创建 + + logging.info(f"服务器 {self.name} 清理完成") + + except Exception as e: + logging.error(f"清理过程出错: {e}") + finally: + self._is_cleaning_up = False + self._cleanup_completed.set() + + async def wait_for_cleanup(self) -> None: + """等待清理完成""" + if self._is_cleaning_up: + await self._cleanup_completed.wait() + + def is_ready_for_reuse(self) -> bool: + """检查是否可以重新使用""" + return not self._is_cleaning_up and not self._is_initialized + + def __del__(self): + """析构函数,确保Docker容器被清理""" + if self.use_docker and self.container_name: + try: + subprocess.run( + ["docker", "kill", self.container_name], + capture_output=True, + timeout=5 + ) + # 从注册表中移除 + self._registry.unregister_container(self.container_name) + except: + pass + +class Tool: + """Represents a tool with its properties and formatting.""" + + def __init__( + self, + name: str, + description: str, + input_schema: dict[str, Any], + title: str | None = None, + ) -> None: + self.name: str = name + self.title: str | None = title + self.description: str = description + self.input_schema: dict[str, Any] = input_schema + + def format_for_llm(self) -> str: + """Format tool information for LLM. + + Returns: + A formatted string describing the tool. + """ + args_desc = [] + if "properties" in self.input_schema: + for param_name, param_info in self.input_schema["properties"].items(): + arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}" + if param_name in self.input_schema.get("required", []): + arg_desc += " (required)" + args_desc.append(arg_desc) + + # Build the formatted output with title as a separate field + output = f"Tool: {self.name}\n" + + # Add human-readable title if available + if self.title: + output += f"User-readable title: {self.title}\n" + + output += f"""Description: {self.description} +Arguments: +{chr(10).join(args_desc)} +""" + + return output \ No newline at end of file diff --git a/src/client/Session.py b/src/client/Session.py new file mode 100644 index 0000000..4e3d0c3 --- /dev/null +++ b/src/client/Session.py @@ -0,0 +1,148 @@ +import json +import logging +import os +from typing import Any +from dotenv import load_dotenv +from .MCPClient import MCPClient +from ..llm.LLM import LLMClient +from ..utils.parse_json import parse_evaluation_json + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class Configuration: + """Manages configuration and environment variables for the MCP client.""" + + def __init__(self) -> None: + """Initialize configuration with environment variables.""" + self.load_env() + self.api_key = os.getenv("LLM_API_KEY") + + @staticmethod + def load_env() -> None: + """Load environment variables from .env file.""" + load_dotenv() + + @staticmethod + def load_config(file_path: str) -> dict[str, Any]: + """Load server configuration from JSON file. + + Args: + file_path: Path to the JSON configuration file. + + Returns: + Dict containing server configuration. + + Raises: + FileNotFoundError: If configuration file doesn't exist. + JSONDecodeError: If configuration file is invalid JSON. + """ + with open(file_path, "r") as f: + return json.load(f) + + @property + def llm_api_key(self) -> str: + """Get the LLM API key. + + Returns: + The API key as a string. + + Raises: + ValueError: If the API key is not found in environment variables. + """ + if not self.api_key: + raise ValueError("LLM_API_KEY not found in environment variables") + return self.api_key + + +class ChatSession: + def __init__(self, server: MCPClient, llm_client: LLMClient) -> None: + self.server = server + self.llm_client = llm_client + + async def process_llm_response(self, llm_response: str) -> str: + """Process the LLM response and execute tools if needed. + + Args: + llm_response: The response from the LLM. + + Returns: + The result of tool execution or the original response. + """ + tool_info = { + "tool_name": "", + "arguments": {}, + } + try: + tool_call = parse_evaluation_json(llm_response) + if tool_call and "tool" in tool_call and "arguments" in tool_call: + print(f"Executing tool: {tool_call['tool']}") + print(f"With arguments: {tool_call['arguments']}") + tool_info["tool_name"] = tool_call["tool"] + tool_info["arguments"] = tool_call["arguments"] + + tools = await self.server.list_tools() + if any(tool.name == tool_call["tool"] for tool in tools): + try: + result = await self.server.execute_tool(tool_call["tool"], tool_call["arguments"]) + result_str = f"{result}" + if len(result_str) > 500: + logging.info(f"The output of tool execution is too long. Only show part of it: {result[:400]}... {result[-100:]}") + else: + logging.info(f"Tool execution result: {result}") + return tool_info, f"Tool execution result: {result}" ###这里有问题,把输出变成str了 + except Exception as e: + error_msg = f"Error executing tool: {str(e)}" + print(error_msg) + return tool_info, error_msg + + return tool_info, f"No server found with tool: {tool_call['tool']}" + return tool_info, f"tool call json decode error: {llm_response}" + except json.JSONDecodeError: + return tool_info, llm_response + + async def handle_query(self, query) -> None: + all_tools = [] + # for server in self.servers: + tools = await self.server.list_tools() + all_tools.extend(tools) + + tools_description = "\n".join([tool.format_for_llm() for tool in all_tools]) + + system_message = f"""You are a helpful assistant with access to these tools: +{tools_description} +Choose the appropriate tool based on the user's question. If no tool is needed, reply directly. +IMPORTANT: When you need to use a tool, you must ONLY respond with the exact JSON object format below, nothing else: +```json +{{ + "tool": "tool-name", + "arguments": {{ + "argument-name": "value" + }} +}} +``` +After receiving a tool's response: +1. Transform the raw data into a natural, conversational response +2. Keep responses concise but informative +3. Focus on the most relevant information +4. Use appropriate context from the user's question +5. Avoid simply repeating the raw data + +Please use only the tools that are explicitly defined above.""" + + + messages = [{"role": "system", "content": system_message}] + messages.append({"role": "user", "content": "User query:"+query}) + + llm_response = self.llm_client.get_response(messages) + print("\nAssistant: %s", llm_response) + + tool_info, tool_result = await self.process_llm_response(llm_response) + tool_included_or_not = True if tool_result != llm_response else False + if tool_included_or_not: + return tool_included_or_not, tool_info, tool_result + else: + return tool_included_or_not, tool_info, 'No tool was used, here is the direct response: '+ llm_response + + diff --git a/src/test_generator/TestGenerator.py b/src/test_generator/TestGenerator.py index 9e97fab..b200716 100644 --- a/src/test_generator/TestGenerator.py +++ b/src/test_generator/TestGenerator.py @@ -8,7 +8,7 @@ from ..llm.LLM import LLMClient from ..type.types_def import ToolDefinition, TestCase from ..prompts.tool_prompt import tool_prompt from ..prompts.eval_prompt import eval_prompt -from ..client.Client import Configuration +from ..client.Session import Configuration from ..client.MCPClient import MCPClient from ..utils.read_source_code import ReadSourceCode diff --git a/src/utils/parse_json.py b/src/utils/parse_json.py new file mode 100644 index 0000000..7237d24 --- /dev/null +++ b/src/utils/parse_json.py @@ -0,0 +1,31 @@ +import re +import json +def parse_evaluation_json(response_text): + """ + 从LLM的响应文本中解析评估结果的JSON对象 + + 参数: + response_text: LLM返回的原始文本 + tool_name: 工具名称,用于日志输出 + + 返回: + 解析后的评估结果字典,如果解析失败则返回None + """ + # 尝试提取反引号之间的JSON内容 + json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response_text) + if json_match and json_match.group(1): + json_content = json_match.group(1) + else: + # 未找到反引号时,使用整个响应文本尝试解析 + json_content = response_text + + try: + # 解析JSON + parsed_json = json.loads(json_content) + return parsed_json + + except json.JSONDecodeError as parse_error: + print(f"JSON解析失败: {parse_error}") + print(f"尝试解析的内容: {json_content}") + + return None # 解析失败返回None \ No newline at end of file diff --git a/src/utils/read_source_code.py b/src/utils/read_source_code.py new file mode 100644 index 0000000..a0ea4d0 --- /dev/null +++ b/src/utils/read_source_code.py @@ -0,0 +1,105 @@ +from typing import List, Dict, Optional +import json +import os +import ast + +class ReadSourceCode: + def __init__(self, config_path: str = None): + with open(config_path, "r") as f: + self.config = json.load(f) + + def get_code(self, server_name: str) -> List[str]: + source_path = self.extract_source_code_path(server_name) + if not source_path: + return [] + tool_functions = self.get_mcp_tool_functions(source_path) + return tool_functions + + def extract_source_code_path(self, server_name: str) -> Optional[str]: + """ + 从 Server 的 args 中提取源代码文件(.py)的绝对路径 + :param self.config: Server 的 args 列表(如 ["--directory", "/path", "server.py"]) + :param command: Server 的 command(如 "uv"、"python3.11",辅助判断参数逻辑) + :return: 源代码绝对路径(None 表示未找到) + """ + try: + server_args = self.config["mcpServers"][server_name]['args'] + + source_file = None + work_dir = os.getcwd() # 默认工作目录(当前目录) + + for i, arg in enumerate(server_args): + if arg == "--directory" and i + 1 < len(server_args): + work_dir = server_args[i + 1] + work_dir = os.path.abspath(work_dir) + break + + for arg in server_args: + if arg.endswith(".py"): + source_file = arg + break + + if not source_file: + print("未在 args 中找到 .py 源代码文件") + return None + + if os.path.isabs(source_file): + # 若已是绝对路径,直接使用 + absolute_path = source_file + else: + # 相对路径 → 拼接工作目录 + absolute_path = os.path.join(work_dir, source_file) + + # 验证文件是否存在 + if os.path.exists(absolute_path): + return absolute_path + else: + print(f"源代码文件不存在:{absolute_path}") + return None + except Exception as e: + print(f"Error: {e}") + return None + + def get_mcp_tool_functions(self, source_path: str) -> Dict[str, str]: + """ + 解析源代码文件,提取被 @mcp.tool() 装饰的函数名和对应函数代码 + :param source_path: 源代码文件路径(.py) + :return: 字典,键为函数名,值为函数完整代码字符串 + $$$待修改 + """ + + tool_functions = {} # 改为字典存储函数名和代码 + + # 读取源代码内容 + with open(source_path, "r", encoding="utf-8") as f: + source_code = f.read() + + # 用 ast 解析抽象语法树 + tree = ast.parse(source_code) + # 遍历语法树,找到函数定义并检查装饰器 + for node in ast.walk(tree): + # 只处理函数定义节点(def 函数) + if isinstance(node, ast.FunctionDef): + # 检查函数是否有 @mcp.tool() 装饰器 + for decorator in node.decorator_list: + # 处理两种装饰器形式:@mcp.tool() 或 @mcp.tool(name="xxx") + is_mcp_tool = False + if isinstance(decorator, ast.Call): + # 装饰器带参数(如 @mcp.tool(name="test")) + if (isinstance(decorator.func, ast.Attribute) + and decorator.func.value.id == "mcp" + and decorator.func.attr == "tool"): + is_mcp_tool = True + elif isinstance(decorator, ast.Attribute): + # 装饰器不带参数(如 @mcp.tool) + if decorator.value.id == "mcp" and decorator.attr == "tool": + is_mcp_tool = True + + if is_mcp_tool: + # 提取函数完整代码(包括装饰器、文档字符串和函数体) + function_code = ast.get_source_segment(source_code, node) + if function_code: + tool_functions[node.name] = function_code.strip() + break # 找到匹配的装饰器后跳出循环 + + return tool_functions -- Gitee From ae876ad0c84fb78e11c78738b3d8491d79dfe8b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B2=88=E6=82=A6?= Date: Mon, 29 Sep 2025 08:48:51 +0000 Subject: [PATCH 4/4] Revert "add client" This reverts commit 94630e1e0a49a02d612740a50c057f1a297bb783. --- main.py | 2 + src/client/DockerRegistry.py | 187 --------- src/client/MCPClient.py | 582 ---------------------------- src/client/Session.py | 148 ------- src/test_generator/TestGenerator.py | 2 +- src/utils/parse_json.py | 31 -- src/utils/read_source_code.py | 105 ----- 7 files changed, 3 insertions(+), 1054 deletions(-) delete mode 100644 src/client/DockerRegistry.py delete mode 100644 src/client/MCPClient.py delete mode 100644 src/client/Session.py delete mode 100644 src/utils/parse_json.py delete mode 100644 src/utils/read_source_code.py diff --git a/main.py b/main.py index a3b71cb..23b4c04 100644 --- a/main.py +++ b/main.py @@ -49,6 +49,8 @@ async def gen_cases(config_path): from src.test_generator.TestGenerator import TestGenerator generator = TestGenerator(config_path=config_path) return await generator.run() + + async def val_cases(config_path, testcase_path): from src.validator.Response_validator_withenv import ResponseValidator_withenv validator = ResponseValidator_withenv(config_path=config_path, testcase_path=testcase_path) diff --git a/src/client/DockerRegistry.py b/src/client/DockerRegistry.py deleted file mode 100644 index 5eb0506..0000000 --- a/src/client/DockerRegistry.py +++ /dev/null @@ -1,187 +0,0 @@ -import logging -import asyncio -import atexit -from typing import Set -import signal -import subprocess -import sys -import os - -class DockerContainerRegistry: - """全局Docker容器注册表,确保程序退出时清理所有容器""" - _instance = None - _containers: Set[str] = set() - _cleanup_lock = asyncio.Lock() - _initialized = False - _cleanup_in_progress = False # 添加清理状态标志 - _signal_count = 0 # 信号计数器 - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @classmethod - def initialize(cls): - """初始化全局清理机制""" - if cls._initialized: - return - - instance = cls() - - atexit.register(instance._sync_cleanup_all) - - def signal_handler(signum, frame): - instance._handle_signal(signum) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - cls._initialized = True - logging.info("Docker容器注册表已初始化") - - def _handle_signal(self, signum): - """处理信号,避免重复清理""" - self._signal_count += 1 - - if self._cleanup_in_progress: - if self._signal_count <= 2: - logging.info(f"清理正在进行中,请稍候... (信号计数: {self._signal_count})") - return - elif self._signal_count <= 5: - logging.warning(f"强制中断清理过程... (信号计数: {self._signal_count})") - return - else: - logging.error("多次中断信号,强制退出程序") - os._exit(1) - - logging.info(f"接收到信号 {signum},开始清理Docker容器...") - self._cleanup_in_progress = True - - try: - self._sync_cleanup_all() - except Exception as e: - logging.error(f"清理过程中出错: {e}") - finally: - logging.info("程序退出") - sys.exit(0) - - def register_container(self, container_name: str): - """注册容器""" - self._containers.add(container_name) - logging.debug(f"注册Docker容器: {container_name}") - - def unregister_container(self, container_name: str): - """注销容器""" - self._containers.discard(container_name) - logging.debug(f"注销Docker容器: {container_name}") - - def _sync_cleanup_all(self): - """同步清理所有注册的容器""" - if not self._containers or self._cleanup_in_progress: - return - - self._cleanup_in_progress = True - - try: - logging.info(f"开始清理 {len(self._containers)} 个Docker容器...") - containers_to_clean = self._containers.copy() - - for container_name in containers_to_clean: - try: - result = subprocess.run( - ["docker", "kill", container_name], - capture_output=True, - text=True, - timeout=3 # 减少超时时间 - ) - if result.returncode == 0: - logging.info(f"成功清理容器: {container_name}") - self._containers.discard(container_name) - else: - logging.warning(f"清理容器失败: {container_name}, {result.stderr}") - except subprocess.TimeoutExpired: - logging.warning(f"清理容器 {container_name} 超时,跳过") - except Exception as e: - logging.error(f"清理容器 {container_name} 出错: {e}") - - # 如果还有容器未清理,尝试强制清理 - if self._containers: - logging.info("尝试强制清理剩余容器...") - for container_name in list(self._containers): - try: - subprocess.run( - ["docker", "rm", "-f", container_name], - capture_output=True, - text=True, - timeout=2 - ) - self._containers.discard(container_name) - logging.info(f"强制清理容器: {container_name}") - except Exception as e: - logging.error(f"强制清理容器 {container_name} 失败: {e}") - - logging.info("Docker容器清理完成") - - finally: - self._cleanup_in_progress = False - - async def async_cleanup_all(self): - """异步清理所有容器""" - if not self._containers: - return - - async with self._cleanup_lock: - if self._cleanup_in_progress: - return - - self._cleanup_in_progress = True - - try: - containers_to_clean = list(self._containers) - - # 并发清理所有容器,但添加超时限制 - tasks = [] - for container_name in containers_to_clean: - task = asyncio.create_task(self._async_kill_container(container_name)) - tasks.append(task) - - if tasks: - # 设置总体超时时间 - try: - results = await asyncio.wait_for( - asyncio.gather(*tasks, return_exceptions=True), - timeout=10 # 总体超时10秒 - ) - for container_name, result in zip(containers_to_clean, results): - if isinstance(result, Exception): - logging.error(f"异步清理容器 {container_name} 失败: {result}") - else: - self.unregister_container(container_name) - except asyncio.TimeoutError: - logging.warning("异步清理容器超时") - finally: - self._cleanup_in_progress = False - - async def _async_kill_container(self, container_name: str): - """异步终止单个容器""" - loop = asyncio.get_event_loop() - try: - result = await loop.run_in_executor( - None, - lambda: subprocess.run( - ["docker", "kill", container_name], - capture_output=True, - text=True, - timeout=5 # 减少单个容器的超时时间 - ) - ) - if result.returncode == 0: - logging.info(f"异步清理容器成功: {container_name}") - return True - else: - logging.warning(f"异步清理容器失败: {container_name}, {result.stderr}") - return False - except Exception as e: - logging.error(f"异步清理容器 {container_name} 出错: {e}") - return False \ No newline at end of file diff --git a/src/client/MCPClient.py b/src/client/MCPClient.py deleted file mode 100644 index bfb5b4d..0000000 --- a/src/client/MCPClient.py +++ /dev/null @@ -1,582 +0,0 @@ -import asyncio -import sys -import logging -import subprocess -import os -import time -import shutil -import docker -from docker.errors import NotFound, APIError -from contextlib import AsyncExitStack -from typing import Any, Optional -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client -from ..utils.parse_json import parse_evaluation_json -from .DockerRegistry import DockerContainerRegistry - -class MCPClient: - """MCP Client, 支持可靠的Docker生命周期管理""" - - def __init__(self, name: str, config: dict[str, Any], env_script: str = "", use_docker: bool = False) -> None: - self.name: str = name - self.config: dict[str, Any] = config - self.session: Optional[ClientSession] = None - self._cleanup_lock: asyncio.Lock = asyncio.Lock() - self.exit_stack: AsyncExitStack = AsyncExitStack() - - self.env_script = env_script - - - # 状态管理 - self._is_initialized = False - self._is_cleaning_up = False - self._cleanup_completed = asyncio.Event() - - # Docker相关配置 - self.use_docker = use_docker - self.abs_script_path = self.get_command_script_path() - self.host_mcp_path = self.abs_script_path.split('src')[0] if self.abs_script_path else "" - self.container_mcp_path = "/app/" - self.server_port = config.get("port", 8080) - - # Docker进程管理 - self.docker_process = None - self.container_id = None - self.container_name = None - - self.client = docker.from_env() - - # 初始化全局容器注册表 - if use_docker: - DockerContainerRegistry.initialize() - self._registry = DockerContainerRegistry() - - async def initialize(self) -> None: - """初始化服务器""" - if self._is_initialized: - logging.warning(f"服务器 {self.name} 已经初始化") - return - - try: - logging.info(f"开始初始化服务器 {self.name}") - - if self.use_docker: - await self._initialize_docker() - else: - await self._initialize_host_server() - - self._is_initialized = True - - except Exception as e: - logging.error(f"初始化失败: {e}") - await self.cleanup() - raise - - async def _initialize_host_server(self) -> None: - """在主机上启动MCP服务器""" - command = shutil.which("npx") if self.config["command"] == "npx" else self.config["command"] - if command is None: - raise ValueError(f"主机命令不存在: {self.config['command']}") - - server_params = StdioServerParameters( - command=command, - args=self.config["args"], - env={**os.environ, **self.config["env"]} if self.config.get("env") else None, - ) - - try: - stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) - read, write = stdio_transport - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await session.initialize() - self.session = session - logging.info(f"主机上的MCP服务器 {self.name} 已初始化") - except Exception as e: - logging.error(f"主机服务器初始化失败: {e}") - raise - - def _build_docker_command(self) -> list[str]: - """构建Docker运行命令""" - self.container_name = f"mcp-server-{self.name}-{int(time.time())}" - - docker_cmd = [ - "docker", "run", - "--rm", - "-i", - "--name", self.container_name, - "--workdir", self.container_mcp_path, - ] - - # 挂载主机MCP代码目录到容器 - docker_cmd.extend([ - "-v", f"{self.host_mcp_path}:{self.container_mcp_path}" - ]) - - # 添加环境变量 - env_vars = { - "PYTHONPATH": self.container_mcp_path, - "PYTHONUNBUFFERED": "1", - "PIP_ROOT_USER_ACTION": "ignore", - } - env_vars.update(self.config.get("env", {})) - - for key, value in env_vars.items(): - docker_cmd.extend(["-e", f"{key}={value}"]) - - docker_cmd.extend(["-a", "stdout", "-a", "stderr"]) - - # 添加Docker镜像 - self.docker_image = "val:latest" - docker_cmd.append(self.docker_image) - - startup_script = self._build_correct_bash_script() - docker_cmd.extend(["bash", "-c", startup_script]) - - return docker_cmd - def get_command_script_path(self) -> str: - """获取命令脚本路径""" - try: - server_args = self.config['args'] - source_file = None - work_dir = os.getcwd() - for i, arg in enumerate(server_args): - if arg == "--directory" and i + 1 < len(server_args): - work_dir = server_args[i + 1] - work_dir = os.path.abspath(work_dir) - break - - for arg in server_args: - if arg.endswith(".py"): - source_file = arg - break - - if not source_file: - logging.warning("未在 args 中找到 .py 源代码文件") - return None - - if os.path.isabs(source_file): - absolute_path = source_file - else: - absolute_path = os.path.join(work_dir, source_file) - - if os.path.exists(absolute_path): - return absolute_path - else: - logging.error(f"源代码文件不存在:{absolute_path}") - return None - except Exception as e: - logging.error(f"获取脚本路径出错: {e}") - return None - def _build_correct_bash_script(self) -> str: - """构建启动脚本""" - container_command = self._get_container_command() - script = f'''set -e - - echo "=== 快速环境检查 ===" - echo "Python版本: $(python --version)" - echo "工作目录: $(pwd)" - echo "文件列表:" - ls -la - echo "" - - # 只安装项目特定的新依赖 - echo "=== 检查并安装项目依赖 ===" - if [ -f requirements.txt ]; then - echo "发现requirements.txt文件" - pip install -qq --upgrade-strategy only-if-needed -r requirements.txt - if [ $? -eq 0 ]; then - echo "依赖安装完成(无异常)" - else - echo "依赖安装失败!(可去掉 -qq 参数重新执行查看详细错误)" - fi - else - echo "未找到requirements.txt文件" - fi - - echo "=== 执行自定义环境部署 ===" - {self.env_script} - - echo "=== 启动MCP服务器 ===" - echo "执行命令: {container_command}" - exec {container_command}''' - - return script - - def _get_container_command(self) -> str: - """获取容器内的命令字符串""" - command = self.config.get("command", "python") - if command == "uv": - command = "uv run" - script_rel_path = 'src'+self.abs_script_path.split('src')[-1] - return command + " " + script_rel_path - - - async def _initialize_docker(self): - """初始化Docker中的MCP服务器,支持输出显示""" - original_docker_command = self._build_docker_command() - - # 修改Docker命令,使用tee同时输出到终端和MCP - docker_command = self._add_output_redirection(original_docker_command) - - logging.info(f"启动Docker命令: {' '.join(docker_command)}") - - # 注册容器到全局注册表 - if self.container_name: - self._registry.register_container(self.container_name) - - # 清理可能存在的同名容器 - await self._cleanup_existing_container() - - # 使用修改后的命令建立MCP连接(只启动一个进程) - server_params = StdioServerParameters( - command=docker_command[0], - args=docker_command[1:], - env=None - ) - - try: - stdio_transport = await self.exit_stack.enter_async_context( - stdio_client(server_params)) - - read, write = stdio_transport - session = await self.exit_stack.enter_async_context( - ClientSession(read, write) - ) - - # 等待容器稳定 - await asyncio.sleep(3) - self.search_for_container() - - # 启动监控任务 - monitor_task = await self._start_container_monitoring(session) - - # 注册清理回调 - self._register_cleanup_callback(monitor_task) - - self.session = session - logging.info(f"Docker MCP服务器 {self.name} 已初始化") - - return session - - except Exception as e: - logging.error(f"Docker MCP服务器初始化失败: {str(e)}") - raise - - def _add_output_redirection(self, docker_command): - """为Docker命令添加输出重定向""" - # 找到bash脚本部分 - if len(docker_command) >= 3 and docker_command[-2] == "-c": - # 修改bash脚本,添加tee命令 - original_script = docker_command[-1] - - # 将输出同时发送到stderr(显示在终端)和stdout(给MCP) - modified_script = f''' - # 设置输出重定向 - exec > >(tee /dev/stderr) - exec 2>&1 - - # 原始脚本 - {original_script} - ''' - new_command = docker_command[:-1] + [modified_script] - return new_command - - return docker_command - - async def _cleanup_existing_container(self): - """清理可能存在的同名容器""" - if not self.container_name: - return - - try: - import subprocess - stop_cmd = f"docker stop {self.container_name} 2>/dev/null || true" - remove_cmd = f"docker rm {self.container_name} 2>/dev/null || true" - - subprocess.run(stop_cmd, shell=True, capture_output=True) - subprocess.run(remove_cmd, shell=True, capture_output=True) - - logging.debug(f"已清理可能存在的容器: {self.container_name}") - except Exception as e: - logging.warning(f"清理现有容器时出错: {str(e)}") - - async def _start_container_monitoring(self, session): - """启动容器监控和会话初始化""" - try: - # 创建监控任务 - async def monitor(): - """容器状态监控循环""" - try: - while True: - await asyncio.sleep(2) - self.search_for_container() - except asyncio.CancelledError: - logging.debug("监控任务被取消") - raise - except Exception as e: - logging.error(f"监控任务错误: {str(e)}") - raise - - # 启动初始化和监控任务 - init_task = asyncio.create_task(session.initialize()) - monitor_task = asyncio.create_task(monitor()) - - try: - # 等待任务完成 - done, pending = await asyncio.wait( - [init_task, monitor_task], - return_when=asyncio.FIRST_COMPLETED - ) - - # 取消还在运行的任务 - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # 检查完成的任务是否有异常 - for task in done: - await task - - # 返回监控任务引用 - return monitor_task - - except Exception as e: - # 清理任务 - for task in [init_task, monitor_task]: - if not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - raise - - except Exception as e: - logging.error(f"启动容器监控失败: {str(e)}") - raise - - def _register_cleanup_callback(self, monitor_task): - """注册清理回调函数""" - def cleanup(): - """清理所有资源""" - try: - # 清理监控任务 - if monitor_task and not monitor_task.done(): - monitor_task.cancel() - except Exception as cleanup_error: - logging.warning(f"清理过程中出现错误: {str(cleanup_error)}") - - self.exit_stack.callback(cleanup) - - def search_for_container(self): - try: - container = self.client.containers.get(self.container_name) - if container.status != "running": - raise RuntimeError(f"容器{self.container_name}未处于运行状态,当前状态: {container.status}") - except NotFound: - raise RuntimeError(f"容器{self.container_name}不存在") - except APIError as e: - raise RuntimeError(f"Docker API错误: {str(e)}") - - async def list_tools(self) -> list[Any]: - """列出可用工具""" - if not self.session: - raise RuntimeError(f"Server {self.name} not initialized") - - tools_response = await self.session.list_tools() - tools = [] - - for item in tools_response: - if isinstance(item, tuple) and item[0] == "tools": - tools.extend(Tool(tool.name, tool.description, tool.inputSchema) for tool in item[1]) - - return tools - - async def execute_tool( - self, - tool_name: str, - arguments: dict[str, Any], - retries: int = 1, - delay: float = 1.0, - ) -> list[Any]: - """执行工具""" - if not self.session: - raise RuntimeError(f"Server {self.name} not initialized") - - attempt = 0 - while attempt < retries: - try: - logging.info(f"Executing {tool_name}...") - result = await self.session.call_tool(tool_name, arguments) - tool_result = [] - for rc in result.content: - if rc.type == "text": - if '{' and '}' in rc.text: - try: - # 假设parse_evaluation_json函数存在 - rc_text_json = parse_evaluation_json(rc.text) - tool_result.append(rc_text_json) - except: - tool_result.append(rc.text) - else: - tool_result.append(rc.text) - elif rc.type == "image": - logging.warning("Image result is not supported yet") - elif rc.type == "resource": - logging.warning("Resource result is not supported yet") - return tool_result - except Exception as e: - attempt += 1 - logging.warning(f"Error executing tool: {e}. Attempt {attempt} of {retries}.") - if attempt < retries: - await asyncio.sleep(delay) - else: - logging.error("Max retries reached. Failing.") - raise - - async def _force_kill_docker_container_async(self) -> bool: - """异步强制终止Docker容器""" - if not self.container_name: - return True - - loop = asyncio.get_event_loop() - try: - # 使用线程池执行同步的docker命令 - result = await loop.run_in_executor( - None, - lambda: subprocess.run( - ["docker", "kill", self.container_name], - capture_output=True, - text=True, - timeout=10 - ) - ) - - if result.returncode == 0: - logging.info(f"成功强制终止容器: {self.container_name}") - return True - else: - logging.warning(f"终止容器失败: {result.stderr}") - return False - - except Exception as e: - logging.error(f"强制终止容器出错: {e}") - return False - - async def cleanup(self) -> None: - """清理服务器资源""" - async with self._cleanup_lock: - if self._is_cleaning_up: - # 等待之前的清理完成 - await self._cleanup_completed.wait() - return - - self._is_cleaning_up = True - self._cleanup_completed.clear() - - try: - logging.info(f"开始清理服务器 {self.name}") - - # 1. 标记为未初始化 - self._is_initialized = False - - # 2. 如果是Docker模式,先强制终止容器 - if self.use_docker and self.container_name: - success = await self._force_kill_docker_container_async() - if success: - # 从注册表中移除 - self._registry.unregister_container(self.container_name) - await asyncio.sleep(0.5) - - self.session = None - self.stdio_context = None - - try: - await self.exit_stack.aclose() - logging.debug("exit_stack清理完成") - except Exception as e: - logging.warning(f"exit_stack清理出错: {e}") - finally: - self.exit_stack = AsyncExitStack() - - self.docker_process = None - self.container_id = None - if self.use_docker: - self.container_name = None # 重置容器名,允许下次重新创建 - - logging.info(f"服务器 {self.name} 清理完成") - - except Exception as e: - logging.error(f"清理过程出错: {e}") - finally: - self._is_cleaning_up = False - self._cleanup_completed.set() - - async def wait_for_cleanup(self) -> None: - """等待清理完成""" - if self._is_cleaning_up: - await self._cleanup_completed.wait() - - def is_ready_for_reuse(self) -> bool: - """检查是否可以重新使用""" - return not self._is_cleaning_up and not self._is_initialized - - def __del__(self): - """析构函数,确保Docker容器被清理""" - if self.use_docker and self.container_name: - try: - subprocess.run( - ["docker", "kill", self.container_name], - capture_output=True, - timeout=5 - ) - # 从注册表中移除 - self._registry.unregister_container(self.container_name) - except: - pass - -class Tool: - """Represents a tool with its properties and formatting.""" - - def __init__( - self, - name: str, - description: str, - input_schema: dict[str, Any], - title: str | None = None, - ) -> None: - self.name: str = name - self.title: str | None = title - self.description: str = description - self.input_schema: dict[str, Any] = input_schema - - def format_for_llm(self) -> str: - """Format tool information for LLM. - - Returns: - A formatted string describing the tool. - """ - args_desc = [] - if "properties" in self.input_schema: - for param_name, param_info in self.input_schema["properties"].items(): - arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}" - if param_name in self.input_schema.get("required", []): - arg_desc += " (required)" - args_desc.append(arg_desc) - - # Build the formatted output with title as a separate field - output = f"Tool: {self.name}\n" - - # Add human-readable title if available - if self.title: - output += f"User-readable title: {self.title}\n" - - output += f"""Description: {self.description} -Arguments: -{chr(10).join(args_desc)} -""" - - return output \ No newline at end of file diff --git a/src/client/Session.py b/src/client/Session.py deleted file mode 100644 index 4e3d0c3..0000000 --- a/src/client/Session.py +++ /dev/null @@ -1,148 +0,0 @@ -import json -import logging -import os -from typing import Any -from dotenv import load_dotenv -from .MCPClient import MCPClient -from ..llm.LLM import LLMClient -from ..utils.parse_json import parse_evaluation_json - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - - -class Configuration: - """Manages configuration and environment variables for the MCP client.""" - - def __init__(self) -> None: - """Initialize configuration with environment variables.""" - self.load_env() - self.api_key = os.getenv("LLM_API_KEY") - - @staticmethod - def load_env() -> None: - """Load environment variables from .env file.""" - load_dotenv() - - @staticmethod - def load_config(file_path: str) -> dict[str, Any]: - """Load server configuration from JSON file. - - Args: - file_path: Path to the JSON configuration file. - - Returns: - Dict containing server configuration. - - Raises: - FileNotFoundError: If configuration file doesn't exist. - JSONDecodeError: If configuration file is invalid JSON. - """ - with open(file_path, "r") as f: - return json.load(f) - - @property - def llm_api_key(self) -> str: - """Get the LLM API key. - - Returns: - The API key as a string. - - Raises: - ValueError: If the API key is not found in environment variables. - """ - if not self.api_key: - raise ValueError("LLM_API_KEY not found in environment variables") - return self.api_key - - -class ChatSession: - def __init__(self, server: MCPClient, llm_client: LLMClient) -> None: - self.server = server - self.llm_client = llm_client - - async def process_llm_response(self, llm_response: str) -> str: - """Process the LLM response and execute tools if needed. - - Args: - llm_response: The response from the LLM. - - Returns: - The result of tool execution or the original response. - """ - tool_info = { - "tool_name": "", - "arguments": {}, - } - try: - tool_call = parse_evaluation_json(llm_response) - if tool_call and "tool" in tool_call and "arguments" in tool_call: - print(f"Executing tool: {tool_call['tool']}") - print(f"With arguments: {tool_call['arguments']}") - tool_info["tool_name"] = tool_call["tool"] - tool_info["arguments"] = tool_call["arguments"] - - tools = await self.server.list_tools() - if any(tool.name == tool_call["tool"] for tool in tools): - try: - result = await self.server.execute_tool(tool_call["tool"], tool_call["arguments"]) - result_str = f"{result}" - if len(result_str) > 500: - logging.info(f"The output of tool execution is too long. Only show part of it: {result[:400]}... {result[-100:]}") - else: - logging.info(f"Tool execution result: {result}") - return tool_info, f"Tool execution result: {result}" ###这里有问题,把输出变成str了 - except Exception as e: - error_msg = f"Error executing tool: {str(e)}" - print(error_msg) - return tool_info, error_msg - - return tool_info, f"No server found with tool: {tool_call['tool']}" - return tool_info, f"tool call json decode error: {llm_response}" - except json.JSONDecodeError: - return tool_info, llm_response - - async def handle_query(self, query) -> None: - all_tools = [] - # for server in self.servers: - tools = await self.server.list_tools() - all_tools.extend(tools) - - tools_description = "\n".join([tool.format_for_llm() for tool in all_tools]) - - system_message = f"""You are a helpful assistant with access to these tools: -{tools_description} -Choose the appropriate tool based on the user's question. If no tool is needed, reply directly. -IMPORTANT: When you need to use a tool, you must ONLY respond with the exact JSON object format below, nothing else: -```json -{{ - "tool": "tool-name", - "arguments": {{ - "argument-name": "value" - }} -}} -``` -After receiving a tool's response: -1. Transform the raw data into a natural, conversational response -2. Keep responses concise but informative -3. Focus on the most relevant information -4. Use appropriate context from the user's question -5. Avoid simply repeating the raw data - -Please use only the tools that are explicitly defined above.""" - - - messages = [{"role": "system", "content": system_message}] - messages.append({"role": "user", "content": "User query:"+query}) - - llm_response = self.llm_client.get_response(messages) - print("\nAssistant: %s", llm_response) - - tool_info, tool_result = await self.process_llm_response(llm_response) - tool_included_or_not = True if tool_result != llm_response else False - if tool_included_or_not: - return tool_included_or_not, tool_info, tool_result - else: - return tool_included_or_not, tool_info, 'No tool was used, here is the direct response: '+ llm_response - - diff --git a/src/test_generator/TestGenerator.py b/src/test_generator/TestGenerator.py index b200716..9e97fab 100644 --- a/src/test_generator/TestGenerator.py +++ b/src/test_generator/TestGenerator.py @@ -8,7 +8,7 @@ from ..llm.LLM import LLMClient from ..type.types_def import ToolDefinition, TestCase from ..prompts.tool_prompt import tool_prompt from ..prompts.eval_prompt import eval_prompt -from ..client.Session import Configuration +from ..client.Client import Configuration from ..client.MCPClient import MCPClient from ..utils.read_source_code import ReadSourceCode diff --git a/src/utils/parse_json.py b/src/utils/parse_json.py deleted file mode 100644 index 7237d24..0000000 --- a/src/utils/parse_json.py +++ /dev/null @@ -1,31 +0,0 @@ -import re -import json -def parse_evaluation_json(response_text): - """ - 从LLM的响应文本中解析评估结果的JSON对象 - - 参数: - response_text: LLM返回的原始文本 - tool_name: 工具名称,用于日志输出 - - 返回: - 解析后的评估结果字典,如果解析失败则返回None - """ - # 尝试提取反引号之间的JSON内容 - json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response_text) - if json_match and json_match.group(1): - json_content = json_match.group(1) - else: - # 未找到反引号时,使用整个响应文本尝试解析 - json_content = response_text - - try: - # 解析JSON - parsed_json = json.loads(json_content) - return parsed_json - - except json.JSONDecodeError as parse_error: - print(f"JSON解析失败: {parse_error}") - print(f"尝试解析的内容: {json_content}") - - return None # 解析失败返回None \ No newline at end of file diff --git a/src/utils/read_source_code.py b/src/utils/read_source_code.py deleted file mode 100644 index a0ea4d0..0000000 --- a/src/utils/read_source_code.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import List, Dict, Optional -import json -import os -import ast - -class ReadSourceCode: - def __init__(self, config_path: str = None): - with open(config_path, "r") as f: - self.config = json.load(f) - - def get_code(self, server_name: str) -> List[str]: - source_path = self.extract_source_code_path(server_name) - if not source_path: - return [] - tool_functions = self.get_mcp_tool_functions(source_path) - return tool_functions - - def extract_source_code_path(self, server_name: str) -> Optional[str]: - """ - 从 Server 的 args 中提取源代码文件(.py)的绝对路径 - :param self.config: Server 的 args 列表(如 ["--directory", "/path", "server.py"]) - :param command: Server 的 command(如 "uv"、"python3.11",辅助判断参数逻辑) - :return: 源代码绝对路径(None 表示未找到) - """ - try: - server_args = self.config["mcpServers"][server_name]['args'] - - source_file = None - work_dir = os.getcwd() # 默认工作目录(当前目录) - - for i, arg in enumerate(server_args): - if arg == "--directory" and i + 1 < len(server_args): - work_dir = server_args[i + 1] - work_dir = os.path.abspath(work_dir) - break - - for arg in server_args: - if arg.endswith(".py"): - source_file = arg - break - - if not source_file: - print("未在 args 中找到 .py 源代码文件") - return None - - if os.path.isabs(source_file): - # 若已是绝对路径,直接使用 - absolute_path = source_file - else: - # 相对路径 → 拼接工作目录 - absolute_path = os.path.join(work_dir, source_file) - - # 验证文件是否存在 - if os.path.exists(absolute_path): - return absolute_path - else: - print(f"源代码文件不存在:{absolute_path}") - return None - except Exception as e: - print(f"Error: {e}") - return None - - def get_mcp_tool_functions(self, source_path: str) -> Dict[str, str]: - """ - 解析源代码文件,提取被 @mcp.tool() 装饰的函数名和对应函数代码 - :param source_path: 源代码文件路径(.py) - :return: 字典,键为函数名,值为函数完整代码字符串 - $$$待修改 - """ - - tool_functions = {} # 改为字典存储函数名和代码 - - # 读取源代码内容 - with open(source_path, "r", encoding="utf-8") as f: - source_code = f.read() - - # 用 ast 解析抽象语法树 - tree = ast.parse(source_code) - # 遍历语法树,找到函数定义并检查装饰器 - for node in ast.walk(tree): - # 只处理函数定义节点(def 函数) - if isinstance(node, ast.FunctionDef): - # 检查函数是否有 @mcp.tool() 装饰器 - for decorator in node.decorator_list: - # 处理两种装饰器形式:@mcp.tool() 或 @mcp.tool(name="xxx") - is_mcp_tool = False - if isinstance(decorator, ast.Call): - # 装饰器带参数(如 @mcp.tool(name="test")) - if (isinstance(decorator.func, ast.Attribute) - and decorator.func.value.id == "mcp" - and decorator.func.attr == "tool"): - is_mcp_tool = True - elif isinstance(decorator, ast.Attribute): - # 装饰器不带参数(如 @mcp.tool) - if decorator.value.id == "mcp" and decorator.attr == "tool": - is_mcp_tool = True - - if is_mcp_tool: - # 提取函数完整代码(包括装饰器、文档字符串和函数体) - function_code = ast.get_source_segment(source_code, node) - if function_code: - tool_functions[node.name] = function_code.strip() - break # 找到匹配的装饰器后跳出循环 - - return tool_functions -- Gitee