From 32d44664cbeddc7779e13b02a8defb21652ff6f8 Mon Sep 17 00:00:00 2001 From: Wang Kui Date: Thu, 7 Aug 2025 10:45:08 +0800 Subject: [PATCH] add apply patch and create pr --- servers/cvekit_mcp/mcp-rpm.yaml | 2 + servers/cvekit_mcp/src/cvekit/cli.py | 61 ++++- .../src/cvekit/utils/apply_patch.py | 225 ++++++++++++++++++ .../cvekit_mcp/src/cvekit/utils/create_pr.py | 100 ++++++++ servers/cvekit_mcp/src/requirements.txt | 1 + servers/cvekit_mcp/src/server.py | 78 +++++- 6 files changed, 463 insertions(+), 4 deletions(-) create mode 100644 servers/cvekit_mcp/src/cvekit/utils/apply_patch.py create mode 100644 servers/cvekit_mcp/src/cvekit/utils/create_pr.py diff --git a/servers/cvekit_mcp/mcp-rpm.yaml b/servers/cvekit_mcp/mcp-rpm.yaml index f394e29..f795be1 100644 --- a/servers/cvekit_mcp/mcp-rpm.yaml +++ b/servers/cvekit_mcp/mcp-rpm.yaml @@ -9,9 +9,11 @@ dependencies: - python3 - git - patch + - oegitext python: - requests - PyGithub + - curl_ffi files: required: diff --git a/servers/cvekit_mcp/src/cvekit/cli.py b/servers/cvekit_mcp/src/cvekit/cli.py index 94243b7..b26abd4 100644 --- a/servers/cvekit_mcp/src/cvekit/cli.py +++ b/servers/cvekit_mcp/src/cvekit/cli.py @@ -3,11 +3,14 @@ import sys import logging import os import json + from tabulate import tabulate from .utils.gitee import parse_gitee_issue_url, setup_repository from .utils.commits import get_vulnerability_commits from .utils.branches import process_branches from .utils.cache import get_cached_data, save_cache +from .utils.apply_patch import apply_patch +from .utils.create_pr import create_pr logger = logging.getLogger(__name__) @@ -27,7 +30,7 @@ def main(): # 操作模式参数 parser.add_argument('--action', type=str, - choices=['parse-issue', 'get-commits', 'analyze-branches', 'setup-env'], + choices=['parse-issue', 'get-commits', 'analyze-branches', 'setup-env', 'apply-patch', 'create-pr'], default='analyze-branches', help='''执行模式: parse-issue(解析issue), @@ -44,6 +47,8 @@ def main(): env_group = parser.add_argument_group('仓库环境参数') env_group.add_argument('--fork-repo-url', type=str, help='Fork仓库URL (也可通过FORK_REPO_URL环境变量设置)') + env_group.add_argument('--repo-url', type=str, + help='仓库URL (也可通过REPO_URL环境变量设置)') env_group.add_argument('--gitee-token', type=str, help='Gitee访问令牌(也可通过GITEE_TOKEN环境变量设置)') env_group.add_argument('--clone-dir', type=str, @@ -65,10 +70,16 @@ def main(): output_group.add_argument('--no-cache', action='store_true', help='禁用缓存') output_group.add_argument('--debug', action='store_true', help='开启调试模式') + # 提交pr参数 + pr_group = parser.add_argument_group('提交pr') + pr_group.add_argument('--patch-path', type=str, help='patch文件路径') + pr_group.add_argument('--branch', type=str, help='pr提交分支') + args = parser.parse_args() # 从环境变量获取参数默认值 args.fork_repo_url = args.fork_repo_url or os.environ.get('FORK_REPO_URL', "https://gitee.com/lw520203/kernel_4") + args.repo_url = args.repo_url or os.environ.get('REPO_URL', "https://gitee.com/openeuler/kernel") args.gitee_token = args.gitee_token or os.environ.get('GITEE_TOKEN') args.clone_dir = args.clone_dir or os.environ.get('CLONE_DIR', os.path.join(os.path.expanduser("~"), "Image")) args.branches = args.branches or os.environ.get('BRANCHES', "OLK-5.10,OLK-6.6,master") @@ -120,7 +131,17 @@ def main(): def handle_action(args): """路由到不同操作处理器""" if args.action == 'setup-env': - return setup_repository(args.fork_repo_url, args.gitee_token, args.clone_dir) + try: + repo, repo_path = setup_repository(args.fork_repo_url, args.gitee_token, args.clone_dir) + except Exception as e: + return { + 'status': 'failed', + 'message': str(e) + } + return { + 'status': 'success', + 'repo_path': repo_path + } cve_id = args.cve_id if args.cve_id else fetch_cve_id(args.issue_url, args.gitee_token, not args.no_cache) @@ -128,8 +149,42 @@ def handle_action(args): return handle_parse_issue(args) elif args.action == 'get-commits': return handle_get_commits(cve_id, not args.no_cache) - else: # analyze-branches + elif args.action == 'analyze-branches': # analyze-branches return handle_analyze_branches(args) + elif args.action == 'apply-patch': + return handle_apply_patch(cve_id, args) + elif args.action == 'create-pr': + return handle_create_pr(cve_id, args) + else: + raise RuntimeError("action not supported: %s", args.action) + +def handle_apply_patch(cve_id, args): + """应用patch并把branch提交到fork分支""" + result = apply_patch( + fork_repo_url=args.fork_repo_url, + gitee_token=args.gitee_token, + branch=args.branch, + clone_dir=args.clone_dir, + patch_path=args.patch_path, + signer_name=args.signer_name, + signer_email=args.signer_email, + cve_id=cve_id, + issue_url=args.issue_url + ) + return result + +def handle_create_pr(cve_id, args): + """创建pr""" + result = create_pr( + cve_id=cve_id, + issue_url=args.issue_url, + fork_repo_url=args.fork_repo_url, + repo_url=args.repo_url, + branch=args.branch, + clone_dir=args.clone_dir, + gitee_token=args.gitee_token + ) + return result def handle_parse_issue(args): """处理issue解析逻辑""" diff --git a/servers/cvekit_mcp/src/cvekit/utils/apply_patch.py b/servers/cvekit_mcp/src/cvekit/utils/apply_patch.py new file mode 100644 index 0000000..8a322de --- /dev/null +++ b/servers/cvekit_mcp/src/cvekit/utils/apply_patch.py @@ -0,0 +1,225 @@ +import git +import logging +import os +import re +import subprocess + +from .gitee import setup_repository +from .patch import getUrlText +from .commits import get_vulnerability_commits + +logger = logging.getLogger(__name__) + + +def get_commit_reference(commit_id, repo_path): + # 判断目录是否存在 + if not os.path.exists(repo_path): + # 获取上一层目录 + parent_dir = os.path.dirname(repo_path) + result = subprocess.run( + ["git", "clone", "https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git", repo_path], + check=True, + cwd=parent_dir, + capture_output=True, + text=True + ) + if not os.path.exists(repo_path): + result = subprocess.run( + ["git", "clone", "https://kernel.googlesource.com/pub/scm/linux/kernel/git/stable/linux.git", repo_path], + check=True, + cwd=parent_dir, + capture_output=True, + text=True + ) + if not os.path.exists(repo_path): + result = subprocess.run( + ["git", "clone", "https://gitee.com/mirrors/linux_old1.git", repo_path], + check=True, + cwd=parent_dir, + capture_output=True, + text=True + ) + if not os.path.exists(repo_path): + raise RuntimeError("linux仓库克隆失败: https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git") + + repo = git.Repo(repo_path) + subprocess.run( + ["git", "pull"], + check=True, + cwd=repo_path, + capture_output=True, + text=True + ) + """获取提交的引用信息,如mainline版本或stable版本""" + is_stable = True + try: + name_rev = repo.git.name_rev(commit_id) + # 不带rc版本号的tag为stable版本 + if '-rc' in name_rev: + is_stable = False + # 解析name_rev输出,格式通常为: tags/~ + match = re.search(r'tags/([^~]+)', name_rev) + if match: + tag_name = match.group(1) + return tag_name, is_stable + return "unknown", is_stable + except Exception as e: + logger.error(f"获取提交引用失败: {e}") + return "unknown", is_stable + + +def generate_patch_header(commit_id, cve_id, bugzilla_url, patch_url, repo_path): + """生成符合规范的补丁头""" + ref_version, is_stable = get_commit_reference(commit_id, repo_path) + + inclusion_type = "stable inclusion" if is_stable else "mainline inclusion" + if ref_version == "unknown": + from_line = f"from mainline" if not is_stable else f"from stable" + else: + from_line = f"from mainline-{ref_version}" if not is_stable else f"from stable-{ref_version}" + + commit_text = getUrlText(patch_url) + pattern = re.compile(r"
(.+?)
", re.S) + subject = re.search(pattern, commit_text).group(1) + pattern = re.compile(r"
(.+?)
", re.S) + msg = re.search(pattern, commit_text).group(1) + msg = msg.replace('<', '<').replace('>', '>') + + header = f"""{subject} + +{inclusion_type} +{from_line} +commit id: {commit_id} +bugzilla: {bugzilla_url} +CVE: {cve_id} + +Reference: {patch_url} + +------------------- + +{msg} +""" + return header + + +def generate_commit_message(cve_id, issue_url, repo_path): + """生成commit信息""" + introduced_commit, fixed_commit = get_vulnerability_commits(cve_id) + patch_url = f"https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id={fixed_commit}" + message = generate_patch_header(fixed_commit, cve_id, issue_url, patch_url, repo_path=repo_path) + + return message + + +def apply_patch( + fork_repo_url: str, + gitee_token: str, + branch: str, + clone_dir: str, + patch_path: str, + signer_name: str, + signer_email: str, + cve_id: str, + issue_url: str, +): + """合并分支并且提交 + + Args: + fork_repo_url: git仓库地址 + gitee_token: Gitee访问令牌(必须有仓库写入权限) + branch: 处理的分支 + clone_dir: 本地克隆目录 + patch_path: patch文件路径 + signer_name: 签名者名称 + signer_email: 签名者邮箱 + cve_id: cve id + issue_url: issue链接 + + Returns: + patch应用信息字典 + """ + try: + commit_msg = generate_commit_message(cve_id, issue_url, repo_path=os.path.join(clone_dir, 'linux')) + except Exception as e: + logger.error(f"生成commit信息失败: {str(e)}") + return { + "status": "error", + "error": f"生成commit信息失败: {str(e)}" + } + # 解析fork URL获取组织名和仓库名 + parts = fork_repo_url.strip().rstrip('/').split('/') + org_name = parts[-2] + repo_name = parts[-1].replace('.git', '') + + repo, repo_path = setup_repository(fork_repo_url, gitee_token, clone_dir) + repo.git.pull() + branches = repo.git.branch().split() + try: + if branch in branches: + repo.git.checkout(branch) + else: + repo.git.checkout('-b', branch, f'origin/{branch}') + except Exception as e: + logger.error(f"切换分支失败: {str(e)}") + return { + "status": "error", + "error": f"切换分支失败: {str(e)}" + } + fix_branch = branch + try: + # 执行 git am patch_path + repo.git.apply(patch_path) + logger.info("补丁成功应用") + except git.exc.GitCommandError as e: + logger.error(f"应用补丁失败: {str(e)}") + + # 检查是否处于 am 过程中的冲突状态 + if "Applying" in str(e): + repo.git.am("--abort") + return { + "status": "error", + "error": f"无法完成补丁应用,请检查冲突并重试。: {str(e)}" + } + else: + repo.git.am("--abort") # 非冲突错误,直接中止 + logger.info("已中止补丁应用过程") + return { + "status": "error", + "error": f"无法应用补丁: {str(e)}" + } + + # 添加所有变更并提交 + repo.git.add("--all") + repo.git.commit("-m", commit_msg, "-s", f"--author={signer_name} <{signer_email}>") + + remote = f"fork-{org_name}" + # 推送变更到远程仓库 + repo_remote = None + for repo_remote in repo.remotes: + if repo_remote.name == remote: + break + if not repo_remote: + repo_remote = repo.create_remote(remote, fork_url) + try: + logger.info(f"开始推送变更到远程仓库: {repo_remote.url}") + repo.git.push(remote, fix_branch) + logger.info("变更推送成功") + except Exception as e: + try: + repo.git.push(f"{remote} --set-upstream", fix_branch) + except Exception as e: + try: + repo.git.push(remote, fix_branch, "--force") + except Exception as e: + logger.error(f"推送变更失败: {str(e)}") + return { + "status": "error", + "error": f"无法推送变更: {str(e)}" + } + + return { + "status": "success", + "remote": remote, + "branch": branch, + "repo_path": repo_path, + } diff --git a/servers/cvekit_mcp/src/cvekit/utils/create_pr.py b/servers/cvekit_mcp/src/cvekit/utils/create_pr.py new file mode 100644 index 0000000..0308e5e --- /dev/null +++ b/servers/cvekit_mcp/src/cvekit/utils/create_pr.py @@ -0,0 +1,100 @@ +import git +import logging +import subprocess +import re +import json + +from .patch import getUrlText +from .commits import get_vulnerability_commits + +logger = logging.getLogger(__name__) + + +def generate_pr_body(cve_id, issue_url): + """读取标题和内容""" + introduced_commit, fixed_commit = get_vulnerability_commits(cve_id) + + commit_url = f"https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id={fixed_commit}" + commit_text = getUrlText(commit_url) + pattern = re.compile(r"
(.+?)
", re.S) + subject = re.search(pattern, commit_text).group(1) + result = f"""{subject} + +{issue_url} +""" + return result + + +def create_pr( + cve_id: str, + issue_url: str, + gitee_token: str, + fork_repo_url: str, + repo_url: str, + branch: str, + clone_dir: str + ): + """创建PR + Args: + cve_id: cve id + issue_url: issue链接 + gitee_token: Gitee访问令牌(必须有仓库写入权限) + fork_repo_url: fork仓库地址 + repo_url: 提交PR仓库目标地址 + branch: 处理的分支 + clone_dir: 本地克隆目录 + + Returns: + 是否合并提交成功 + """ + # 解析repo URL获取组织名和仓库名 + parts = fork_repo_url.strip().rstrip('/').split('/') + head_org_name = parts[-2] + head_repo_name = parts[-1].replace('.git', '') + parts = repo_url.strip().rstrip('/').split('/') + base_org_name = parts[-2] + base_repo_name = parts[-1].replace('.git', '') + title = f"Fix {cve_id}" + body = generate_pr_body(cve_id, issue_url) + + try: + subprocess.run( + ['oegitext', 'config', '-token', gitee_token], + check=True, + cwd=clone_dir, + capture_output=True, + text=True + ) + except Exception as e: + logger.error(f"oegitext配置token失败: {str(e)}") + return { + "status": "error", + "error": f"oegitext配置token失败: {str(e)}" + } + + cmd = [ + 'oegitext', 'pull', '-cmd', 'create', '-user', base_org_name, '-repo', base_repo_name, + '-title', title, '-head', f'{head_org_name}/{head_repo_name}:{branch}', '-base', f'{branch}', + '-body', body, '-show' + ] + + try: + result = subprocess.run( + cmd, + check=True, + cwd=clone_dir, + capture_output=True, + text=True + ) + except Exception as e: + logger.error(f"提交pr失败: {str(e)}") + return { + "status": "error", + "error": f"提交pr失败: {str(e)}" + } + result = json.loads(result.stdout) + + return { + "status": "success", + "pr_html_url": result['html_url'], + } diff --git a/servers/cvekit_mcp/src/requirements.txt b/servers/cvekit_mcp/src/requirements.txt index 870186d..0937b4d 100644 --- a/servers/cvekit_mcp/src/requirements.txt +++ b/servers/cvekit_mcp/src/requirements.txt @@ -2,3 +2,4 @@ requests>=2.31.0 mcp>=1.0.0 PyGithub>=2.0.0 gitpython>=3.1.40 +curl_ffi diff --git a/servers/cvekit_mcp/src/server.py b/servers/cvekit_mcp/src/server.py index fd86213..d6bdb45 100644 --- a/servers/cvekit_mcp/src/server.py +++ b/servers/cvekit_mcp/src/server.py @@ -44,6 +44,26 @@ def run_cvekit(action: str, params: dict) -> dict: if 'signer_email' in params: cmd.append(f'--signer-email={params["signer_email"]}') + elif action == 'apply-patch': + if 'patch_path' in params: + cmd.append(f'--patch-path={params["patch_path"]}') + if 'fork_repo_url' in params: + cmd.append(f'--fork-repo-url={params["fork_repo_url"]}') + if 'branch' in params: + cmd.append(f'--branch={params["branch"]}') + if 'signer_name' in params: + cmd.append(f'--signer-name={params["signer_name"]}') + if 'signer_email' in params: + cmd.append(f'--signer-email={params["signer_email"]}') + + elif action == 'create-pr': + if 'branch' in params: + cmd.append(f'--branch={params["branch"]}') + if 'fork_repo_url' in params: + cmd.append(f'--fork-repo-url={params["fork_repo_url"]}') + if 'repo_url' in params: + cmd.append(f'--repo-url={params["repo_url"]}') + result = subprocess.run( cmd, check=True, @@ -138,7 +158,7 @@ def get_commits( @mcp.tool() def analyze_branches( issue_url: str = Field(..., description="Gitee Issue URL"), - branches: Optional[str] = Field(None, description="要分析的分支列表,逗号分隔"), + branches: Optional[str] = Field('OLK-5.10,OLK-6.6,master', description="要分析的分支列表,逗号分隔"), signer_name: Optional[str] = Field(None, description="提交者姓名"), signer_email: Optional[str] = Field(None, description="提交者邮箱"), gitee_token: Optional[str] = Field(None, description="Gitee访问令牌(可选)") @@ -147,6 +167,8 @@ def analyze_branches( 该函数是CVE修复流程的第四步: 分析introduced_commit在本地仓库的哪些分支被引入,如果引入的话,是否被fixed了,以此来分析哪些分支需要应用补丁 并检查从上游获取的补丁直接应用,是否存在冲突 + 该步骤中的参数branches为kernel的分支名,和issue分析中的受影响版本并不完全一致,若用户未指定要分析的分支名,采 + 用默认值即可 """ result = run_cvekit('analyze-branches', { 'issue_url': issue_url, @@ -181,5 +203,59 @@ def analyze_branches( "请确认以上分析结果" ) +@mcp.tool() +def apply_patch( + issue_url: str = Field(..., description="Gitee Issue URL"), + branch: Optional[str] = Field(description="要应用patch的分支名"), + fork_repo_url: Optional[str] = Field(description="fork仓库url"), + patch_path: Optional[str] = Field(description="patch路径"), + signer_name: Optional[str] = Field(description="提交者姓名"), + signer_email: Optional[str] = Field(None, description="提交者邮箱"), + gitee_token: Optional[str] = Field(None, description="Gitee访问令牌(可选)") +) -> str: + """ + 该函数是CVE修复流程的第五步: + 对于第四步中分析出的受影响分支,分别应用相对应的patch,参数中的patch_path为第四步的冲突点, + 若patch应用成功,提交之后,把该分支推送到fork仓,若patch应用失败,尝试解决冲突后,重新执行该步骤 + """ + result = run_cvekit('apply-patch', { + 'issue_url': issue_url, + 'branch': branch, + 'fork_repo_url': fork_repo_url, + 'patch_path': patch_path, + 'signer_name': signer_name, + 'signer_email': signer_email, + 'gitee_token': gitee_token + }) + + if 'error' in result or 'error' in result.get('status'): + return f"应用patch失败: {result['error']}" + return 'patch应用成功' + +@mcp.tool() +def create_pr( + issue_url: str = Field(..., description="Gitee Issue URL"), + branch: Optional[str] = Field(None, description="提交pr源分支名和目标分支名"), + fork_repo_url: Optional[str] = Field(None, description="fork仓库url"), + repo_url: Optional[str] = Field('https://gitee.com/openeuler/kernel', description="目标仓库url"), + signer_name: Optional[str] = Field(None, description="提交者姓名"), + signer_email: Optional[str] = Field(None, description="提交者邮箱"), + gitee_token: Optional[str] = Field(description="Gitee访问令牌") +) -> str: + """ + 该函数是CVE修复流程的第六步: + 对于第五步中推送成功的分支,提交pr,若用户未提供目标仓库url,则使用默认的目标仓库 + """ + result = run_cvekit('create-pr', { + 'issue_url': issue_url, + 'branch': branch, + 'fork_repo_url': fork_repo_url, + 'repo_url': repo_url, + 'gitee_token': gitee_token + }) + if 'error' in result or 'error' in result.get('status'): + return f"pr提交失败: {result.get('error')}" + return f"pr已提交: {result.get('pr_html_url')}" + if __name__ == "__main__": mcp.run() \ No newline at end of file -- Gitee