From 270791a7a4b771fcb6b24b8be50c3b11c373cc0f Mon Sep 17 00:00:00 2001 From: shenyue_ustc Date: Fri, 10 Oct 2025 14:32:18 +0800 Subject: [PATCH 1/3] add calidator --- src/client/MCPClient.py | 17 +- src/test_generator/TestGenerator.py | 2 +- src/validator/Response_validator_withenv.py | 539 ++++++++++++++++++++ 3 files changed, 551 insertions(+), 7 deletions(-) create mode 100644 src/validator/Response_validator_withenv.py diff --git a/src/client/MCPClient.py b/src/client/MCPClient.py index 8f10172..a33f197 100644 --- a/src/client/MCPClient.py +++ b/src/client/MCPClient.py @@ -56,21 +56,25 @@ class MCPClient: """获取命令脚本路径""" try: server_args = self.config['args'] + work_dir = os.getcwd() source_file = None - work_dir = os.getcwd() + for i, arg in enumerate(server_args): if arg == "--directory" and i + 1 < len(server_args): - work_dir = server_args[i + 1] - work_dir = os.path.abspath(work_dir) - break + work_dir = os.path.abspath(server_args[i + 1]) + break - for arg in server_args: + for i, arg in enumerate(server_args): if arg.endswith(".py"): source_file = arg + if i > 0 and server_args[i-1] == "run": + break + elif arg == "run" and i + 1 < len(server_args) and server_args[i+1].endswith(".py"): + source_file = server_args[i+1] break if not source_file: - logging.warning("未在 args 中找到 .py 源代码文件") + logging.warning("未在args中找到.py源代码文件") return None if os.path.isabs(source_file): @@ -83,6 +87,7 @@ class MCPClient: else: logging.error(f"源代码文件不存在:{absolute_path}") return None + except Exception as e: logging.error(f"获取脚本路径出错: {e}") return None diff --git a/src/test_generator/TestGenerator.py b/src/test_generator/TestGenerator.py index 920ce33..366785b 100644 --- a/src/test_generator/TestGenerator.py +++ b/src/test_generator/TestGenerator.py @@ -128,7 +128,7 @@ class TestGenerator: # Extract input schema properties safely input_properties = {} if tool.input_schema and hasattr(tool.input_schema, 'properties'): - input_properties = json.dumps(input_properties, indent=2) + input_properties = json.dumps(tool.input_schema["properties"], indent=2) formatted_prompt = tool_prompt.format( tool=tool, diff --git a/src/validator/Response_validator_withenv.py b/src/validator/Response_validator_withenv.py new file mode 100644 index 0000000..793e8b5 --- /dev/null +++ b/src/validator/Response_validator_withenv.py @@ -0,0 +1,539 @@ +import json +import logging +import os +from pathlib import Path +import re +from typing import List, Dict, Any + +from jinja2 import Environment +from jsonschema import validate, ValidationError, SchemaError + +from ..llm.LLM import LLMClient +from ..client.MCPClient import MCPClient +from ..client.MakeConfig import Configuration +from ..client.Session import ChatSession +from ..utils.parse_json import parse_evaluation_json +from ..prompts.env_prompt import env_prompt, not_pass_judge_prompt +from ..prompts.val_prompt import val_prompt_tool, val_prompt_eval + +class ResponseValidator_withenv: + """ + Validator using LLM-generated rules + """ + + def __init__(self, api_key: str = None, config_path: str = None, testcase_path: str=None, max_attempts: int = 3): + """ + Create a new test validator + Args: + api_key: API key for the language model + """ + Config_class = Configuration() + self.config = Config_class.load_config(config_path) + self.llm = LLMClient(api_key) + self.testcase_path = testcase_path + self.max_attempts = max_attempts + with open(self.testcase_path, 'r', encoding='utf-8') as file: + self.testcases = json.load(file) + self.env_script = '' + self.jinja_env = Environment() + + async def run(self): + # get test case + if not self.testcases: + Warning('No testcase found in the file {self.testcase_path}. Nothing to validate.') + + res = [] + for case in self.testcases: + try: + server_name = self.testcase_path.split('/')[-2].split('_2025')[0] + # load config + + print("\n========================================") + print(f"Validating Server: {server_name}") + print("========================================\n") + + logging.info(f"\n--- Validating Test Case ID: {id} for Tool: {case['toolName']} ---") + validation_log_tool = await self.tool_validation(case, server_name) + + logging.info("\n--- Validating End-to-end Test Case ---") + validation_log_eval = await self.eval_validation(case, server_name) + result = { + "id": case["id"], + "toolName": case["toolName"], + "input": case["input"], + "description": case["description"], + "query": case["query"], + "expect": case["expect"]["status"], + "env_script": self.env_script, + "validation_tool": { + "output": validation_log_tool["output"], + "passed": validation_log_tool["passed"], + "rule_results": validation_log_tool["message"], + }, + "validation_eval": { + "output": validation_log_eval["output"], + "passed": validation_log_eval["passed"], + "message": validation_log_eval["message"], + }, + } + res.append(result) + + except Exception as e: + logging.error(f"主程序出错: {e}") + raise + + self.save_to_file(server_name, res) + + async def tool_validation(self, case, server_name): + """ + 验证用例工具是否通过测试。 + """ + id = case["id"] + srv_config = self.config["mcpServers"][server_name] + + attempt = 0 + option_not_passed = "" #没有passed tool测试的原因选项 "a. Unmet special environmental requirement", "b. Issue with the validation rule itself" + env_script = "" + reason_not_passed = "" #没有passed tool测试的具体原因 + server = None + validation_log_tool = None + history = [] + + while attempt < self.max_attempts: + print(f"Attempt {attempt + 1} of {self.max_attempts}") + server = None + try: + # 实例化服务器 + server = MCPClient(server_name, srv_config, env_script, use_docker=True) + await server.initialize() + validation_log_tool = await self.validate_toolcase(server,case) + + # 验证通过 + if validation_log_tool.get("passed"): + logging.info(f"✅ 测试用例 [ID: {id}] 工具接口测试验证通过") + self.env_script = env_script ### 存起来获取环境部署脚本,供eval_validation使用 + return validation_log_tool + + if not option_not_passed: + option_not_passed = self.get_not_pass_option(server, validation_log_tool, case) + if not option_not_passed: + logging.warning("🔴 未从LLM响应中提取到有效脚本,本次尝试跳过") + attempt += 1 + continue + if option_not_passed.get("option","").startswith('b'): + ###如果原因是rule本身的问题,规则验证必然失败,先跳出 + break + reason_not_passed = option_not_passed.get("reason","") + + # 尝试获取环境部署脚本 + new_env_script = self.get_env_script(server, validation_log_tool,reason_not_passed, case,history) + if not new_env_script: + logging.warning("🔴 未从LLM响应中提取到有效脚本,本次尝试跳过") + attempt += 1 + continue + logging.info(f"成功提取环境调整脚本(长度: {len(new_env_script)} 字符)") + env_script = new_env_script + history.append({"attempt":attempt, "env_script": new_env_script, "output":validation_log_tool["output"]}) + attempt += 1 + + except Exception as e: + logging.error(f"❌ 尝试 {attempt + 1} 失败: {str(e)}") + attempt += 1 + + finally: + if server: + try: + await server.cleanup() + logging.info("服务器资源已清理") + except Exception as cleanup_err: + logging.error(f"清理服务器资源时出错: {str(cleanup_err)}") + return validation_log_tool + def get_dependencies(self, server): + dependencies = "" + if not server.host_mcp_path: + print("Error: server.host_mcp_path is not set") + else: + root = Path(server.host_mcp_path) + for req_file in root.rglob("requirements.txt"): # 递归搜索所有子目录 + if req_file.is_file(): + try: + dependencies = req_file.read_text(encoding="utf-8") + print(f"在 {req_file} 中找到依赖项:\n{dependencies}") + except Exception as e: + print(f"读取文件失败 {req_file}: {e}") + return dependencies + def get_not_pass_option(self, server, validation_log_tool, case): + judge_prompt = "" + + dependencies = self.get_dependencies(server) + + judge_prompt = not_pass_judge_prompt.format( + dependencies=dependencies, + testcases = json.dumps(case), + output = json.dumps(validation_log_tool["output"]), + validation_results = json.dumps(validation_log_tool["message"])) + + judge_output = self.llm.get_response([{"role": "user", "content": judge_prompt}]) + judge_json = parse_evaluation_json(judge_output) + if not judge_json: + return None + logging.info(f"❌ The case is not passed due to {judge_json['reason']}") + return judge_json + + def get_env_script(self, server, validation_log_tool, reason_not_passed, case, history): + dependencies = self.get_dependencies(server) + + env_template = self.jinja_env.from_string(env_prompt) + + history_text = "" + for part in history: + history_text += f"Attempt {part['attempt']}:\nbash script:\n{part['env_script']}\nOutput:\n{json.dumps(part['output'])}\n\n" + + env_vars = { + "reason": reason_not_passed, + "dependencies": dependencies, + "testcases": json.dumps(case), + "output": json.dumps(validation_log_tool["output"]), + "validation_results": json.dumps(validation_log_tool["message"]), + "history": history_text, + } + + + env_prompt_formatted = env_template.render(**env_vars) + env_output = self.llm.get_response([{"role": "user", "content": env_prompt_formatted}]) + + # prompt = env_prompt.format( + # reason = reason_not_passed, + # dependencies=dependencies, + # testcases = json.dumps(case), + # output = json.dumps(validation_log_tool["output"]), + # validation_results = json.dumps(validation_log_tool["message"])) + + # env_output = self.llm.get_response([{"role": "user", "content": prompt}]) + script_pattern = re.compile(r'(.*?)', re.DOTALL | re.IGNORECASE) + matches = script_pattern.findall(env_output) + env_script = matches[0].strip() if matches else "" + return env_script + + async def validate_toolcase(self, server, case: json) -> Dict[str, Any]: + """ + Validate a single tool case + + Args: + tool_name: Name of the tool to be tested + expect_status: Expected status of the response + validation_rules: Rules for validating the response content + + Returns: + A dictionary with validation results + """ + + # Send request to the tool via MCP server + + output = await server.execute_tool(case["toolName"], case["input"]) + + if not output: + print(f"No output received for Test Case ID: {case['id']}") + return { + "id": case["id"], + "toolName": case["toolName"], + "input": case["input"], + "description": case["description"], + "query": case["query"], + "expect": case["expect"]['status'], + "message": f"No output received for Test Case ID: {case['id']}", + "passed": False, + "output": output, + } + + validation_rules = case["expect"]['validation_rules'] if 'validation_rules' in case['expect'] else [] + rule_results = [] + expect_status = case["expect"]['status'] if 'status' in case['expect'] else "success" + + + ## 如果工具输出结果是json列表,则使用schema rule + ## 否则使用contains rule + + type_expect = 'schema' + for o in output: + if not isinstance(o, dict): + type_expect = 'contains' + break + + if type_expect != 'schema': + validation_rules = [rule for rule in validation_rules if rule["type"] in ['contains',"equals","llm"]] + + for i,rule in enumerate(validation_rules): + valid_per_rule, message_per_rule = self.validate_single_rule(output, rule, case) + + rule_for_log = rule.copy() + if rule_for_log.get('type')=="schema": + rule_for_log['value'] = json.dumps(rule_for_log["value"], ensure_ascii=False, separators=(',', ':')) + + if not valid_per_rule: + ### LLM parse失败,或者schema本身有问题,并不能说明是规则验证失败 + if message_per_rule!="Failed to parse LLM validation response." and \ + not message_per_rule.startswith("Invalid schema -"): + rule_results.append({ + f"rule": rule_for_log, + "rule_passed": False, + "rule_error": message_per_rule + } + ) + else: + rule_results.append({ + f"rule": rule_for_log, + "rule_passed": True, + "rule_error": "passed" + } + ) + passed_results = [r['rule_passed'] for r in rule_results if r['rule_passed']] + if rule_results and len(passed_results) >= len(rule_results)/2: + all_passed = True + print(f"All validations passed for Test Case ID: {case['id']}") + else: + all_passed = False + print(f"Validation result mismatch for Test Case ID: {case['id']}") + print(f"Expected status: {expect_status}") + print(f"Rule results: {rule_results}") + + # truncate output + output = self.truncate_output(output) + + return { + "id": case["id"], + "toolName": case["toolName"], + "input": case["input"], + "description": case["description"], + "query": case["query"], + "expect": case["expect"]['status'], + "message": rule_results, + "passed": all_passed, + "output": output, + } + + def validate_single_rule(self, output: List, rule, case): + if rule['type'] == 'contains': + if "contain" in rule['value']: + rule['value'] = rule['value'].split("contain")[-1] + if "contains" in rule['value']: + rule['value'] = rule['value'].split("contains")[-1] + rule["value"] = self.clean_value(rule['value']).lower() + if all(isinstance(o,dict) for o in output): + output_cat = json.dumps(output, ensure_ascii=False, separators=(',', ':')) + else: + output_cat = ' '.join(str(o) for o in output) + if rule['value'] in output_cat.lower(): + return True, f"" + else: + return False, f"Output does not contain expected substring: {rule['value']}" + + + elif rule['type'] == 'equals': + if not isinstance(rule['value'],list): + rule['value'] = [rule['value']] + if rule['value'] == output: + return True,f"" + else: + return False, f"Output does not exactly equal expected value: {rule['value']}" + + elif rule['type'] == 'schema': + if not isinstance(rule['value'],dict): + return False, f"Invalid schema" + if rule['value'].get('type') == 'array': + try: + validate(instance=output, schema=rule['value']) + return True, "Schema validation passed." + except ValidationError as e: + return False, f"Schema validation failed: {str(e)}" + except SchemaError as e: + return False, f"Invalid schema - {str(e)}" + + elif rule['value'].get('type') == 'object': + try: + validate(instance=output[0], schema=rule['value']) + return True, "Schema validation passed." + except ValidationError as e: + return False, f"Schema validation failed: {str(e)}" + except SchemaError as e: + return False, f"Invalid schema - {str(e)}" + + elif rule['type'] == 'llm': + output = '\n'.join(str(o) for o in output) + valid, message = self.llm_rule_validation(output, rule, case) + return valid, message + + + def clean_value(self, raw_value: str) -> str: + BAD_CHARS = {'"', ':', '!', '@', '$', '%', '&', ';', ',', '.', ' ', '\t', '\n',"'"} + cleaned = raw_value.strip(''.join(BAD_CHARS)) + + while cleaned and cleaned[-1] in BAD_CHARS: + cleaned = cleaned[:-1] + + if not cleaned: + return "" + + return cleaned + + + def llm_rule_validation(self, output: str, rule: json, case: json) -> Dict[str, Any]: + + try: + val_prompt_tool_formatted = val_prompt_tool.format( + tool_name=case["toolName"], + input=case["input"], + validation_rule=rule['value'], + output=output + ) + llm_response = self.llm.get_response([{"role": "user", "content": val_prompt_tool_formatted}]) + llm_result = parse_evaluation_json(llm_response) + content_valid = llm_result.get('answer', '').lower() == 'yes' + message = llm_result.get('explanation', '') + return content_valid, message + except json.JSONDecodeError: + Warning("Failed to parse LLM validation response.") + return False,"Failed to parse LLM validation response." + + async def eval_validation(self, case, server_name): + """ + 通过端到端测试。 + """ + id = case["id"] + srv_config = self.config["mcpServers"][server_name] + + server = None + validation_log_eval = None + + # 实例化服务器,传入全局环境部署配置 + server = MCPClient(server_name, srv_config, self.env_script, use_docker=True) + await server.initialize() + validation_log_eval = await self.validate_evalcase(server, case) + + if validation_log_eval.get("passed"): + logging.info(f"✅ 测试用例 [ID: {id}] 端到端测试验证通过") + else: + logging.error(f"❌ 测试用例 [ID: {id}] 端到端测试验证失败") + await server.cleanup() + return validation_log_eval + + async def validate_evalcase(self, server, case): + query = case["query"] + self.session = ChatSession(server, self.llm) + tool_included_or_not, tool_info, session_result = await self.session.handle_query(query) + eval_results = [] + if tool_included_or_not: + tool_name = tool_info["tool_name"] + expected_tool = case["toolName"] + ## 首先判断是否调用正确的工具 + if tool_name != expected_tool: + eval_results= { + "passed": False, + "message": f"Expect using tool: {expected_tool}, but got {tool_name}", + } + + else: + expect_rules = case["expect"]['validation_rules'] if 'validation_rules' in case['expect'] else [] + + res_expect = [] + for rule in expect_rules: + if rule["type"] in ["contains","equals"]: + res_expect.append(rule['message']) + elif rule["type"] == "llm": + res_expect.append(rule['value']) + + expect_results = ' '.join(res_expect) + + if len(session_result)>1000: + session_result = session_result[:800] + '...' + session_result[-200:] + + + template = self.jinja_env.from_string(val_prompt_eval) + if expect_results: + test_case_vars = { + "query": case["query"], + "expect_type": case["expect"]['status'], + "expected_output": "**Expected Output:** " + expect_results, + "output": session_result + } + else: + test_case_vars = { + "query": case["query"], + "expect_type": case["expect"]['status'], + "expected_output": '', + "output": session_result + } + + val_prompt_eval_formatted = template.render(**test_case_vars) + val_response = self.llm.get_response([{"role": "user", "content": val_prompt_eval_formatted}]) + + + val_result = parse_evaluation_json(val_response) + if val_result: + passed = val_result.get('answer', '').lower() == 'yes' + message = val_result.get('explanation', '') + eval_results= { + "passed": passed, + "message": message, + } + else: + eval_results= { + "passed": False, + "message": "Failed to parse LLM evaluation response." + } + + else: + eval_results= { + "passed": False, + "message": "No tool was called in the response." + } + # truncate output + if isinstance(session_result, str): + output = session_result[:800]+'...'+session_result[-200:] if len(session_result)>1000 else session_result + elif isinstance(session_result, list): + output = session_result[:10] if len(session_result)>10 else session_result + + return { + "id": case["id"], + "toolName": case["toolName"], + "input": case["input"], + "description": case["description"], + "query": case["query"], + "expect": case["expect"], + "message": eval_results["message"], + "passed": eval_results['passed'], + "output": output,} + + def truncate_output(self, output): + if isinstance(output, list): + output = output[:10] if len(output)>10 else output + if all(isinstance(o, dict) for o in output): + output = [json.dumps(o, ensure_ascii=False, separators=(',', ':')) for o in output] + elif isinstance(output, str): + output = output[:800] + '...' + output[-200:] if len(output)>1000 else output + return output + + def save_to_file(self, server_name: str, validationlog: List) -> bool: + """ + save test cases (array of JSON) to file + """ + try: + if not isinstance(validationlog, list): + raise ValueError("input data should be an array of JSON") + + folerpath = os.path.dirname(self.testcase_path) + filename = "validation_results_eval_env.json" + filepath = os.path.join(folerpath,filename) + with open(filepath, 'w', encoding='utf-8') as file: + json.dump(validationlog, file, ensure_ascii=False, indent=4) + print(f"{server_name} validation results are successfully saved into {filepath}") + return True + except IOError as e: + print(f"文件操作错误: {e}") + except ValueError as e: + print(f"数据格式错误: {e}") + except Exception as e: + print(f"发生未知错误: {e}") + + return False + -- Gitee From 7bc0a8a773f4d7d85551e3f6917ae315e9fc6874 Mon Sep 17 00:00:00 2001 From: shenyue_ustc Date: Fri, 10 Oct 2025 14:37:23 +0800 Subject: [PATCH 2/3] add dockerfile --- Dockerfile | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 21 ++++++++++++++++--- 2 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..847d83c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,60 @@ +FROM openeuler/openeuler + +LABEL maintainer="shenyue24@huawei.com" + +# 基础包安装 +RUN dnf update -y && \ + dnf install -y --setopt=install_weak_deps=False \ + wget \ + findutils \ + sudo \ + libtool-ltdl \ + container-selinux \ + libseccomp \ + glibc \ + lvm2 \ + docker-client && \ + dnf clean all && \ + rm -rf /var/cache/dnf/* && \ + find /var/log -type f -delete && \ + rm -rf /tmp/* /var/tmp/* && \ + rm -rf /var/lib/dnf/history/* + +# 安装Miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py311_25.7.0-2-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p /opt/conda && \ + rm -f miniconda.sh && \ + /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ + /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r && \ + /opt/conda/bin/conda clean -afy && \ + /opt/conda/bin/conda init bash && \ + echo "conda config --set auto_activate_base true" >> ~/.bashrc && \ + echo "conda config --set notify_outdated_conda false" >> ~/.bashrc + +ENV PATH="/opt/conda/bin:$PATH" + +# 预配置pip镜像源 +RUN mkdir -p /root/.pip && \ + printf '[global]\nindex-url = https://pypi.tuna.tsinghua.edu.cn/simple/\ntrusted-host = pypi.tuna.tsinghua.edu.cn\ntimeout = 300\nretries = 3\n' > /root/.pip/pip.conf && \ + echo "pip镜像源预配置完成" + +# 预安装常用的Python包,避免每次都重复安装 +RUN pip install --no-cache-dir \ + numpy \ + pandas \ + scipy \ + matplotlib \ + seaborn \ + scikit-learn \ + uv \ + mcp==1.4.0 \ + && echo "常用Python包预安装完成" + +RUN conda --version && \ + python --version && \ + pip --version + +# 设置工作目录 +WORKDIR /app + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/main.py b/main.py index 23b4c04..9ebaabb 100644 --- a/main.py +++ b/main.py @@ -41,7 +41,18 @@ def parse_args(): default=".logs/perf_mcp_2025-09-11T07-31-04-418670/validation_results.json", help="Path to MCP Server config file" ) - + rep_parser.add_argument( + "--config", + type=str, + default=None, + help="Path to MCP Server config file" + ) + rep_parser.add_argument( + "--detailed", + type=bool, + default=False, + help="Output detailed report or not" + ) return parser.parse_args() @@ -55,13 +66,17 @@ async def val_cases(config_path, testcase_path): from src.validator.Response_validator_withenv import ResponseValidator_withenv validator = ResponseValidator_withenv(config_path=config_path, testcase_path=testcase_path) return await validator.run() - +async def rep_cases(valpath, config_path, detailed): + from src.reporter.Reporter import Reporter + Reporter = Reporter(valpath, config_path, detailed) + return await Reporter.run() async def main(): args = parse_args() if args.command == 'gen-cases': await gen_cases(args.config) if args.command == 'val-cases': await val_cases(args.config, args.testpath) - + if args.command == 'rep-cases': + await rep_cases(args.valpath, args.config, args.detailed) if __name__ == "__main__": asyncio.run(main()) \ No newline at end of file -- Gitee From 1e85e6e138a8f97d4031d281cdd9c0f9e0a832f5 Mon Sep 17 00:00:00 2001 From: shenyue_ustc Date: Fri, 10 Oct 2025 14:54:52 +0800 Subject: [PATCH 3/3] add utils --- src/utils/parse_json.py | 29 ++++++ src/utils/read_source_code.py | 98 +++++++++++++++++++++ src/validator/Response_validator_withenv.py | 8 -- 3 files changed, 127 insertions(+), 8 deletions(-) create mode 100644 src/utils/parse_json.py create mode 100644 src/utils/read_source_code.py diff --git a/src/utils/parse_json.py b/src/utils/parse_json.py new file mode 100644 index 0000000..81bc83d --- /dev/null +++ b/src/utils/parse_json.py @@ -0,0 +1,29 @@ +import re +import json +def parse_evaluation_json(response_text): + """ + 从LLM的响应文本中解析评估结果的JSON对象 + + 参数: + response_text: LLM返回的原始文本 + tool_name: 工具名称,用于日志输出 + + 返回: + 解析后的评估结果字典,如果解析失败则返回None + """ + # 尝试提取反引号之间的JSON内容 + json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response_text) + if json_match and json_match.group(1): + json_content = json_match.group(1) + else: + json_content = response_text + + try: + parsed_json = json.loads(json_content) + return parsed_json + + except json.JSONDecodeError as parse_error: + print(f"JSON解析失败: {parse_error}") + print(f"尝试解析的内容: {json_content}") + + return None \ No newline at end of file diff --git a/src/utils/read_source_code.py b/src/utils/read_source_code.py new file mode 100644 index 0000000..1be6775 --- /dev/null +++ b/src/utils/read_source_code.py @@ -0,0 +1,98 @@ +from typing import List, Dict, Optional +import json +import os +import ast + +class ReadSourceCode: + def __init__(self, config_path: str = None): + with open(config_path, "r") as f: + self.config = json.load(f) + + def get_code(self, server_name: str) -> List[str]: + source_path = self.extract_source_code_path(server_name) + if not source_path: + return [] + tool_functions = self.get_mcp_tool_functions(source_path) + return tool_functions + + def extract_source_code_path(self, server_name: str) -> Optional[str]: + """ + 从 Server 的 args 中提取源代码文件(.py)的绝对路径 + :param self.config: Server 的 args 列表(如 ["--directory", "/path", "server.py"]) + :param command: Server 的 command(如 "uv"、"python3.11",辅助判断参数逻辑) + :return: 源代码绝对路径(None 表示未找到) + """ + try: + server_args = self.config["mcpServers"][server_name]['args'] + + source_file = None + work_dir = os.getcwd() # 默认工作目录(当前目录) + + for i, arg in enumerate(server_args): + if arg == "--directory" and i + 1 < len(server_args): + work_dir = server_args[i + 1] + work_dir = os.path.abspath(work_dir) + break + + for arg in server_args: + if arg.endswith(".py"): + source_file = arg + break + + if not source_file: + print("未在 args 中找到 .py 源代码文件") + return None + + if os.path.isabs(source_file): + # 若已是绝对路径,直接使用 + absolute_path = source_file + else: + # 相对路径 → 拼接工作目录 + absolute_path = os.path.join(work_dir, source_file) + + # 验证文件是否存在 + if os.path.exists(absolute_path): + return absolute_path + else: + print(f"源代码文件不存在:{absolute_path}") + return None + except Exception as e: + print(f"Error: {e}") + return None + + def get_mcp_tool_functions(self, source_path: str) -> Dict[str, str]: + """ + 解析源代码文件,提取被 @mcp.tool() 装饰的函数名和对应函数代码 + :param source_path: 源代码文件路径(.py) + :return: 字典,键为函数名,值为函数完整代码字符串 + $$$待修改 + """ + + tool_functions = {} + + with open(source_path, "r", encoding="utf-8") as f: + source_code = f.read() + + tree = ast.parse(source_code) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + # 处理两种装饰器形式:@mcp.tool() 或 @mcp.tool(name="xxx") + is_mcp_tool = False + if isinstance(decorator, ast.Call): + if (isinstance(decorator.func, ast.Attribute) + and decorator.func.value.id == "mcp" + and decorator.func.attr == "tool"): + is_mcp_tool = True + elif isinstance(decorator, ast.Attribute): + if decorator.value.id == "mcp" and decorator.attr == "tool": + is_mcp_tool = True + + if is_mcp_tool: + # 提取函数完整代码(包括装饰器、文档字符串和函数体) + function_code = ast.get_source_segment(source_code, node) + if function_code: + tool_functions[node.name] = function_code.strip() + break + + return tool_functions diff --git a/src/validator/Response_validator_withenv.py b/src/validator/Response_validator_withenv.py index 793e8b5..5926ace 100644 --- a/src/validator/Response_validator_withenv.py +++ b/src/validator/Response_validator_withenv.py @@ -201,15 +201,7 @@ class ResponseValidator_withenv: env_prompt_formatted = env_template.render(**env_vars) env_output = self.llm.get_response([{"role": "user", "content": env_prompt_formatted}]) - - # prompt = env_prompt.format( - # reason = reason_not_passed, - # dependencies=dependencies, - # testcases = json.dumps(case), - # output = json.dumps(validation_log_tool["output"]), - # validation_results = json.dumps(validation_log_tool["message"])) - # env_output = self.llm.get_response([{"role": "user", "content": prompt}]) script_pattern = re.compile(r'(.*?)', re.DOTALL | re.IGNORECASE) matches = script_pattern.findall(env_output) env_script = matches[0].strip() if matches else "" -- Gitee