From 78e087e256824b30d73e2835342b37444d92ab72 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 12:16:10 +0800 Subject: [PATCH 01/60] =?UTF-8?q?=E9=80=82=E9=85=8Drag=E7=9A=84=E4=BD=9C?= =?UTF-8?q?=E8=80=85=E5=92=8C=E5=88=9B=E5=BB=BA=E6=97=B6=E9=97=B4=E7=9A=84?= =?UTF-8?q?=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/context.py | 3 +++ apps/scheduler/scheduler/message.py | 9 +-------- apps/schemas/record.py | 6 +++++- apps/services/document.py | 4 ++++ apps/services/rag.py | 3 +++ 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index b7088d8d..4c2c4cf0 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -114,11 +114,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: used_docs.append( RecordGroupDocument( _id=docs["id"], + author=docs.get("author", ""), + order=docs.get("order", 0), name=docs["name"], abstract=docs.get("abstract", ""), extension=docs.get("extension", ""), size=docs.get("size", 0), associated="answer", + created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)), ) ) if docs.get("order") is not None: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index c89fdd10..5997b48f 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -71,14 +71,7 @@ async def push_rag_message( # 如果是文本消息,直接拼接到答案中 full_answer += content_obj.content elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append({ - "id": content_obj.content.get("id", ""), - "order": content_obj.content.get("order", 0), - "name": content_obj.content.get("name", ""), - "abstract": content_obj.content.get("abstract", ""), - "extension": content_obj.content.get("extension", ""), - "size": content_obj.content.get("size", 0), - }) + task.runtime.documents.append(content_obj.content) # 保存答案 task.runtime.answer = full_answer await TaskManager.save_task(task.id, task) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d7acd368..e1a995f8 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -17,8 +17,9 @@ class RecordDocument(Document): """GET /api/record/{conversation_id} Result中的document数据结构""" id: str = Field(alias="_id", default="") + order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None = None + author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] class Config: @@ -103,11 +104,14 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + order: int = Field(default=0, description="文档顺序") + author: str = Field(default="", description="文档作者") name: str = Field(default="", description="文档名称") abstract: str = Field(default="", description="文档摘要") extension: str = Field(default="", description="文档扩展名") size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] + created_at: float = Field(default=0.0, description="文档创建时间") class Record(RecordData): diff --git a/apps/services/document.py b/apps/services/document.py index 203162da..451423a9 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -2,6 +2,7 @@ """文件Manager""" import base64 +from datetime import UTC, datetime import logging import uuid @@ -131,12 +132,15 @@ class DocumentManager: return [ RecordDocument( _id=doc.id, + order=doc.order, + author=doc.author, abstract=doc.abstract, name=doc.name, type=doc.extension, size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, + created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] diff --git a/apps/services/rag.py b/apps/services/rag.py index 6b6c843d..efbdfe94 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """对接Euler Copilot RAG""" +from datetime import UTC, datetime import json import logging from collections.abc import AsyncGenerator @@ -156,9 +157,11 @@ class RAG: "id": doc_chunk["docId"], "order": doc_cnt, "name": doc_chunk.get("docName", ""), + "author": doc_chunk.get("docAuthor", ""), "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), + "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] -- Gitee From c930151094eef40aa9f4e01e1ec59bd56ef75a21 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:40:30 +0800 Subject: [PATCH 02/60] =?UTF-8?q?=E5=AF=B9rag=E7=9A=84team=E5=92=8C?= =?UTF-8?q?=E5=9B=A2=E9=98=9F=E8=8E=B7=E5=8F=96=E5=A2=9E=E5=8A=A0=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E6=8D=95=E8=8E=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/services/conversation.py | 6 +++++- apps/services/knowledge.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 4bcade45..bac964db 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,7 +59,11 @@ class ConversationManager: model_name=llm.model_name, ) kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except: + logger.error("[ConversationManager] 获取团队知识库列表失败") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 9b4077f9..bd8dfc9e 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -138,7 +138,11 @@ class KnowledgeBaseManager: return [] kb_ids_update_success = [] kb_item_dict_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except Exception as e: + logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: -- Gitee From 059bf35ef464b8165ea1c816566c9c35c4bc742b Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:53:07 +0800 Subject: [PATCH 03/60] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index e1a995f8..6e9617a3 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,6 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") + user_sub: None | None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From eefe717400bca7606c0c722d5e506fd9d9704c14 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:56:50 +0800 Subject: [PATCH 04/60] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 6e9617a3..b5e1b0c5 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,7 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None | None + user_sub: None = None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From 60ea73002a36aa3452997caa649c9b80993ea2bd Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 16:21:07 +0800 Subject: [PATCH 05/60] =?UTF-8?q?=E5=AE=8C=E5=96=84DocumentAddContent?= =?UTF-8?q?=E7=9A=84=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/message.py | 2 ++ apps/schemas/message.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 5997b48f..a2a45e41 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -108,10 +108,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl data=DocumentAddContent( documentId=content_obj.content.get("id", ""), documentOrder=content_obj.content.get("order", 0), + documentAuthor=content_obj.content.get("author", ""), documentName=content_obj.content.get("name", ""), documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), + createdAt=round(datetime.now(tz=UTC).timestamp(), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d0661224..5f0aee8a 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -2,7 +2,7 @@ """队列中的消息结构""" from typing import Any - +from datetime import UTC, datetime from pydantic import BaseModel, Field from apps.schemas.enum_var import EventType, StepStatus @@ -60,10 +60,14 @@ class DocumentAddContent(BaseModel): document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") + document_author: str = Field(description="文档作者", alias="documentAuthor", default="") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) + created_at: float = Field( + description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + ) class FlowStartContent(BaseModel): -- Gitee From 7b57aa351af2a882d5059e5f5b3a618d2238deec Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 19:22:36 +0800 Subject: [PATCH 06/60] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=97=A0=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/user.py | 8 +++++++ apps/routers/api_key.py | 1 + apps/routers/auth.py | 3 +-- apps/scheduler/executor/step.py | 38 ++++++++++++++++----------------- apps/schemas/config.py | 9 +++++++- apps/schemas/message.py | 2 ++ apps/schemas/request_data.py | 4 ++-- 7 files changed, 41 insertions(+), 24 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51..87cbd290 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -5,10 +5,12 @@ import logging from fastapi import Depends from fastapi.security import OAuth2PasswordBearer +import secrets from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config from apps.services.api_key import ApiKeyManager from apps.services.session import SessionManager @@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return secrets.token_hex(16) session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( @@ -69,6 +74,9 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return Config().get_config().no_auth.user_sub session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py index 158cfc13..51366a21 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/api_key.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse + from apps.dependency.user import get_user, verify_user from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp from apps.schemas.response_data import ResponseData diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed6..4a3f8293 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates +from apps.common.config import Config from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user from apps.schemas.collection import Audit @@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: user_info = await oidc_provider.get_oidc_user(token["access_token"]) user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) except Exception as e: logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa..506f3bb1 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -86,8 +86,8 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.task.state.step_id = self.step.step_id # type: ignore[arg-type] + self.task.state.step_name = self.step.step.name # type: ignore[arg-type] # 获取并验证Call类 node_id = self.step.step.node @@ -127,13 +127,13 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.task.state.step_id # type: ignore[arg-type] + current_step_name = self.task.state.step_name # type: ignore[arg-type] # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] + self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,17 +156,17 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] + self.task.state.step_id = current_step_id # type: ignore[arg-type] + self.task.state.step_name = current_step_name # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens @@ -212,7 +212,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,22 +224,22 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": str(e), "data": {}, } return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -253,12 +253,12 @@ class StepExecutor(BaseExecutor): # 更新context history = FlowStepHistory( task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + flow_name=self.task.state.flow_name, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + status=self.task.state.status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data, ) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1..99bcccde 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -6,6 +6,13 @@ from typing import Literal from pydantic import BaseModel, Field +class NoauthConfig(BaseModel): + """无认证配置""" + + enable: bool = Field(description="是否启用无认证访问", default=False) + user_sub: str = Field(description="调试用户的sub", default="admin") + + class DeployConfig(BaseModel): """部署配置""" @@ -122,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - + no_auth: NoauthConfig deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 5f0aee8a..cf70a82b 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -24,6 +24,8 @@ class MessageFlow(BaseModel): flow_id: str = Field(description="Flow ID", alias="flowId") step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") + sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) + sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2305dd93..5d6dc7fc 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -16,8 +16,8 @@ class RequestDataApp(BaseModel): """模型对话中包含的app信息""" app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") + params: dict[str, Any] | None = Field(default=None, description="插件参数") class MockRequestData(BaseModel): -- Gitee From f02748f597d63b3b072b0749681bf1827283987e Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 21:13:04 +0800 Subject: [PATCH 07/60] =?UTF-8?q?=E5=AE=8C=E5=96=84Agent=E7=9A=84=E5=BC=80?= =?UTF-8?q?=E5=8F=91&=E4=BF=AE=E5=A4=8Dmcp=E6=B3=A8=E5=86=8C=E6=97=B6?= =?UTF-8?q?=E5=80=99=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 11 ++++++++--- apps/scheduler/call/mcp/mcp.py | 7 ------- apps/scheduler/executor/agent.py | 28 +++++++++------------------ apps/scheduler/pool/loader/mcp.py | 1 - apps/scheduler/scheduler/scheduler.py | 2 +- apps/schemas/enum_var.py | 3 +++ apps/schemas/mcp.py | 3 ++- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 3 +++ apps/services/task.py | 9 --------- 10 files changed, 27 insertions(+), 41 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7fe5162c..589000be 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -36,11 +36,16 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - # 创建或还原Task + if post_body.new_task: + # 创建或还原Task + task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + if task: + await TaskManager.delete_task_by_task_id(task.id) task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) # 更改信息并刷新数据库 - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + if post_body.new_task: + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id return task diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 661e9ada..4e6a1bb7 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -35,7 +35,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) - @classmethod def info(cls) -> CallInfo: """ @@ -46,7 +45,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """ return CallInfo(name="MCP", description="调用MCP Server,执行工具") - async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" # 获取MCP交互类 @@ -63,7 +61,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 @@ -80,7 +77,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): async for chunk in self._generate_answer(): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 @@ -103,7 +99,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final @@ -141,7 +136,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 @@ -163,7 +157,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): ).model_dump(), ) - def _create_output( self, text: str, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f6814dd3..2ff4f3d3 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -7,6 +7,8 @@ from pydantic import Field from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.schemas.task import ExecutorState, StepQueueItem +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -15,26 +17,14 @@ class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" question: str = Field(description="用户输入") - max_steps: int = Field(default=10, description="最大步数") + max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - async def run(self) -> None: - """运行MCP Agent""" - agent = await MCPAgent.create( - servers_id=self.servers_id, - max_steps=self.max_steps, - task=self.task, - msg_queue=self.msg_queue, - question=self.question, - agent_id=self.agent_id, - description=self.agent_description, - ) - - try: - answer = await agent.run(self.question) - self.task = agent.task - self.task.runtime.answer = answer - except Exception as e: - logger.error(f"Error: {str(e)}") + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 66a516e7..1463d0a1 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -153,7 +153,6 @@ class MCPLoader(metaclass=SingletonMeta): # 检查目录 template_path = MCP_PATH / "template" / mcp_id await Path.mkdir(template_path, parents=True, exist_ok=True) - ProcessHandler.clear_finished_tasks() # 安装MCP模板 if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config): err = f"安装任务无法执行,请稍后重试: {mcp_id}" diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ed73638c..417f93d2 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -206,7 +206,7 @@ class Scheduler: task=self.task, msg_queue=queue, question=post_body.question, - max_steps=app_metadata.history_len, + history_len=app_metadata.history_len, servers_id=servers_id, background=background, agent_id=app_info.app_id, diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84..a84dc3a3 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" @@ -38,6 +39,8 @@ class EventType(str, Enum): TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" + STEP_WAITING_FOR_START = "step.waiting_for_start" + STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 44021b0e..60c8f17b 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" +import uuid from enum import Enum from typing import Any @@ -117,7 +118,7 @@ class MCPToolSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - + id: str = Field(default_factory=lambda: str(uuid.uuid4())) content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 5d6dc7fc..a3a8848c 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -46,6 +46,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + new_task: bool = Field(default=True, description="是否新建任务") class QuestionBlacklistRequest(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 8efcb599..37fdebbf 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from apps.schemas.enum_var import StepStatus from apps.schemas.flow import Step +from apps.schemas.mcp import MCPPlan class FlowStepHistory(BaseModel): @@ -42,6 +43,7 @@ class ExecutorState(BaseModel): # 附加信息 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -75,6 +77,7 @@ class TaskRuntime(BaseModel): summary: str = Field(description="摘要", default="") filled: dict[str, Any] = Field(description="填充的槽位", default={}) documents: list[dict[str, Any]] = Field(description="文档列表", default=[]) + temporary_plans: MCPPlan | None = Field(description="临时计划列表", default=None) class Task(BaseModel): diff --git a/apps/services/task.py b/apps/services/task.py index 1e672be6..2456d96b 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -45,7 +45,6 @@ class TaskManager: return Task.model_validate(task) - @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" @@ -58,7 +57,6 @@ class TaskManager: task = await task_collection.find_one({"_id": record_group_obj.task_id}) return Task.model_validate(task) - @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" @@ -68,7 +66,6 @@ class TaskManager: return None return Task.model_validate(task) - @staticmethod async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: """根据record_group_id获取flow信息""" @@ -95,7 +92,6 @@ class TaskManager: else: return flow_context_list - @staticmethod async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: """根据task_id获取flow信息""" @@ -115,7 +111,6 @@ class TaskManager: else: return flow_context - @staticmethod async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: """保存flow信息到flow_context""" @@ -137,7 +132,6 @@ class TaskManager: except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") - @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" @@ -148,7 +142,6 @@ class TaskManager: if task: await task_collection.delete_one({"_id": task_id}) - @staticmethod async def delete_tasks_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" @@ -167,7 +160,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod async def get_task( cls, @@ -212,7 +204,6 @@ class TaskManager: runtime=TaskRuntime(), ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee From b929ed4ead156f6d7008ec575830375b7beb0218 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 12:16:10 +0800 Subject: [PATCH 08/60] =?UTF-8?q?=E9=80=82=E9=85=8Drag=E7=9A=84=E4=BD=9C?= =?UTF-8?q?=E8=80=85=E5=92=8C=E5=88=9B=E5=BB=BA=E6=97=B6=E9=97=B4=E7=9A=84?= =?UTF-8?q?=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/context.py | 3 +++ apps/scheduler/scheduler/message.py | 9 +-------- apps/schemas/record.py | 6 +++++- apps/services/document.py | 4 ++++ apps/services/rag.py | 3 +++ 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index b7088d8d..4c2c4cf0 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -114,11 +114,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: used_docs.append( RecordGroupDocument( _id=docs["id"], + author=docs.get("author", ""), + order=docs.get("order", 0), name=docs["name"], abstract=docs.get("abstract", ""), extension=docs.get("extension", ""), size=docs.get("size", 0), associated="answer", + created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)), ) ) if docs.get("order") is not None: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index c89fdd10..5997b48f 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -71,14 +71,7 @@ async def push_rag_message( # 如果是文本消息,直接拼接到答案中 full_answer += content_obj.content elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append({ - "id": content_obj.content.get("id", ""), - "order": content_obj.content.get("order", 0), - "name": content_obj.content.get("name", ""), - "abstract": content_obj.content.get("abstract", ""), - "extension": content_obj.content.get("extension", ""), - "size": content_obj.content.get("size", 0), - }) + task.runtime.documents.append(content_obj.content) # 保存答案 task.runtime.answer = full_answer await TaskManager.save_task(task.id, task) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d7acd368..e1a995f8 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -17,8 +17,9 @@ class RecordDocument(Document): """GET /api/record/{conversation_id} Result中的document数据结构""" id: str = Field(alias="_id", default="") + order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None = None + author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] class Config: @@ -103,11 +104,14 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + order: int = Field(default=0, description="文档顺序") + author: str = Field(default="", description="文档作者") name: str = Field(default="", description="文档名称") abstract: str = Field(default="", description="文档摘要") extension: str = Field(default="", description="文档扩展名") size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] + created_at: float = Field(default=0.0, description="文档创建时间") class Record(RecordData): diff --git a/apps/services/document.py b/apps/services/document.py index 203162da..451423a9 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -2,6 +2,7 @@ """文件Manager""" import base64 +from datetime import UTC, datetime import logging import uuid @@ -131,12 +132,15 @@ class DocumentManager: return [ RecordDocument( _id=doc.id, + order=doc.order, + author=doc.author, abstract=doc.abstract, name=doc.name, type=doc.extension, size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, + created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] diff --git a/apps/services/rag.py b/apps/services/rag.py index 6b6c843d..efbdfe94 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """对接Euler Copilot RAG""" +from datetime import UTC, datetime import json import logging from collections.abc import AsyncGenerator @@ -156,9 +157,11 @@ class RAG: "id": doc_chunk["docId"], "order": doc_cnt, "name": doc_chunk.get("docName", ""), + "author": doc_chunk.get("docAuthor", ""), "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), + "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] -- Gitee From 755d5c315b6d9c0d3833383a790b188b0f53b28d Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:40:30 +0800 Subject: [PATCH 09/60] =?UTF-8?q?=E5=AF=B9rag=E7=9A=84team=E5=92=8C?= =?UTF-8?q?=E5=9B=A2=E9=98=9F=E8=8E=B7=E5=8F=96=E5=A2=9E=E5=8A=A0=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E6=8D=95=E8=8E=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/services/conversation.py | 6 +++++- apps/services/knowledge.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 4bcade45..bac964db 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,7 +59,11 @@ class ConversationManager: model_name=llm.model_name, ) kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except: + logger.error("[ConversationManager] 获取团队知识库列表失败") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 9b4077f9..bd8dfc9e 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -138,7 +138,11 @@ class KnowledgeBaseManager: return [] kb_ids_update_success = [] kb_item_dict_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except Exception as e: + logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: -- Gitee From d4ec596d65a0f0dfe4b4b442454da0a39b8b1325 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:53:07 +0800 Subject: [PATCH 10/60] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index e1a995f8..6e9617a3 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,6 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") + user_sub: None | None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From 5fa678be223c4ed7eead9744074fb5402584247d Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:56:50 +0800 Subject: [PATCH 11/60] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 6e9617a3..b5e1b0c5 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,7 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None | None + user_sub: None = None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From b407dc0ce2f530220448ac87544f861f3a8d4513 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 16:21:07 +0800 Subject: [PATCH 12/60] =?UTF-8?q?=E5=AE=8C=E5=96=84DocumentAddContent?= =?UTF-8?q?=E7=9A=84=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/message.py | 2 ++ apps/schemas/message.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 5997b48f..a2a45e41 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -108,10 +108,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl data=DocumentAddContent( documentId=content_obj.content.get("id", ""), documentOrder=content_obj.content.get("order", 0), + documentAuthor=content_obj.content.get("author", ""), documentName=content_obj.content.get("name", ""), documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), + createdAt=round(datetime.now(tz=UTC).timestamp(), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d0661224..5f0aee8a 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -2,7 +2,7 @@ """队列中的消息结构""" from typing import Any - +from datetime import UTC, datetime from pydantic import BaseModel, Field from apps.schemas.enum_var import EventType, StepStatus @@ -60,10 +60,14 @@ class DocumentAddContent(BaseModel): document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") + document_author: str = Field(description="文档作者", alias="documentAuthor", default="") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) + created_at: float = Field( + description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + ) class FlowStartContent(BaseModel): -- Gitee From af1b364f950f18d5c6165438ea9a04fe6697e070 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 19:22:36 +0800 Subject: [PATCH 13/60] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=97=A0=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/user.py | 8 +++++++ apps/routers/api_key.py | 1 + apps/routers/auth.py | 3 +-- apps/scheduler/executor/step.py | 38 ++++++++++++++++----------------- apps/schemas/config.py | 9 +++++++- apps/schemas/message.py | 2 ++ apps/schemas/request_data.py | 4 ++-- 7 files changed, 41 insertions(+), 24 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51..87cbd290 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -5,10 +5,12 @@ import logging from fastapi import Depends from fastapi.security import OAuth2PasswordBearer +import secrets from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config from apps.services.api_key import ApiKeyManager from apps.services.session import SessionManager @@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return secrets.token_hex(16) session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( @@ -69,6 +74,9 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return Config().get_config().no_auth.user_sub session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py index 158cfc13..51366a21 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/api_key.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse + from apps.dependency.user import get_user, verify_user from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp from apps.schemas.response_data import ResponseData diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed6..4a3f8293 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates +from apps.common.config import Config from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user from apps.schemas.collection import Audit @@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: user_info = await oidc_provider.get_oidc_user(token["access_token"]) user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) except Exception as e: logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa..506f3bb1 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -86,8 +86,8 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.task.state.step_id = self.step.step_id # type: ignore[arg-type] + self.task.state.step_name = self.step.step.name # type: ignore[arg-type] # 获取并验证Call类 node_id = self.step.step.node @@ -127,13 +127,13 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.task.state.step_id # type: ignore[arg-type] + current_step_name = self.task.state.step_name # type: ignore[arg-type] # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] + self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,17 +156,17 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] + self.task.state.step_id = current_step_id # type: ignore[arg-type] + self.task.state.step_name = current_step_name # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens @@ -212,7 +212,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,22 +224,22 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": str(e), "data": {}, } return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -253,12 +253,12 @@ class StepExecutor(BaseExecutor): # 更新context history = FlowStepHistory( task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + flow_name=self.task.state.flow_name, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + status=self.task.state.status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data, ) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1..99bcccde 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -6,6 +6,13 @@ from typing import Literal from pydantic import BaseModel, Field +class NoauthConfig(BaseModel): + """无认证配置""" + + enable: bool = Field(description="是否启用无认证访问", default=False) + user_sub: str = Field(description="调试用户的sub", default="admin") + + class DeployConfig(BaseModel): """部署配置""" @@ -122,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - + no_auth: NoauthConfig deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 5f0aee8a..cf70a82b 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -24,6 +24,8 @@ class MessageFlow(BaseModel): flow_id: str = Field(description="Flow ID", alias="flowId") step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") + sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) + sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2305dd93..5d6dc7fc 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -16,8 +16,8 @@ class RequestDataApp(BaseModel): """模型对话中包含的app信息""" app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") + params: dict[str, Any] | None = Field(default=None, description="插件参数") class MockRequestData(BaseModel): -- Gitee From 9653ef8e6c3d6ba15aba377b9721a1f605f7cff7 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 21:13:04 +0800 Subject: [PATCH 14/60] =?UTF-8?q?=E5=AE=8C=E5=96=84Agent=E7=9A=84=E5=BC=80?= =?UTF-8?q?=E5=8F=91&=E4=BF=AE=E5=A4=8Dmcp=E6=B3=A8=E5=86=8C=E6=97=B6?= =?UTF-8?q?=E5=80=99=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 11 ++++++++--- apps/scheduler/call/mcp/mcp.py | 7 ------- apps/scheduler/executor/agent.py | 28 +++++++++------------------ apps/scheduler/scheduler/scheduler.py | 2 +- apps/schemas/enum_var.py | 3 +++ apps/schemas/mcp.py | 3 ++- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 3 +++ apps/services/task.py | 9 --------- 9 files changed, 27 insertions(+), 40 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7fe5162c..589000be 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -36,11 +36,16 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - # 创建或还原Task + if post_body.new_task: + # 创建或还原Task + task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + if task: + await TaskManager.delete_task_by_task_id(task.id) task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) # 更改信息并刷新数据库 - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + if post_body.new_task: + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id return task diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 661e9ada..4e6a1bb7 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -35,7 +35,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) - @classmethod def info(cls) -> CallInfo: """ @@ -46,7 +45,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """ return CallInfo(name="MCP", description="调用MCP Server,执行工具") - async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" # 获取MCP交互类 @@ -63,7 +61,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 @@ -80,7 +77,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): async for chunk in self._generate_answer(): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 @@ -103,7 +99,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final @@ -141,7 +136,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 @@ -163,7 +157,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): ).model_dump(), ) - def _create_output( self, text: str, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f6814dd3..2ff4f3d3 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -7,6 +7,8 @@ from pydantic import Field from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.schemas.task import ExecutorState, StepQueueItem +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -15,26 +17,14 @@ class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" question: str = Field(description="用户输入") - max_steps: int = Field(default=10, description="最大步数") + max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - async def run(self) -> None: - """运行MCP Agent""" - agent = await MCPAgent.create( - servers_id=self.servers_id, - max_steps=self.max_steps, - task=self.task, - msg_queue=self.msg_queue, - question=self.question, - agent_id=self.agent_id, - description=self.agent_description, - ) - - try: - answer = await agent.run(self.question) - self.task = agent.task - self.task.runtime.answer = answer - except Exception as e: - logger.error(f"Error: {str(e)}") + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ed73638c..417f93d2 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -206,7 +206,7 @@ class Scheduler: task=self.task, msg_queue=queue, question=post_body.question, - max_steps=app_metadata.history_len, + history_len=app_metadata.history_len, servers_id=servers_id, background=background, agent_id=app_info.app_id, diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84..a84dc3a3 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" @@ -38,6 +39,8 @@ class EventType(str, Enum): TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" + STEP_WAITING_FOR_START = "step.waiting_for_start" + STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 44021b0e..60c8f17b 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" +import uuid from enum import Enum from typing import Any @@ -117,7 +118,7 @@ class MCPToolSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - + id: str = Field(default_factory=lambda: str(uuid.uuid4())) content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 5d6dc7fc..a3a8848c 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -46,6 +46,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + new_task: bool = Field(default=True, description="是否新建任务") class QuestionBlacklistRequest(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 8efcb599..37fdebbf 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from apps.schemas.enum_var import StepStatus from apps.schemas.flow import Step +from apps.schemas.mcp import MCPPlan class FlowStepHistory(BaseModel): @@ -42,6 +43,7 @@ class ExecutorState(BaseModel): # 附加信息 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -75,6 +77,7 @@ class TaskRuntime(BaseModel): summary: str = Field(description="摘要", default="") filled: dict[str, Any] = Field(description="填充的槽位", default={}) documents: list[dict[str, Any]] = Field(description="文档列表", default=[]) + temporary_plans: MCPPlan | None = Field(description="临时计划列表", default=None) class Task(BaseModel): diff --git a/apps/services/task.py b/apps/services/task.py index 1e672be6..2456d96b 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -45,7 +45,6 @@ class TaskManager: return Task.model_validate(task) - @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" @@ -58,7 +57,6 @@ class TaskManager: task = await task_collection.find_one({"_id": record_group_obj.task_id}) return Task.model_validate(task) - @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" @@ -68,7 +66,6 @@ class TaskManager: return None return Task.model_validate(task) - @staticmethod async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: """根据record_group_id获取flow信息""" @@ -95,7 +92,6 @@ class TaskManager: else: return flow_context_list - @staticmethod async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: """根据task_id获取flow信息""" @@ -115,7 +111,6 @@ class TaskManager: else: return flow_context - @staticmethod async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: """保存flow信息到flow_context""" @@ -137,7 +132,6 @@ class TaskManager: except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") - @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" @@ -148,7 +142,6 @@ class TaskManager: if task: await task_collection.delete_one({"_id": task_id}) - @staticmethod async def delete_tasks_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" @@ -167,7 +160,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod async def get_task( cls, @@ -212,7 +204,6 @@ class TaskManager: runtime=TaskRuntime(), ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee From 23fd74f6c25e2dd98bb487b485e42663da1a2369 Mon Sep 17 00:00:00 2001 From: zengxianghuai Date: Fri, 25 Jul 2025 17:34:51 +0800 Subject: [PATCH 15/60] =?UTF-8?q?=E5=A2=9E=E5=8A=A0choice=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E4=BB=A5=E5=8F=8A=E7=9B=B8=E5=85=B3=E8=8E=B7=E5=8F=96?= =?UTF-8?q?step=5Fid=E5=92=8C=E6=93=8D=E4=BD=9C=E7=AC=A6=E7=9A=84=E8=B7=AF?= =?UTF-8?q?=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/parameter.py | 143 ++++++++++++++ apps/scheduler/call/choice/choice.py | 79 +++++++- .../call/choice/condition_handler.py | 184 ++++++++++++++++++ apps/scheduler/call/choice/schema.py | 66 +++++++ apps/scheduler/executor/flow.py | 9 + apps/schemas/response_data.py | 18 ++ apps/services/flow.py | 26 ++- 7 files changed, 516 insertions(+), 9 deletions(-) create mode 100644 apps/routers/parameter.py create mode 100644 apps/scheduler/call/choice/condition_handler.py diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py new file mode 100644 index 00000000..33789dfd --- /dev/null +++ b/apps/routers/parameter.py @@ -0,0 +1,143 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.scheduler.call.choice.choice import Choice +from apps.schemas.response_data import ( + FlowStructureGetMsg, + FlowStructureGetRsp, + GetParamsMsg, + GetParamsRsp, + ResponseData, +) +from apps.services.application import AppManager +from apps.services.flow import FlowManager + +router = APIRouter( + prefix="/api/parameter", + tags=["parameter"], + dependencies=[ + Depends(verify_user), + ], +) + + +@router.get("", response_model={ + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + },) +async def get_parameters( + user_sub: Annotated[str, Depends(get_user)], + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], + step_id: Annotated[str, Query(alias="stepId")], +) -> JSONResponse: + """Get parameters for node choice.""" + if not await AppManager.validate_user_app_access(user_sub, app_id): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content=FlowStructureGetRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=FlowStructureGetMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ) + flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + if not flow: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=FlowStructureGetRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该流", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + result = await FlowManager.get_step_by_flow_and_step_id(flow, step_id) + if not result: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=FlowStructureGetRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该节点", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetParamsRsp( + code=status.HTTP_200_OK, + message="获取参数成功", + result=GetParamsMsg(result=result), + ).model_dump(exclude_none=True, by_alias=True), + ) + +async def operate_parameters(operate: str) -> list[str] | None: + """ + 根据操作类型获取对应的操作符参数列表 + + Args: + operate: 操作类型,支持 'int', 'str', 'bool' + + Returns: + 对应的操作符参数列表,若类型不支持则返回None + + """ + string = [ + "equal", + "not_equal", + "great",#长度大于 + "great_equals",#长度大于等于 + "less",#长度小于 + "less_equals",#长度小于等于 + "greater", + "greater_equals", + "smaller", + "smaller_equals", + ] + integer = [ + "equal", + "not_equal", + "great", + "great_equals", + "less", + "less_equals", + ] + boolen = ["equal", "not_equal", "is_empty", "not_empty"] + if operate in string: + return string + if operate in integer: + return integer + if operate in boolen: + return boolen + return None + +@router.get("/operate", response_model={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + },) +async def get_operate_parameters( + user_sub: Annotated[str, Depends(get_user)], + operate: Annotated[str, Query(alias="operate")] +) -> JSONResponse: + """Get parameters for node choice.""" + result = await operate_parameters(operate) + if not result: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="未找到该符号", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="获取参数成功", + result=result, + ).model_dump(exclude_none=True, by_alias=True), + ) \ No newline at end of file diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index a5edf21a..e9cd7d41 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,19 +1,82 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" -from enum import Enum - -from apps.scheduler.call.choice.schema import ChoiceInput, ChoiceOutput -from apps.scheduler.call.core import CoreCall +import logging +from collections.abc import AsyncGenerator +from typing import Any +from pydantic import Field -class Operator(str, Enum): - """Choice工具支持的运算符""" +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + ChoiceInput, + ChoiceOutput, + Logic, +) +from apps.scheduler.call.core import CoreCall +from apps.schemas.enum_var import CallOutputType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) - pass +logger = logging.getLogger(__name__) class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" - pass + to_user: bool = Field(default=False) + choices: list[ChoiceBranch] = Field(description="分支", default=[]) + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") + + async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: + """替换choices中的系统变量""" + valid_choices = [] + for choice in self.choices: + if choice.logic not in [Logic.AND, Logic.OR]: + logger.warning("分支 %s 的逻辑运算符 %s 无效,已跳过",choice.branch_id, choice.logic) + continue + valid_conditions = [] + for condition in choice.conditions: + if condition.left.step_id: + condition.left.value = self._extract_history_variables( + condition.left.step_id, call_vars.history, + ) + valid_conditions.append(condition) + else: + logger.warning("条件 %s 的左侧变量 %s 无效,已跳过", condition.condition_id, condition.left) + continue + choice.conditions = valid_conditions + valid_choices.append(choice.dict()) + return valid_choices + + async def _init(self, call_vars: CallVars) -> ChoiceInput: + """初始化Choice工具""" + return ChoiceInput( + choices=await self._prepare_message(call_vars), + ) + + async def _exec( + self, input_data: dict[str, Any] + ) -> AsyncGenerator[CallOutputChunk, None]: + """执行Choice工具""" + # 解析输入数据 + data = ChoiceInput(**input_data) + ret: CallOutputChunk = CallOutputChunk( + type=CallOutputType.DATA, + content=None, + ) + condition_handler = ConditionHandler() + try: + ret.content = condition_handler.handler(data.choices) + yield ret + except Exception as e: + raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py new file mode 100644 index 00000000..feab5c64 --- /dev/null +++ b/apps/scheduler/call/choice/condition_handler.py @@ -0,0 +1,184 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""处理条件分支的工具""" + +import logging + +from pydantic import BaseModel + +from apps.scheduler.call.choice.schema import ChoiceBranch, ChoiceOutput, Condition, Logic, Operator, Type, Value + +logger = logging.getLogger(__name__) + + +class ConditionHandler(BaseModel): + """条件分支处理器""" + + def handler(self, choices: list[ChoiceBranch]) -> ChoiceOutput: + """处理条件""" + default_branch = [c for c in choices if c.is_default] + + for block_judgement in choices: + results = [] + if block_judgement.is_default: + continue + for condition in block_judgement.conditions: + result = self._judge_condition(condition) + results.append(result) + if block_judgement.logic == Logic.AND: + final_result = all(results) + elif block_judgement.logic == Logic.OR: + final_result = any(results) + + if final_result: + return { + "branch_id": block_judgement.branch_id, + "message": f"选择分支:{block_judgement.branch_id}", + } + + # 如果没有匹配的分支,选择默认分支 + if default_branch: + return { + "branch_id": default_branch[0].branch_id, + "message": f"选择默认分支:{default_branch[0].branch_id}", + } + return { + "branch_id": "", + "message": "没有匹配的分支,且没有默认分支", + } + + def _judge_condition(self, condition: Condition) -> bool: + """ + 判断条件是否成立。 + + Args: + condition (Condition): 'left', 'operator', 'right', 'type' + + Returns: + bool + + """ + left = condition.left + operator = condition.operator + right = condition.right + value_type = condition.type + + result = None + if value_type == Type.STRING: + result = self._judge_string_condition(left, operator, right) + elif value_type == Type.INT: + result = self._judge_int_condition(left, operator, right) + elif value_type == Type.BOOL: + result = self._judge_bool_condition(left, operator, right) + else: + logger.error("不支持的数据类型: %s", value_type) + msg = f"不支持的数据类型: {value_type}" + raise ValueError(msg) + return result + + def _judge_string_condition(self, left: Value, operator: Operator, right: Value) -> bool: + """ + 判断字符串类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operator (Operator): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, str): + logger.error("左值不是字符串类型: %s", left_value) + msg = "左值必须是字符串类型" + raise TypeError(msg) + right_value = right.value + result = False + if operator == Operator.EQUAL: + result = left_value == right_value + elif operator == Operator.NEQUAL: + result = left_value != right_value + elif operator == Operator.GREAT: + result = len(left_value) > len(right_value) + elif operator == Operator.GREAT_EQUALS: + result = len(left_value) >= len(right_value) + elif operator == Operator.LESS: + result = len(left_value) < len(right_value) + elif operator == Operator.LESS_EQUALS: + result = len(left_value) <= len(right_value) + elif operator == Operator.GREATER: + result = left_value > right_value + elif operator == Operator.GREATER_EQUALS: + result = left_value >= right_value + elif operator == Operator.SMALLER: + result = left_value < right_value + elif operator == Operator.SMALLER_EQUALS: + result = left_value <= right_value + elif operator == Operator.CONTAINS: + result = right_value in left_value + elif operator == Operator.NOT_CONTAINS: + result = right_value not in left_value + return result + + def _judge_int_condition(self, left: Value, operator: Operator, right: Value) -> bool: # noqa: PLR0911 + """ + 判断整数类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operator (Operator): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, int): + logger.error("左值不是整数类型: %s", left_value) + msg = "左值必须是整数类型" + raise TypeError(msg) + right_value = right.value + if operator == Operator.EQUAL: + return left_value == right_value + if operator == Operator.NEQUAL: + return left_value != right_value + if operator == Operator.GREAT: + return left_value > right_value + if operator == Operator.GREAT_EQUALS: + return left_value >= right_value + if operator == Operator.LESS: + return left_value < right_value + if operator == Operator.LESS_EQUALS: + return left_value <= right_value + return False + + def _judge_bool_condition(self, left: Value, operator: Operator, right: Value) -> bool: + """ + 判断布尔类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operator (Operator): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, bool): + logger.error("左值不是布尔类型: %s", left_value) + msg = "左值必须是布尔类型" + raise TypeError(msg) + right_value = right.value + if operator == Operator.EQUAL: + return left_value == right_value + if operator == Operator.NEQUAL: + return left_value != right_value + if operator == Operator.IS_EMPTY: + return left_value == "" + if operator == Operator.NOT_EMPTY: + return left_value != "" + return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index 60b62d09..ed1c628c 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,12 +1,78 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +from enum import Enum + +from pydantic import BaseModel, Field + from apps.scheduler.call.core import DataBase +class Operator(str, Enum): + """Choice Call支持的运算符""" + + EQUAL = "equal" + NEQUAL = "not_equal" + GREAT = "great" + GREAT_EQUALS = "great_equals" + LESS = "less" + LESS_EQUALS = "less_equals" + # string + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + GREATER = "greater" + GREATER_EQUALS = "greater_equals" + SMALLER = "smaller" + SMALLER_EQUALS = "smaller_equals" + # bool + IS_EMPTY = "is_empty" + NOT_EMPTY = "not_empty" + + +class Logic(str, Enum): + """Choice 工具支持的逻辑运算符""" + + AND = "and" + OR = "or" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + INT = "int" + BOOL = "bool" + + +class Value(BaseModel): + """值的结构""" + + step_id: str = Field(description="步骤id", default="") + value: str | int | bool = Field(description="值", default=None) + + +class Condition(BaseModel): + """单个条件""" + + type: Type = Field(description="值的类型", default=Type.STRING) + left: Value = Field(description="左值") + right: Value = Field(description="右值") + operator: Operator = Field(description="运算符", default="equal") + id: int = Field(description="条件ID") + + +class ChoiceBranch(BaseModel): + """子分支""" + + branch_id: str = Field(description="分支ID", default="") + logic: Logic = Field(description="逻辑运算符", default=Logic.AND) + conditions: list[Condition] = Field(description="条件列表", default=[]) + is_default: bool = Field(description="是否为默认分支", default=False) class ChoiceInput(DataBase): """Choice Call的输入""" + choices: list[ChoiceBranch] = Field(description="分支", default=[]) + class ChoiceOutput(DataBase): """Choice Call的输出""" diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a70d0d70..d8d22c46 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -117,6 +117,15 @@ class FlowExecutor(BaseExecutor): # 如果当前步骤为结束,则直接返回 if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] + if self.task.state.step_name == "Choice": + # 如果是choice节点,获取分支ID + branch_id = self.task.context[-1]["output_data"].get("branch_id", None) + if branch_id: + self.task.state.step_id = self.task.state.step_id + "." + branch_id + logger.info("[FlowExecutor] 分支ID:%s", branch_id) + else: + logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") + return [] next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index b2a18729..20d7ad9b 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -628,3 +628,21 @@ class ListLLMRsp(ResponseData): """GET /api/llm 返回数据结构""" result: list[LLMProviderInfo] = Field(default=[], title="Result") + +class Params(BaseModel): + """参数数据结构""" + + id: str = Field(..., description="StepID") + name: str = Field(..., description="Step名称") + parameters: dict[str, Any] = Field(..., description="参数") + operate: str = Field(..., description="比较符") + +class GetParamsMsg(BaseModel): + """GET /api/params 返回数据结构""" + + result: list[Params] = Field(..., title="Result") + +class GetParamsRsp(ResponseData): + """GET /api/params 返回数据结构""" + + result: GetParamsMsg \ No newline at end of file diff --git a/apps/services/flow.py b/apps/services/flow.py index 9275fc60..0a253c31 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,7 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager - +from apps.schemas.response_data import Params logger = logging.getLogger(__name__) @@ -470,3 +470,27 @@ class FlowManager: return False else: return True + + @staticmethod + async def get_step_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[Params] | None: + """ + 寻找stepID对应的节点之前的所有节点的参数 + """ + params = [] + try: + for edge in flow.edges: + if edge.target_node == step_id: + id = edge.source_node + if id == "start": + break + params.append(Params( + id = id, + name = flow.nodes[id].name, + parameters = flow.nodes[id].parameters.get("parameters", {}) + )) + step_id = edge.source_node + return params + except Exception: + logger.exception("[FlowManager] 获取节点失败") + return None + \ No newline at end of file -- Gitee From 884000ea5b1287f8d86b9a137d116ed503e4e3c7 Mon Sep 17 00:00:00 2001 From: zengxianghuai Date: Fri, 25 Jul 2025 17:55:41 +0800 Subject: [PATCH 16/60] =?UTF-8?q?=E5=A2=9E=E5=8A=A0choice=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E6=8D=95=E8=8E=B7=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/choice/choice.py | 53 ++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index e9cd7d41..24df0dae 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -37,25 +37,50 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """返回Call的名称和描述""" return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") + def _raise_value_error(self, msg: str) -> None: + """统一处理 ValueError 异常抛出""" + logger.warning(msg) + raise ValueError(msg) + async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: """替换choices中的系统变量""" valid_choices = [] + for choice in self.choices: - if choice.logic not in [Logic.AND, Logic.OR]: - logger.warning("分支 %s 的逻辑运算符 %s 无效,已跳过",choice.branch_id, choice.logic) + try: + # 验证逻辑运算符 + if choice.logic not in [Logic.AND, Logic.OR]: + msg = f"无效的逻辑运算符: {choice.logic}" + self._raise_value_error(msg) + + valid_conditions = [] + for condition in choice.conditions: + # 处理左值 + if condition.left.step_id: + condition.left.value = self._extract_history_variables(condition.left.step_id, call_vars.history) + # 检查历史变量是否成功提取 + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + self._raise_value_error(msg) + else: + msg = "左侧变量缺少step_id" + self._raise_value_error(msg) + + valid_conditions.append(condition) + + # 如果所有条件都无效,抛出异常 + if not valid_conditions: + msg = "分支没有有效条件" + self._raise_value_error(msg) + + # 更新有效条件 + choice.conditions = valid_conditions + valid_choices.append(choice.dict()) + + except ValueError as e: + logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) continue - valid_conditions = [] - for condition in choice.conditions: - if condition.left.step_id: - condition.left.value = self._extract_history_variables( - condition.left.step_id, call_vars.history, - ) - valid_conditions.append(condition) - else: - logger.warning("条件 %s 的左侧变量 %s 无效,已跳过", condition.condition_id, condition.left) - continue - choice.conditions = valid_conditions - valid_choices.append(choice.dict()) + return valid_choices async def _init(self, call_vars: CallVars) -> ChoiceInput: -- Gitee From b51ed5f71018454a29c6dc0bee93593f4c92e177 Mon Sep 17 00:00:00 2001 From: zengxianghuai Date: Fri, 25 Jul 2025 18:45:20 +0800 Subject: [PATCH 17/60] =?UTF-8?q?=E4=BF=AE=E6=94=B9get=5Fparams=5Fby=5Fflo?= =?UTF-8?q?w=5Fand=5Fstep=5Fid=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/parameter.py | 4 +-- apps/services/flow.py | 63 ++++++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 33789dfd..538c2556 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -55,7 +55,7 @@ async def get_parameters( result={}, ).model_dump(exclude_none=True, by_alias=True), ) - result = await FlowManager.get_step_by_flow_and_step_id(flow, step_id) + result = await FlowManager.get_params_by_flow_and_step_id(flow, step_id) if not result: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -120,7 +120,7 @@ async def operate_parameters(operate: str) -> list[str] | None: },) async def get_operate_parameters( user_sub: Annotated[str, Depends(get_user)], - operate: Annotated[str, Query(alias="operate")] + operate: Annotated[str, Query(alias="operate")], ) -> JSONResponse: """Get parameters for node choice.""" result = await operate_parameters(operate) diff --git a/apps/services/flow.py b/apps/services/flow.py index 0a253c31..cb8ad57f 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -472,25 +472,46 @@ class FlowManager: return True @staticmethod - async def get_step_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[Params] | None: - """ - 寻找stepID对应的节点之前的所有节点的参数 - """ + async def get_params_by_flow_and_step_id( + flow: FlowItem, step_id: str + ) -> list[Params] | None: + """递归收集指定节点之前所有路径上的节点参数""" params = [] - try: - for edge in flow.edges: - if edge.target_node == step_id: - id = edge.source_node - if id == "start": - break - params.append(Params( - id = id, - name = flow.nodes[id].name, - parameters = flow.nodes[id].parameters.get("parameters", {}) - )) - step_id = edge.source_node - return params - except Exception: - logger.exception("[FlowManager] 获取节点失败") - return None - \ No newline at end of file + collected = set() # 记录已收集参数的节点 + + async def backtrack(current_id: str, visited: set) -> None: + # 避免循环递归 + if current_id in visited: + return + visited.add(current_id) + + # 获取所有指向当前节点的边 + incoming_edges = [ + edge for edge in flow.edges if edge.target_node == current_id + ] + + for edge in incoming_edges: + source_id = edge.source_node + + # 跳过起始节点 + if source_id == "start": + continue + + # 收集当前节点的参数(如果未被收集过) + if source_id not in collected: + node = flow.nodes.get(source_id) + if node: + collected.add(source_id) + params.append( + Params( + id=source_id, + name=node.name, + parameters=node.parameters.get("parameters", {}), + ), + ) + + # 继续回溯,传递当前路径的visited集合副本 + await backtrack(source_id, visited.copy()) + + await backtrack(step_id, set()) + return params -- Gitee From e4dfb413f4147e9cee2fb570a89b05f9fdc57331 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 25 Jul 2025 19:09:10 +0800 Subject: [PATCH 18/60] =?UTF-8?q?=E5=AE=8C=E5=96=84mcp=20agent=E7=9A=84?= =?UTF-8?q?=E5=BC=80=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 2 +- apps/routers/record.py | 21 +- apps/scheduler/call/mcp/mcp.py | 2 +- apps/scheduler/executor/agent.py | 2 +- apps/scheduler/executor/flow.py | 28 +- apps/scheduler/executor/step.py | 18 +- apps/scheduler/mcp/host.py | 9 +- apps/scheduler/mcp_agent/__init__.py | 8 + apps/scheduler/mcp_agent/agent/base.py | 196 -------------- apps/scheduler/mcp_agent/agent/mcp.py | 81 ------ apps/scheduler/mcp_agent/agent/react.py | 35 --- apps/scheduler/mcp_agent/agent/toolcall.py | 238 ----------------- apps/scheduler/mcp_agent/host.py | 190 ++++++++++++++ apps/scheduler/mcp_agent/plan.py | 110 ++++++++ apps/scheduler/mcp_agent/prompt.py | 240 ++++++++++++++++++ apps/scheduler/mcp_agent/schema.py | 148 ----------- apps/scheduler/mcp_agent/select.py | 185 ++++++++++++++ apps/scheduler/mcp_agent/tool/__init__.py | 9 - apps/scheduler/mcp_agent/tool/base.py | 73 ------ apps/scheduler/mcp_agent/tool/terminate.py | 25 -- .../mcp_agent/tool/tool_collection.py | 55 ---- apps/scheduler/scheduler/context.py | 2 +- apps/schemas/enum_var.py | 11 + apps/schemas/record.py | 1 + apps/schemas/task.py | 12 +- tests/common/test_queue.py | 4 +- 26 files changed, 792 insertions(+), 913 deletions(-) create mode 100644 apps/scheduler/mcp_agent/__init__.py delete mode 100644 apps/scheduler/mcp_agent/agent/base.py delete mode 100644 apps/scheduler/mcp_agent/agent/mcp.py delete mode 100644 apps/scheduler/mcp_agent/agent/react.py delete mode 100644 apps/scheduler/mcp_agent/agent/toolcall.py create mode 100644 apps/scheduler/mcp_agent/host.py create mode 100644 apps/scheduler/mcp_agent/plan.py create mode 100644 apps/scheduler/mcp_agent/prompt.py delete mode 100644 apps/scheduler/mcp_agent/schema.py create mode 100644 apps/scheduler/mcp_agent/select.py delete mode 100644 apps/scheduler/mcp_agent/tool/__init__.py delete mode 100644 apps/scheduler/mcp_agent/tool/base.py delete mode 100644 apps/scheduler/mcp_agent/tool/terminate.py delete mode 100644 apps/scheduler/mcp_agent/tool/tool_collection.py diff --git a/apps/common/queue.py b/apps/common/queue.py index 5601c93a..911485b3 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -58,7 +58,7 @@ class MessageQueue: flowId=task.state.flow_id, stepId=task.state.step_id, stepName=task.state.step_name, - stepStatus=task.state.status, + stepStatus=task.state.step_status ) else: flow = None diff --git a/apps/routers/record.py b/apps/routers/record.py index 7384793b..f357f0de 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -83,22 +83,23 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) # 获得Record关联的flow数据 - flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) - if flow_list: - first_flow = FlowStepHistory.model_validate(flow_list[0]) + flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) + if flow_step_list: + first_step_history = FlowStepHistory.model_validate(flow_step_list[0]) tmp_record.flow = RecordFlow( - id=first_flow.flow_name, #TODO: 此处前端应该用name + id=first_step_history.flow_name, # TODO: 此处前端应该用name recordId=record.id, - flowId=first_flow.id, - stepNum=len(flow_list), + flowStatus=first_step_history.flow_status, + flowId=first_step_history.id, + stepNum=len(flow_step_list), steps=[], ) - for flow in flow_list: - flow_step = FlowStepHistory.model_validate(flow) + for flow_step in flow_step_list: + flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( RecordFlowStep( - stepId=flow_step.step_name, #TODO: 此处前端应该用name - stepStatus=flow_step.status, + stepId=flow_step.step_name, # TODO: 此处前端应该用name + stepStatus=flow_step.step_status, input=flow_step.input_data, output=flow_step.output_data, ), diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 4e6a1bb7..bd0257b4 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -31,7 +31,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """MCP工具""" mcp_list: list[str] = Field(description="MCP Server ID列表", max_length=5, min_length=1) - max_steps: int = Field(description="最大步骤数", default=6) + max_steps: int = Field(description="最大步骤数", default=20) text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 2ff4f3d3..cb8e183e 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -6,7 +6,7 @@ import logging from pydantic import Field from apps.scheduler.executor.base import BaseExecutor -from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.scheduler.mcp_agent import host, plan, select from apps.schemas.task import ExecutorState, StepQueueItem from apps.services.task import TaskManager diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a70d0d70..e5381a4a 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -11,7 +11,7 @@ from pydantic import Field from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.executor.step import StepExecutor -from apps.schemas.enum_var import EventType, SpecialCallType, StepStatus +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp from apps.schemas.task import ExecutorState, StepQueueItem @@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -59,8 +58,9 @@ class FlowExecutor(BaseExecutor): self.task.state = ExecutorState( flow_id=str(self.flow_id), flow_name=self.flow.name, + flow_status=FlowStatus.RUNNING, description=str(self.flow.description), - status=StepStatus.RUNNING, + step_status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", step_name="开始", @@ -70,7 +70,6 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" # 创建步骤Runner @@ -90,7 +89,6 @@ class FlowExecutor(BaseExecutor): # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: @@ -102,7 +100,6 @@ class FlowExecutor(BaseExecutor): # 执行Step await self._invoke_runner(queue_item) - async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" next_ids = [] @@ -111,14 +108,13 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -137,7 +133,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -150,8 +145,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -166,11 +161,11 @@ class FlowExecutor(BaseExecutor): # 插入首个步骤 self.step_queue.append(first_step) - + self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -183,13 +178,14 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), enable_filling=False, to_user=False, )) + self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] # 错误处理后结束 self._reached_end = True @@ -216,3 +212,5 @@ class FlowExecutor(BaseExecutor): self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 await self.push_message(EventType.FLOW_STOP.value) + # 更新Task状态 + self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 506f3bb1..377a4c6e 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -133,7 +132,7 @@ class StepExecutor(BaseExecutor): # 更新State self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,9 +155,9 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.step_status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 @@ -170,7 +169,6 @@ class StepExecutor(BaseExecutor): self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor): return content - async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) @@ -212,7 +209,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,7 +221,7 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.step_status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): self.task.state.error_info = { # type: ignore[arg-type] @@ -239,7 +236,7 @@ class StepExecutor(BaseExecutor): return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -255,10 +252,11 @@ class StepExecutor(BaseExecutor): task_id=self.task.id, flow_id=self.task.state.flow_id, # type: ignore[arg-type] flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_status=self.task.state.flow_status, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + step_status=self.task.state.step_status, input_data=self.obj.input, output_data=output_data, ) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 78aa7bc3..acdd4871 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -40,7 +40,6 @@ class MCPHost: lstrip_blocks=True, ) - async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() @@ -59,7 +58,6 @@ class MCPHost: logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) return None - async def assemble_memory(self) -> str: """组装记忆""" task = await TaskManager.get_task_by_task_id(self._task_id) @@ -78,7 +76,6 @@ class MCPHost: context_list=context_list, ) - async def _save_memory( self, tool: MCPTool, @@ -105,11 +102,12 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, + flow_status=StepStatus.SUCCESS, step_id=tool.name, step_name=tool.name, # description是规划的实际内容 step_description=plan_item.content, - status=StepStatus.SUCCESS, + step_status=StepStatus.SUCCESS, input_data=input_data, output_data=output_data, ) @@ -125,7 +123,6 @@ class MCPHost: return output_data - async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate @@ -146,7 +143,6 @@ class MCPHost: ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client @@ -170,7 +166,6 @@ class MCPHost: return processed_result - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: """获取工具列表""" mongo = MongoDB() diff --git a/apps/scheduler/mcp_agent/__init__.py b/apps/scheduler/mcp_agent/__init__.py new file mode 100644 index 00000000..12f5cb68 --- /dev/null +++ b/apps/scheduler/mcp_agent/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Scheduler MCP 模块""" + +from apps.scheduler.mcp.host import MCPHost +from apps.scheduler.mcp.plan import MCPPlanner +from apps.scheduler.mcp.select import MCPSelector + +__all__ = ["MCPHost", "MCPPlanner", "MCPSelector"] diff --git a/apps/scheduler/mcp_agent/agent/base.py b/apps/scheduler/mcp_agent/agent/base.py deleted file mode 100644 index eccb58a9..00000000 --- a/apps/scheduler/mcp_agent/agent/base.py +++ /dev/null @@ -1,196 +0,0 @@ -"""MCP Agent基类""" -import logging -from abc import ABC, abstractmethod -from contextlib import asynccontextmanager - -from pydantic import BaseModel, Field, model_validator - -from apps.common.queue import MessageQueue -from apps.schemas.enum_var import AgentState -from apps.schemas.task import Task -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.schema import Memory, Message, Role -from apps.services.activity import Activity - -logger = logging.getLogger(__name__) - - -class BaseAgent(BaseModel, ABC): - """ - 用于管理代理状态和执行的抽象基类。 - - 为状态转换、内存管理、 - 以及分步执行循环。子类必须实现`step`方法。 - """ - - msg_queue: MessageQueue - task: Task - name: str = Field(..., description="Agent名称") - agent_id: str = Field(default="", description="Agent ID") - description: str = Field(default="", description="Agent描述") - question: str - # Prompts - next_step_prompt: str | None = Field( - None, description="判断下一步动作的提示" - ) - - # Dependencies - llm: ReasoningLLM = Field(default_factory=ReasoningLLM, description="大模型实例") - memory: Memory = Field(default_factory=Memory, description="Agent记忆库") - state: AgentState = Field( - default=AgentState.IDLE, description="Agent状态" - ) - servers_id: list[str] = Field(default_factory=list, description="MCP server id") - - # Execution control - max_steps: int = Field(default=10, description="终止前的最大步长") - current_step: int = Field(default=0, description="执行中的当前步骤") - - duplicate_threshold: int = 2 - - user_prompt: str = r""" - 当前步骤:{step} 工具输出结果:{result} - 请总结当前正在执行的步骤和对应的工具输出结果,内容包括当前步骤是多少,执行的工具是什么,输出是什么。 - 最终以报告的形式展示。 - 如果工具输出结果中执行的工具为terminate,请按照状态输出本次交互过程最终结果并完成对整个报告的总结,不需要输出你的分析过程。 - """ - """用户提示词""" - - class Config: - arbitrary_types_allowed = True - extra = "allow" # Allow extra fields for flexibility in subclasses - - @model_validator(mode="after") - def initialize_agent(self) -> "BaseAgent": - """初始化Agent""" - if self.llm is None or not isinstance(self.llm, ReasoningLLM): - self.llm = ReasoningLLM() - if not isinstance(self.memory, Memory): - self.memory = Memory() - return self - - @asynccontextmanager - async def state_context(self, new_state: AgentState): - """ - Agent状态转换上下文管理器 - - Args: - new_state: 要转变的状态 - - :return: None - :raise ValueError: 如果new_state无效 - """ - if not isinstance(new_state, AgentState): - raise ValueError(f"无效状态: {new_state}") - - previous_state = self.state - self.state = new_state - try: - yield - except Exception as e: - self.state = AgentState.ERROR # Transition to ERROR on failure - raise e - finally: - self.state = previous_state # Revert to previous state - - def update_memory( - self, - role: Role, - content: str, - **kwargs, - ) -> None: - """添加信息到Agent的memory中""" - message_map = { - "user": Message.user_message, - "system": Message.system_message, - "assistant": Message.assistant_message, - "tool": lambda content, **kw: Message.tool_message(content, **kw), - } - - if role not in message_map: - raise ValueError(f"不支持的消息角色: {role}") - - # Create message with appropriate parameters based on role - kwargs = {**(kwargs if role == "tool" else {})} - self.memory.add_message(message_map[role](content, **kwargs)) - - async def run(self, request: str | None = None) -> str: - """异步执行Agent的主循环""" - self.task.runtime.question = request - if self.state != AgentState.IDLE: - raise RuntimeError(f"无法从以下状态运行智能体: {self.state}") - - if request: - self.update_memory("user", request) - - results: list[str] = [] - async with self.state_context(AgentState.RUNNING): - while ( - self.current_step < self.max_steps and self.state != AgentState.FINISHED - ): - if not await Activity.is_active(self.task.ids.user_sub): - logger.info("用户终止会话,任务停止!") - return "" - self.current_step += 1 - logger.info(f"执行步骤{self.current_step}/{self.max_steps}") - step_result = await self.step() - - # Check for stuck state - if self.is_stuck(): - self.handle_stuck_state() - result = f"Step {self.current_step}: {step_result}" - results.append(result) - - if self.current_step >= self.max_steps: - self.current_step = 0 - self.state = AgentState.IDLE - result = f"任务终止: 已达到最大步数 ({self.max_steps})" - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": result}, # type: ignore[arg-type] - ) - results.append(result) - return "\n".join(results) if results else "未执行任何步骤" - - @abstractmethod - async def step(self) -> str: - """ - 执行代理工作流程中的单个步骤。 - - 必须由子类实现,以定义具体的行为。 - """ - - def handle_stuck_state(self): - """通过添加更改策略的提示来处理卡住状态""" - stuck_prompt = "\ - 观察到重复响应。考虑新策略,避免重复已经尝试过的无效路径" - self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" - logger.warning(f"检测到智能体处于卡住状态。新增提示:{stuck_prompt}") - - def is_stuck(self) -> bool: - """通过检测重复内容来检查代理是否卡在循环中""" - if len(self.memory.messages) < 2: - return False - - last_message = self.memory.messages[-1] - if not last_message.content: - return False - - duplicate_count = sum( - 1 - for msg in reversed(self.memory.messages[:-1]) - if msg.role == "assistant" and msg.content == last_message.content - ) - - return duplicate_count >= self.duplicate_threshold - - @property - def messages(self) -> list[Message]: - """从Agent memory中检索消息列表""" - return self.memory.messages - - @messages.setter - def messages(self, value: list[Message]) -> None: - """设置Agent memory的消息列表""" - self.memory.messages = value diff --git a/apps/scheduler/mcp_agent/agent/mcp.py b/apps/scheduler/mcp_agent/agent/mcp.py deleted file mode 100644 index 378da368..00000000 --- a/apps/scheduler/mcp_agent/agent/mcp.py +++ /dev/null @@ -1,81 +0,0 @@ -"""MCP Agent""" -import logging - -from pydantic import Field - -from apps.scheduler.mcp.host import MCPHost -from apps.scheduler.mcp_agent.agent.toolcall import ToolCallAgent -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class MCPAgent(ToolCallAgent): - """ - 用于与MCP(模型上下文协议)服务器交互。 - - 使用SSE或stdio传输连接到MCP服务器 - 并使服务器的工具 - """ - - name: str = "MCPAgent" - description: str = "一个多功能的智能体,能够使用多种工具(包括基于MCP的工具)解决各种任务" - - # Add general-purpose tools to the tool collection - available_tools: ToolCollection = Field( - default_factory=lambda: ToolCollection( - Terminate(), - ), - ) - - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - _initialized: bool = False - - @classmethod - async def create(cls, **kwargs) -> "MCPAgent": # noqa: ANN003 - """创建并初始化MCP Agent实例""" - instance = cls(**kwargs) - await instance.initialize_mcp_servers() - instance._initialized = True - return instance - - async def initialize_mcp_servers(self) -> None: - """初始化与已配置的MCP服务器的连接""" - mcp_host = MCPHost( - self.task.ids.user_sub, - self.task.id, - self.agent_id, - self.description, - ) - mcps = {} - for mcp_id in self.servers_id: - client = await mcp_host.get_client(mcp_id) - if client: - mcps[mcp_id] = client - - for mcp_id, mcp_client in mcps.items(): - new_tools = [] - for tool in mcp_client.tools: - original_name = tool.name - # Always prefix with server_id to ensure uniqueness - tool_name = f"mcp_{mcp_id}_{original_name}" - - server_tool = MCPClientTool( - name=tool_name, - description=tool.description, - parameters=tool.inputSchema, - session=mcp_client.session, - server_id=mcp_id, - original_name=original_name, - ) - new_tools.append(server_tool) - self.available_tools.add_tools(*new_tools) - - async def think(self) -> bool: - """使用适当的上下文处理当前状态并决定下一步操作""" - if not self._initialized: - await self.initialize_mcp_servers() - self._initialized = True - - return await super().think() diff --git a/apps/scheduler/mcp_agent/agent/react.py b/apps/scheduler/mcp_agent/agent/react.py deleted file mode 100644 index b56efd8b..00000000 --- a/apps/scheduler/mcp_agent/agent/react.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC, abstractmethod - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.agent.base import BaseAgent -from apps.scheduler.mcp_agent.schema import Memory - - -class ReActAgent(BaseAgent, ABC): - name: str - description: str | None = None - - system_prompt: str | None = None - next_step_prompt: str | None = None - - llm: ReasoningLLM | None = Field(default_factory=ReasoningLLM) - memory: Memory = Field(default_factory=Memory) - state: AgentState = AgentState.IDLE - - @abstractmethod - async def think(self) -> bool: - """处理当前状态并决定下一步操作""" - - @abstractmethod - async def act(self) -> str: - """执行已决定的行动""" - - async def step(self) -> str: - """执行一个步骤:思考和行动""" - should_act = await self.think() - if not should_act: - return "思考完成-无需采取任何行动" - return await self.act() diff --git a/apps/scheduler/mcp_agent/agent/toolcall.py b/apps/scheduler/mcp_agent/agent/toolcall.py deleted file mode 100644 index 1e22099c..00000000 --- a/apps/scheduler/mcp_agent/agent/toolcall.py +++ /dev/null @@ -1,238 +0,0 @@ -import asyncio -import json -import logging -from typing import Any, Optional - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.function import JsonGenerator -from apps.llm.patterns import Select -from apps.scheduler.mcp_agent.agent.react import ReActAgent -from apps.scheduler.mcp_agent.schema import Function, Message, ToolCall -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class ToolCallAgent(ReActAgent): - """用于处理工具/函数调用的基本Agent类""" - - name: str = "toolcall" - description: str = "可以执行工具调用的智能体" - - available_tools: ToolCollection = ToolCollection( - Terminate(), - ) - tool_choices: str = "auto" - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - tool_calls: list[ToolCall] = Field(default_factory=list) - _current_base64_image: str | None = None - - max_observe: int | bool | None = None - - async def think(self) -> bool: - """使用工具处理当前状态并决定下一步行动""" - messages = [] - for message in self.messages: - if isinstance(message, Message): - message = message.to_dict() - messages.append(message) - try: - # 通过工具获得响应 - select_obj = Select() - choices = [] - for available_tool in self.available_tools.to_params(): - choices.append(available_tool.get("function")) - - tool = await select_obj.generate(question=self.question, choices=choices) - if tool in self.available_tools.tool_map: - schema = self.available_tools.tool_map[tool].parameters - json_generator = JsonGenerator( - query="根据跟定的信息,获取工具参数", - conversation=messages, - schema=schema, - ) # JsonGenerator - parameters = await json_generator.generate() - - else: - raise ValueError(f"尝试调用不存在的工具: {tool}") - except Exception as e: - raise - self.tool_calls = tool_calls = [ToolCall(id=tool, function=Function(name=tool, arguments=parameters))] - content = f"选择的执行工具为:{tool}, 参数为{parameters}" - - logger.info( - f"{self.name} 选择 {len(tool_calls) if tool_calls else 0}个工具执行" - ) - if tool_calls: - logger.info( - f"准备使用的工具: {[call.function.name for call in tool_calls]}" - ) - logger.info(f"工具参数: {tool_calls[0].function.arguments}") - - try: - - assistant_msg = ( - Message.from_tool_calls(content=content, tool_calls=self.tool_calls) - if self.tool_calls - else Message.assistant_message(content) - ) - self.memory.add_message(assistant_msg) - - if not self.tool_calls: - return bool(content) - - return bool(self.tool_calls) - except Exception as e: - logger.error(f"{self.name}的思考过程遇到了问题:: {e}") - self.memory.add_message( - Message.assistant_message( - f"处理时遇到错误: {str(e)}" - ) - ) - return False - - async def act(self) -> str: - """执行工具调用并处理其结果""" - if not self.tool_calls: - # 如果没有工具调用,则返回最后的消息内容 - return self.messages[-1].content or "没有要执行的内容或命令" - - results = [] - for command in self.tool_calls: - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"正在执行工具{command.function.name}"} - ) - - self._current_base64_image = None - - result = await self.execute_tool(command) - - if self.max_observe: - result = result[: self.max_observe] - - push_result = "" - async for chunk in self.llm.call( - messages=[{"role": "system", "content": "You are a helpful asistant."}, - {"role": "user", "content": self.user_prompt.format( - step=self.current_step, - result=result, - )}, ], streaming=False - ): - push_result += chunk - self.task.tokens.input_tokens += self.llm.input_tokens - self.task.tokens.output_tokens += self.llm.output_tokens - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": push_result}, # type: ignore[arg-type] - ) - - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"工具{command.function.name}执行完成"}, # type: ignore[arg-type] - ) - - logger.info( - f"工具'{command.function.name}'执行完成! 执行结果为: {result}" - ) - - # 将工具响应添加到内存 - tool_msg = Message.tool_message( - content=result, - tool_call_id=command.id, - name=command.function.name, - ) - self.memory.add_message(tool_msg) - results.append(result) - self.question += ( - f"\n已执行工具{command.function.name}, " - f"作用是{self.available_tools.tool_map[command.function.name].description},结果为{result}" - ) - - return "\n\n".join(results) - - async def execute_tool(self, command: ToolCall) -> str: - """执行单个工具调用""" - if not command or not command.function or not command.function.name: - return "错误:无效的命令格式" - - name = command.function.name - if name not in self.available_tools.tool_map: - return f"错误:未知工具 '{name}'" - - try: - # 解析参数 - args = command.function.arguments - # 执行工具 - logger.info(f"激活工具:'{name}'...") - result = await self.available_tools.execute(name=name, tool_input=args) - - # 执行特殊工具 - await self._handle_special_tool(name=name, result=result) - - # 格式化结果 - observation = ( - f"观察到执行的工具 `{name}`的输出:\n{str(result)}" - if result - else f"工具 `{name}` 已完成,无输出" - ) - - return observation - except json.JSONDecodeError: - error_msg = f"解析{name}的参数时出错:JSON格式无效" - logger.error( - f"{name}”的参数没有意义-无效的JSON,参数:{command.function.arguments}" - ) - return f"错误: {error_msg}" - except Exception as e: - error_msg = f"工具 '{name}' 遇到问题: {str(e)}" - logger.exception(error_msg) - return f"错误: {error_msg}" - - async def _handle_special_tool(self, name: str, result: Any, **kwargs): - """处理特殊工具的执行和状态变化""" - if not self._is_special_tool(name): - return - - if self._should_finish_execution(name=name, result=result, **kwargs): - # 将智能体状态设为finished - logger.info(f"特殊工具'{name}'已完成任务!") - self.state = AgentState.FINISHED - - @staticmethod - def _should_finish_execution(**kwargs) -> bool: - """确定工具执行是否应完成""" - return True - - def _is_special_tool(self, name: str) -> bool: - """检查工具名称是否在特殊工具列表中""" - return name.lower() in [n.lower() for n in self.special_tool_names] - - async def cleanup(self): - """清理Agent工具使用的资源。""" - logger.info(f"正在清理智能体的资源'{self.name}'...") - for tool_name, tool_instance in self.available_tools.tool_map.items(): - if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( - tool_instance.cleanup - ): - try: - logger.debug(f"清理工具: {tool_name}") - await tool_instance.cleanup() - except Exception as e: - logger.error( - f"清理工具时发生错误'{tool_name}': {e}", exc_info=True - ) - logger.info(f"智能体清理完成'{self.name}'.") - - async def run(self, request: Optional[str] = None) -> str: - """运行Agent""" - try: - return await super().run(request) - finally: - await self.cleanup() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py new file mode 100644 index 00000000..acdd4871 --- /dev/null +++ b/apps/scheduler/mcp_agent/host.py @@ -0,0 +1,190 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP宿主""" + +import json +import logging +from typing import Any + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment +from mcp.types import TextContent + +from apps.common.mongo import MongoDB +from apps.llm.function import JsonGenerator +from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE +from apps.scheduler.pool.mcp.client import MCPClient +from apps.scheduler.pool.mcp.pool import MCPPool +from apps.schemas.enum_var import StepStatus +from apps.schemas.mcp import MCPPlanItem, MCPTool +from apps.schemas.task import FlowStepHistory +from apps.services.task import TaskManager + +logger = logging.getLogger(__name__) + + +class MCPHost: + """MCP宿主服务""" + + def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None: + """初始化MCP宿主""" + self._user_sub = user_sub + self._task_id = task_id + # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description + self._runtime_id = runtime_id + self._runtime_name = runtime_name + self._context_list = [] + self._env = SandboxedEnvironment( + loader=BaseLoader(), + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + ) + + async def get_client(self, mcp_id: str) -> MCPClient | None: + """获取MCP客户端""" + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") + + # 检查用户是否启用了这个mcp + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + if not mcp_db_result: + logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + return None + + # 获取MCP配置 + try: + return await MCPPool().get(mcp_id, self._user_sub) + except KeyError: + logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) + return None + + async def assemble_memory(self) -> str: + """组装记忆""" + task = await TaskManager.get_task_by_task_id(self._task_id) + if not task: + logger.error("任务 %s 不存在", self._task_id) + return "" + + context_list = [] + for ctx_id in self._context_list: + context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + if not context: + continue + context_list.append(context) + + return self._env.from_string(MEMORY_TEMPLATE).render( + context_list=context_list, + ) + + async def _save_memory( + self, + tool: MCPTool, + plan_item: MCPPlanItem, + input_data: dict[str, Any], + result: str, + ) -> dict[str, Any]: + """保存记忆""" + try: + output_data = json.loads(result) + except Exception: # noqa: BLE001 + logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str") + output_data = { + "message": result, + } + + if not isinstance(output_data, dict): + output_data = { + "message": result, + } + + # 创建context;注意用法 + context = FlowStepHistory( + task_id=self._task_id, + flow_id=self._runtime_id, + flow_name=self._runtime_name, + flow_status=StepStatus.SUCCESS, + step_id=tool.name, + step_name=tool.name, + # description是规划的实际内容 + step_description=plan_item.content, + step_status=StepStatus.SUCCESS, + input_data=input_data, + output_data=output_data, + ) + + # 保存到task + task = await TaskManager.get_task_by_task_id(self._task_id) + if not task: + logger.error("任务 %s 不存在", self._task_id) + return {} + self._context_list.append(context.id) + task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + await TaskManager.save_task(self._task_id, task) + + return output_data + + async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: + """填充工具参数""" + # 更清晰的输入·指令,这样可以调用generate + llm_query = rf""" + 请使用参数生成工具,生成满足以下目标的工具参数: + + {query} + """ + + # 进行生成 + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": await self.assemble_memory()}, + ], + tool.input_schema, + ) + return await json_generator.generate() + + async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: + """调用工具""" + # 拿到Client + client = await MCPPool().get(tool.mcp_id, self._user_sub) + if client is None: + err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}" + logger.error(err) + raise ValueError(err) + + # 填充参数 + params = await self._fill_params(tool, plan_item.instruction) + # 调用工具 + result = await client.call_tool(tool.name, params) + # 保存记忆 + processed_result = [] + for item in result.content: + if not isinstance(item, TextContent): + logger.error("MCP结果类型不支持: %s", item) + continue + processed_result.append(await self._save_memory(tool, plan_item, params, item.text)) + + return processed_result + + async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: + """获取工具列表""" + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") + + # 获取工具列表 + tool_list = [] + for mcp_id in mcp_id_list: + # 检查用户是否启用了这个mcp + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + if not mcp_db_result: + logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + continue + # 获取MCP工具配置 + try: + for tool in mcp_db_result["tools"]: + tool_list.extend([MCPTool.model_validate(tool)]) + except KeyError: + logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) + continue + + return tool_list diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py new file mode 100644 index 00000000..cd4f5975 --- /dev/null +++ b/apps/scheduler/mcp_agent/plan.py @@ -0,0 +1,110 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP 用户目标拆解与规划""" + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm.function import JsonGenerator +from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER +from apps.schemas.mcp import MCPPlan, MCPTool + + +class MCPPlanner: + """MCP 用户目标拆解与规划""" + + def __init__(self, user_goal: str) -> None: + """初始化MCP规划器""" + self.user_goal = user_goal + self._env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, + ) + self.input_tokens = 0 + self.output_tokens = 0 + + + async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: + """规划下一步的执行流程,并输出""" + # 获取推理结果 + result = await self._get_reasoning_plan(tool_list, max_steps) + + # 解析为结构化数据 + return await self._parse_plan_result(result, max_steps) + + + async def _get_reasoning_plan(self, tool_list: list[MCPTool], max_steps: int) -> str: + """获取推理大模型的结果""" + # 格式化Prompt + template = self._env.from_string(CREATE_PLAN) + prompt = template.render( + goal=self.user_goal, + tools=tool_list, + max_num=max_steps, + ) + + # 调用推理大模型 + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + reasoning_llm = ReasoningLLM() + result = "" + async for chunk in reasoning_llm.call( + message, + streaming=False, + temperature=0.07, + result_only=True, + ): + result += chunk + + # 保存token用量 + self.input_tokens = reasoning_llm.input_tokens + self.output_tokens = reasoning_llm.output_tokens + + return result + + + async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: + """将推理结果解析为结构化数据""" + # 格式化Prompt + schema = MCPPlan.model_json_schema() + schema["properties"]["plans"]["maxItems"] = max_steps + + # 使用Function模型解析结果 + json_generator = JsonGenerator( + result, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": result}, + ], + schema, + ) + plan = await json_generator.generate() + return MCPPlan.model_validate(plan) + + + async def generate_answer(self, plan: MCPPlan, memory: str) -> str: + """生成最终回答""" + template = self._env.from_string(FINAL_ANSWER) + prompt = template.render( + plan=plan, + memory=memory, + goal=self.user_goal, + ) + + llm = ReasoningLLM() + result = "" + async for chunk in llm.call( + [{"role": "user", "content": prompt}], + streaming=False, + temperature=0.07, + ): + result += chunk + + self.input_tokens = llm.input_tokens + self.output_tokens = llm.output_tokens + + return result diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py new file mode 100644 index 00000000..b322fb08 --- /dev/null +++ b/apps/scheduler/mcp_agent/prompt.py @@ -0,0 +1,240 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP相关的大模型Prompt""" + +from textwrap import dedent + +MCP_SELECT = dedent(r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,选择最合适的MCP Server。 + + ## 选择MCP Server时的注意事项: + + 1. 确保充分理解当前目标,选择最合适的MCP Server。 + 2. 请在给定的MCP Server列表中选择,不要自己生成MCP Server。 + 3. 请先给出你选择的理由,再给出你的选择。 + 4. 当前目标将在下面给出,MCP Server列表也会在下面给出。 + 请将你的思考过程放在"思考过程"部分,将你的选择放在"选择结果"部分。 + 5. 选择必须是JSON格式,严格按照下面的模板,不要输出任何其他内容: + + ```json + { + "mcp": "你选择的MCP Server的名称" + } + ``` + + 6. 下面的示例仅供参考,不要将示例中的内容作为选择MCP Server的依据。 + + ## 示例 + + ### 目标 + + 我需要一个MCP Server来完成一个任务。 + + ### MCP Server列表 + + - **mcp_1**: "MCP Server 1";MCP Server 1的描述 + - **mcp_2**: "MCP Server 2";MCP Server 2的描述 + + ### 请一步一步思考: + + 因为当前目标需要一个MCP Server来完成一个任务,所以选择mcp_1。 + + ### 选择结果 + + ```json + { + "mcp": "mcp_1" + } + ``` + + ## 现在开始! + + ### 目标 + + {{goal}} + + ### MCP Server列表 + + {% for mcp in mcp_list %} + - **{{mcp.id}}**: "{{mcp.name}}";{{mcp.description}} + {% endfor %} + + ### 请一步一步思考: + +""") +CREATE_PLAN = dedent(r""" + 你是一个计划生成器。 + 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + - Final结束步骤,当执行到这一步时,\ +表示计划执行结束,所得到的结果将作为最终结果。 + + + # 样例 + + ## 目标 + + 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + + ## 计划 + + + 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server + 2. 目标可以拆解为以下几个部分: + - 运行alpine:latest容器 + - 挂载主机目录 + - 在后台运行 + - 执行top命令 + 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令 + + + ```json + { + "plans": [ + { + "content": "选择一个支持Docker的MCP Server", + "tool": "mcp_selector", + "instruction": "需要一个支持Docker容器运行的MCP Server" + }, + { + "content": "使用Result[0]中选择的MCP Server,生成Docker命令", + "tool": "command_generator", + "instruction": "生成Docker命令:在后台运行alpine:latest容器,挂载/root到/data,执行top命令" + }, + { + "content": "在Result[0]的MCP Server上执行Result[1]生成的命令", + "tool": "command_executor", + "instruction": "执行Docker命令" + }, + { + "content": "任务执行完成,容器已在后台运行,结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始生成计划: + + ## 目标 + + {{goal}} + + # 计划 +""") +EVALUATE_PLAN = dedent(r""" + 你是一个计划评估器。 + 请根据给定的计划,和当前计划执行的实际情况,分析当前计划是否合理和完整,并生成改进后的计划。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + + # 你此前的计划是: + + {{ plan }} + + # 这个计划的执行情况是: + + 计划的执行情况将放置在 XML标签中。 + + + {{ memory }} + + + # 进行评估时的注意事项: + + - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。 + - 评估结果分为两个部分: + - 计划评估的结论 + - 改进后的计划 + - 请按照以下JSON格式输出评估结果: + + ```json + { + "evaluation": "评估结果", + "plans": [ + { + "content": "改进后的计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + # 现在开始评估计划: + +""") +FINAL_ANSWER = dedent(r""" + 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 + + # 用户目标 + + {{ goal }} + + # 计划执行情况 + + 为了完成上述目标,你实施了以下计划: + + {{ memory }} + + # 其他背景信息: + + {{ status }} + + # 现在,请根据以上信息,向用户报告目标的完成情况: + +""") +MEMORY_TEMPLATE = dedent(r""" + {% for ctx in context_list %} + - 第{{ loop.index }}步:{{ ctx.step_description }} + - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}` + - 执行状态:{{ ctx.status }} + - 得到数据:`{{ ctx.output_data }}` + {% endfor %} +""") diff --git a/apps/scheduler/mcp_agent/schema.py b/apps/scheduler/mcp_agent/schema.py deleted file mode 100644 index 61413907..00000000 --- a/apps/scheduler/mcp_agent/schema.py +++ /dev/null @@ -1,148 +0,0 @@ -"""MCP Agent执行数据结构""" -from typing import Any, Self - -from pydantic import BaseModel, Field - -from apps.schemas.enum_var import Role - - -class Function(BaseModel): - """工具函数""" - - name: str - arguments: dict[str, Any] - - -class ToolCall(BaseModel): - """Represents a tool/function call in a message""" - - id: str - type: str = "function" - function: Function - - -class Message(BaseModel): - """Represents a chat message in the conversation""" - - role: Role = Field(...) - content: str | None = Field(default=None) - tool_calls: list[ToolCall] | None = Field(default=None) - name: str | None = Field(default=None) - tool_call_id: str | None = Field(default=None) - - def __add__(self, other) -> list["Message"]: - """支持 Message + list 或 Message + Message 的操作""" - if isinstance(other, list): - return [self] + other - elif isinstance(other, Message): - return [self, other] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" - ) - - def __radd__(self, other) -> list["Message"]: - """支持 list + Message 的操作""" - if isinstance(other, list): - return other + [self] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" - ) - - def to_dict(self) -> dict: - """Convert message to dictionary format""" - message = {"role": self.role} - if self.content is not None: - message["content"] = self.content - if self.tool_calls is not None: - message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] - if self.name is not None: - message["name"] = self.name - if self.tool_call_id is not None: - message["tool_call_id"] = self.tool_call_id - return message - - @classmethod - def user_message(cls, content: str) -> Self: - """Create a user message""" - return cls(role=Role.USER, content=content) - - @classmethod - def system_message(cls, content: str) -> Self: - """Create a system message""" - return cls(role=Role.SYSTEM, content=content) - - @classmethod - def assistant_message( - cls, content: str | None = None, - ) -> Self: - """Create an assistant message""" - return cls(role=Role.ASSISTANT, content=content) - - @classmethod - def tool_message( - cls, content: str, name: str, tool_call_id: str, - ) -> Self: - """Create a tool message""" - return cls( - role=Role.TOOL, - content=content, - name=name, - tool_call_id=tool_call_id, - ) - - @classmethod - def from_tool_calls( - cls, - tool_calls: list[Any], - content: str | list[str] = "", - **kwargs, # noqa: ANN003 - ) -> Self: - """Create ToolCallsMessage from raw tool calls. - - Args: - tool_calls: Raw tool calls from LLM - content: Optional message content - """ - formatted_calls = [ - {"id": call.id, "function": call.function.model_dump(), "type": "function"} - for call in tool_calls - ] - return cls( - role=Role.ASSISTANT, - content=content, - tool_calls=formatted_calls, - **kwargs, - ) - - -class Memory(BaseModel): - messages: list[Message] = Field(default_factory=list) - max_messages: int = Field(default=100) - - def add_message(self, message: Message) -> None: - """Add a message to memory""" - self.messages.append(message) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def add_messages(self, messages: list[Message]) -> None: - """Add multiple messages to memory""" - self.messages.extend(messages) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def clear(self) -> None: - """Clear all messages""" - self.messages.clear() - - def get_recent_messages(self, n: int) -> list[Message]: - """Get n most recent messages""" - return self.messages[-n:] - - def to_dict_list(self) -> list[dict]: - """Convert messages to list of dicts""" - return [msg.to_dict() for msg in self.messages] diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py new file mode 100644 index 00000000..2ff50344 --- /dev/null +++ b/apps/scheduler/mcp_agent/select.py @@ -0,0 +1,185 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""选择MCP Server及其工具""" + +import logging + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.common.lance import LanceDB +from apps.common.mongo import MongoDB +from apps.llm.embedding import Embedding +from apps.llm.function import FunctionLLM +from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.mcp.prompt import ( + MCP_SELECT, +) +from apps.schemas.mcp import ( + MCPCollection, + MCPSelectResult, + MCPTool, +) + +logger = logging.getLogger(__name__) + + +class MCPSelector: + """MCP选择器""" + + def __init__(self) -> None: + """初始化助手类""" + self.input_tokens = 0 + self.output_tokens = 0 + + @staticmethod + def _assemble_sql(mcp_list: list[str]) -> str: + """组装SQL""" + sql = "(" + for mcp_id in mcp_list: + sql += f"'{mcp_id}', " + return sql.rstrip(", ") + ")" + + + async def _get_top_mcp_by_embedding( + self, + query: str, + mcp_list: list[str], + ) -> list[dict[str, str]]: + """通过向量检索获取Top5 MCP Server""" + logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) + mcp_table = await LanceDB().get_table("mcp") + query_embedding = await Embedding.get_embedding([query]) + mcp_vecs = await (await mcp_table.search( + query=query_embedding, + vector_column_name="embedding", + )).where(f"id IN {MCPSelector._assemble_sql(mcp_list)}").limit(5).to_list() + + # 拿到名称和description + logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs) + mcp_collection = MongoDB().get_collection("mcp") + llm_mcp_list: list[dict[str, str]] = [] + for mcp_vec in mcp_vecs: + mcp_id = mcp_vec["id"] + mcp_data = await mcp_collection.find_one({"_id": mcp_id}) + if not mcp_data: + logger.warning("[MCPHelper] 查询MCP Server名称和描述失败: %s", mcp_id) + continue + mcp_data = MCPCollection.model_validate(mcp_data) + llm_mcp_list.extend([{ + "id": mcp_id, + "name": mcp_data.name, + "description": mcp_data.description, + }]) + return llm_mcp_list + + + async def _get_mcp_by_llm( + self, + query: str, + mcp_list: list[dict[str, str]], + mcp_ids: list[str], + ) -> MCPSelectResult: + """通过LLM选择最合适的MCP Server""" + # 初始化jinja2环境 + env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, + ) + template = env.from_string(MCP_SELECT) + # 渲染模板 + mcp_prompt = template.render( + mcp_list=mcp_list, + goal=query, + ) + + # 调用大模型进行推理 + result = await self._call_reasoning(mcp_prompt) + + # 使用小模型提取JSON + return await self._call_function_mcp(result, mcp_ids) + + + async def _call_reasoning(self, prompt: str) -> str: + """调用大模型进行推理""" + logger.info("[MCPHelper] 调用推理大模型") + llm = ReasoningLLM() + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + result = "" + async for chunk in llm.call(message): + result += chunk + self.input_tokens += llm.input_tokens + self.output_tokens += llm.output_tokens + return result + + + async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: + """调用结构化输出小模型提取JSON""" + logger.info("[MCPHelper] 调用结构化输出小模型") + llm = FunctionLLM() + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": reasoning_result}, + ] + schema = MCPSelectResult.model_json_schema() + # schema中加入选项 + schema["properties"]["mcp_id"]["enum"] = mcp_ids + result = await llm.call(messages=message, schema=schema) + try: + result = MCPSelectResult.model_validate(result) + except Exception: + logger.exception("[MCPHelper] 解析MCP Select Result失败") + raise + return result + + + async def select_top_mcp( + self, + query: str, + mcp_list: list[str], + ) -> MCPSelectResult: + """ + 选择最合适的MCP Server + + 先通过Embedding选择Top5,然后通过LLM选择Top 1 + """ + # 通过向量检索获取Top5 + llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list) + + # 通过LLM选择最合适的 + return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) + + + @staticmethod + async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: + """选择最合适的工具""" + tool_vector = await LanceDB().get_table("mcp_tool") + query_embedding = await Embedding.get_embedding([query]) + tool_vecs = await (await tool_vector.search( + query=query_embedding, + vector_column_name="embedding", + )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list() + + # 拿到工具 + tool_collection = MongoDB().get_collection("mcp") + llm_tool_list = [] + + for tool_vec in tool_vecs: + # 到MongoDB里找对应的工具 + logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) + tool_data = await tool_collection.aggregate([ + {"$match": {"_id": tool_vec["mcp_id"]}}, + {"$unwind": "$tools"}, + {"$match": {"tools.id": tool_vec["id"]}}, + {"$project": {"_id": 0, "tools": 1}}, + {"$replaceRoot": {"newRoot": "$tools"}}, + ]) + async for tool in tool_data: + tool_obj = MCPTool.model_validate(tool) + llm_tool_list.append(tool_obj) + + return llm_tool_list diff --git a/apps/scheduler/mcp_agent/tool/__init__.py b/apps/scheduler/mcp_agent/tool/__init__.py deleted file mode 100644 index 4593f317..00000000 --- a/apps/scheduler/mcp_agent/tool/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool -from apps.scheduler.mcp_agent.tool.terminate import Terminate -from apps.scheduler.mcp_agent.tool.tool_collection import ToolCollection - -__all__ = [ - "BaseTool", - "Terminate", - "ToolCollection", -] diff --git a/apps/scheduler/mcp_agent/tool/base.py b/apps/scheduler/mcp_agent/tool/base.py deleted file mode 100644 index 04ad45c4..00000000 --- a/apps/scheduler/mcp_agent/tool/base.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field - - -class BaseTool(ABC, BaseModel): - name: str - description: str - parameters: Optional[dict] = None - - class Config: - arbitrary_types_allowed = True - - async def __call__(self, **kwargs) -> Any: - return await self.execute(**kwargs) - - @abstractmethod - async def execute(self, **kwargs) -> Any: - """使用给定的参数执行工具""" - - def to_param(self) -> Dict: - """将工具转换为函数调用格式""" - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, - } - - -class ToolResult(BaseModel): - """表示工具执行的结果""" - - output: Any = Field(default=None) - error: Optional[str] = Field(default=None) - system: Optional[str] = Field(default=None) - - class Config: - arbitrary_types_allowed = True - - def __bool__(self): - return any(getattr(self, field) for field in self.__fields__) - - def __add__(self, other: "ToolResult"): - def combine_fields( - field: Optional[str], other_field: Optional[str], concatenate: bool = True - ): - if field and other_field: - if concatenate: - return field + other_field - raise ValueError("Cannot combine tool results") - return field or other_field - - return ToolResult( - output=combine_fields(self.output, other.output), - error=combine_fields(self.error, other.error), - system=combine_fields(self.system, other.system), - ) - - def __str__(self): - return f"Error: {self.error}" if self.error else self.output - - def replace(self, **kwargs): - """返回一个新的ToolResult,其中替换了给定的字段""" - # return self.copy(update=kwargs) - return type(self)(**{**self.dict(), **kwargs}) - - -class ToolFailure(ToolResult): - """表示失败的ToolResult""" diff --git a/apps/scheduler/mcp_agent/tool/terminate.py b/apps/scheduler/mcp_agent/tool/terminate.py deleted file mode 100644 index 84aa1203..00000000 --- a/apps/scheduler/mcp_agent/tool/terminate.py +++ /dev/null @@ -1,25 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool - - -_TERMINATE_DESCRIPTION = """当请求得到满足或助理无法继续处理任务时,终止交互。 -当您完成所有任务后,调用此工具结束工作。""" - - -class Terminate(BaseTool): - name: str = "terminate" - description: str = _TERMINATE_DESCRIPTION - parameters: dict = { - "type": "object", - "properties": { - "status": { - "type": "string", - "description": "交互的完成状态", - "enum": ["success", "failure"], - } - }, - "required": ["status"], - } - - async def execute(self, status: str) -> str: - """Finish the current execution""" - return f"交互已完成,状态为: {status}" diff --git a/apps/scheduler/mcp_agent/tool/tool_collection.py b/apps/scheduler/mcp_agent/tool/tool_collection.py deleted file mode 100644 index 95bda317..00000000 --- a/apps/scheduler/mcp_agent/tool/tool_collection.py +++ /dev/null @@ -1,55 +0,0 @@ -"""用于管理多个工具的集合类""" -import logging -from typing import Any - -from apps.scheduler.mcp_agent.tool.base import BaseTool, ToolFailure, ToolResult - -logger = logging.getLogger(__name__) - - -class ToolCollection: - """定义工具的集合""" - - class Config: - arbitrary_types_allowed = True - - def __init__(self, *tools: BaseTool): - self.tools = tools - self.tool_map = {tool.name: tool for tool in tools} - - def __iter__(self): - return iter(self.tools) - - def to_params(self) -> list[dict[str, Any]]: - return [tool.to_param() for tool in self.tools] - - async def execute( - self, *, name: str, tool_input: dict[str, Any] = None - ) -> ToolResult: - tool = self.tool_map.get(name) - if not tool: - return ToolFailure(error=f"Tool {name} is invalid") - try: - result = await tool(**tool_input) - return result - except Exception as e: - return ToolFailure(error=f"Failed to execute tool {name}: {e}") - - def add_tool(self, tool: BaseTool): - """ - 将单个工具添加到集合中。 - - 如果已存在同名工具,则将跳过该工具并记录警告。 - """ - if tool.name in self.tool_map: - logger.warning(f"Tool {tool.name} already exists in collection, skipping") - return self - - self.tools += (tool,) - self.tool_map[tool.name] = tool - return self - - def add_tools(self, *tools: BaseTool): - for tool in tools: - self.add_tool(tool) - return self diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 4c2c4cf0..32331cf3 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -214,7 +214,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) # 若状态为成功,删除Task - if not task.state or task.state.status == StepStatus.SUCCESS: + if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED: await TaskManager.delete_task_by_task_id(task.id) else: # 更新Task diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index a84dc3a3..20e9c0f9 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -20,6 +20,17 @@ class StepStatus(str, Enum): SUCCESS = "success" ERROR = "error" PARAM = "param" + CANCELLED = "cancelled" + + +class FlowStatus(str, Enum): + """Flow状态""" + + WAITING = "waiting" + RUNNING = "running" + SUCCESS = "success" + ERROR = "error" + CANCELLED = "cancelled" class DocumentStatus(str, Enum): diff --git a/apps/schemas/record.py b/apps/schemas/record.py index b5e1b0c5..0c3d7185 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -44,6 +44,7 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 37fdebbf..98d8c6b3 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import StepStatus +from apps.schemas.enum_var import FlowStatus, StepStatus from apps.schemas.flow import Step from apps.schemas.mcp import MCPPlan @@ -23,10 +23,11 @@ class FlowStepHistory(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="FlowID") flow_name: str = Field(description="Flow名称") + flow_status: FlowStatus = Field(description="Flow状态") step_id: str = Field(description="当前步骤名称") step_name: str = Field(description="当前步骤名称") - step_description: str = Field(description="当前步骤描述") - status: StepStatus = Field(description="当前步骤状态") + step_description: str = Field(description="当前步骤描述", default="") + step_status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) @@ -39,10 +40,11 @@ class ExecutorState(BaseModel): flow_id: str = Field(description="Flow ID") flow_name: str = Field(description="Flow名称") description: str = Field(description="Flow描述") - status: StepStatus = Field(description="Flow执行状态") - # 附加信息 + flow_status: FlowStatus = Field(description="Flow状态") + # 任务级数据 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_status: StepStatus = Field(description="当前步骤状态") step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py index 5375180a..db1f5ead 100644 --- a/tests/common/test_queue.py +++ b/tests/common/test_queue.py @@ -74,8 +74,8 @@ async def test_push_output_with_flow(message_queue, mock_task): mock_task.state.flow_id = "flow_id" mock_task.state.step_id = "step_id" mock_task.state.step_name = "step_name" - mock_task.state.status = "running" - + mock_task.state.step_status = "running" + await message_queue.init("test_task") await message_queue.push_output(mock_task, EventType.TEXT_ADD, {}) -- Gitee From 4a58ac8a819e5cacc6401aeb3caf868b86be36ce Mon Sep 17 00:00:00 2001 From: zxstty Date: Sun, 27 Jul 2025 15:33:20 +0800 Subject: [PATCH 19/60] =?UTF-8?q?=E5=AE=8C=E5=96=84=E9=80=89=E6=8B=A9?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/main.py | 2 + apps/routers/parameter.py | 100 ++----- apps/scheduler/call/__init__.py | 3 +- apps/scheduler/call/choice/choice.py | 95 ++++-- .../call/choice/condition_handler.py | 272 +++++++++++++----- apps/scheduler/call/choice/schema.py | 63 ++-- apps/scheduler/call/core.py | 26 +- apps/scheduler/executor/flow.py | 20 +- apps/scheduler/executor/step.py | 3 - apps/scheduler/slot/slot.py | 43 ++- apps/schemas/config.py | 2 +- apps/schemas/parameters.py | 69 +++++ apps/schemas/response_data.py | 47 ++- apps/services/flow.py | 46 --- apps/services/node.py | 6 +- apps/services/parameter.py | 86 ++++++ 16 files changed, 564 insertions(+), 319 deletions(-) create mode 100644 apps/schemas/parameters.py create mode 100644 apps/services/parameter.py diff --git a/apps/main.py b/apps/main.py index c4ca2bfb..c26e5e47 100644 --- a/apps/main.py +++ b/apps/main.py @@ -36,6 +36,7 @@ from apps.routers import ( record, service, user, + parameter ) from apps.scheduler.pool.pool import Pool @@ -66,6 +67,7 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(parameter.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 538c2556..6edbe2e1 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -5,13 +5,10 @@ from fastapi.responses import JSONResponse from apps.dependency import get_user from apps.dependency.user import verify_user -from apps.scheduler.call.choice.choice import Choice +from apps.services.parameter import ParameterManager from apps.schemas.response_data import ( - FlowStructureGetMsg, - FlowStructureGetRsp, - GetParamsMsg, - GetParamsRsp, - ResponseData, + GetOperaRsp, + GetParamsRsp ) from apps.services.application import AppManager from apps.services.flow import FlowManager @@ -25,10 +22,7 @@ router = APIRouter( ) -@router.get("", response_model={ - status.HTTP_403_FORBIDDEN: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - },) +@router.get("", response_model=GetParamsRsp) async def get_parameters( user_sub: Annotated[str, Depends(get_user)], app_id: Annotated[str, Query(alias="appId")], @@ -39,105 +33,45 @@ async def get_parameters( if not await AppManager.validate_user_app_access(user_sub, app_id): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=FlowStructureGetRsp( + content=GetParamsRsp( code=status.HTTP_403_FORBIDDEN, message="用户没有权限访问该流", - result=FlowStructureGetMsg(), + result=[], ).model_dump(exclude_none=True, by_alias=True), ) flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) if not flow: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( + content=GetParamsRsp( code=status.HTTP_404_NOT_FOUND, message="未找到该流", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) - result = await FlowManager.get_params_by_flow_and_step_id(flow, step_id) - if not result: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( - code=status.HTTP_404_NOT_FOUND, - message="未找到该节点", - result={}, + result=[], ).model_dump(exclude_none=True, by_alias=True), ) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) return JSONResponse( status_code=status.HTTP_200_OK, content=GetParamsRsp( code=status.HTTP_200_OK, message="获取参数成功", - result=GetParamsMsg(result=result), + result=result ).model_dump(exclude_none=True, by_alias=True), ) -async def operate_parameters(operate: str) -> list[str] | None: - """ - 根据操作类型获取对应的操作符参数列表 - - Args: - operate: 操作类型,支持 'int', 'str', 'bool' - - Returns: - 对应的操作符参数列表,若类型不支持则返回None - """ - string = [ - "equal", - "not_equal", - "great",#长度大于 - "great_equals",#长度大于等于 - "less",#长度小于 - "less_equals",#长度小于等于 - "greater", - "greater_equals", - "smaller", - "smaller_equals", - ] - integer = [ - "equal", - "not_equal", - "great", - "great_equals", - "less", - "less_equals", - ] - boolen = ["equal", "not_equal", "is_empty", "not_empty"] - if operate in string: - return string - if operate in integer: - return integer - if operate in boolen: - return boolen - return None - -@router.get("/operate", response_model={ - status.HTTP_200_OK: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - },) +@router.get("/operate", response_model=GetOperaRsp) async def get_operate_parameters( user_sub: Annotated[str, Depends(get_user)], - operate: Annotated[str, Query(alias="operate")], + param_type: Annotated[str, Query(alias="ParamType")], ) -> JSONResponse: """Get parameters for node choice.""" - result = await operate_parameters(operate) - if not result: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content=ResponseData( - code=status.HTTP_404_NOT_FOUND, - message="未找到该符号", - result=[], - ).model_dump(exclude_none=True, by_alias=True), - ) + result = await ParameterManager.get_operate_and_bind_type(param_type) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( + content=GetOperaRsp( code=status.HTTP_200_OK, - message="获取参数成功", - result=result, + message="获取操作成功", + result=result ).model_dump(exclude_none=True, by_alias=True), - ) \ No newline at end of file + ) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2ee8b862..c5a6f054 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -8,7 +8,7 @@ from apps.scheduler.call.mcp.mcp import MCP from apps.scheduler.call.rag.rag import RAG from apps.scheduler.call.sql.sql import SQL from apps.scheduler.call.suggest.suggest import Suggestion - +from apps.scheduler.call.choice.choice import Choice # 只包含需要在编排界面展示的工具 __all__ = [ "API", @@ -18,4 +18,5 @@ __all__ = [ "SQL", "Graph", "Suggestion", + "Choice" ] diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 24df0dae..8cab8288 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,6 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" +import ast +import copy import logging from collections.abc import AsyncGenerator from typing import Any @@ -9,11 +11,13 @@ from pydantic import Field from apps.scheduler.call.choice.condition_handler import ConditionHandler from apps.scheduler.call.choice.schema import ( + Condition, ChoiceBranch, ChoiceInput, ChoiceOutput, Logic, ) +from apps.schemas.parameters import Type from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( @@ -30,17 +34,13 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" to_user: bool = Field(default=False) - choices: list[ChoiceBranch] = Field(description="分支", default=[]) + choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(), + ChoiceBranch(conditions=[Condition()], is_default=False)]) @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") - - def _raise_value_error(self, msg: str) -> None: - """统一处理 ValueError 异常抛出""" - logger.warning(msg) - raise ValueError(msg) + return CallInfo(name="选择器", description="使用大模型或使用程序做出判断") async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: """替换choices中的系统变量""" @@ -51,31 +51,76 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): # 验证逻辑运算符 if choice.logic not in [Logic.AND, Logic.OR]: msg = f"无效的逻辑运算符: {choice.logic}" - self._raise_value_error(msg) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue valid_conditions = [] - for condition in choice.conditions: + for i in range(len(choice.conditions)): + condition = copy.deepcopy(choice.conditions[i]) # 处理左值 - if condition.left.step_id: - condition.left.value = self._extract_history_variables(condition.left.step_id, call_vars.history) + if condition.left.step_id is not None: + condition.left.value = self._extract_history_variables( + condition.left.step_id+'/'+condition.left.value, call_vars.history) # 检查历史变量是否成功提取 - if condition.left.value is None: - msg = f"步骤 {condition.left.step_id} 的历史变量不存在" - self._raise_value_error(msg) + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.left.value, condition.left.type): + msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue else: msg = "左侧变量缺少step_id" - self._raise_value_error(msg) - - valid_conditions.append(condition) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + # 处理右值 + if condition.right.step_id is not None: + condition.right.value = self._extract_history_variables( + condition.right.step_id+'/'+condition.right.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.right.value is None: + msg = f"步骤 {condition.right.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + # 如果右值没有step_id,尝试从call_vars中获取 + right_value_type = await ConditionHandler.get_value_type_from_operate( + condition.operate) + if right_value_type is None: + msg = f"不支持的运算符: {condition.operate}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if condition.right.type != right_value_type: + msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if right_value_type == Type.STRING: + condition.right.value = str(condition.right.value) + else: + condition.right.value = ast.literal_eval(condition.right.value) + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + valid_conditions.append(condition) # 如果所有条件都无效,抛出异常 if not valid_conditions: msg = "分支没有有效条件" - self._raise_value_error(msg) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue # 更新有效条件 choice.conditions = valid_conditions - valid_choices.append(choice.dict()) + valid_choices.append(choice) except ValueError as e: logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) @@ -95,13 +140,11 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """执行Choice工具""" # 解析输入数据 data = ChoiceInput(**input_data) - ret: CallOutputChunk = CallOutputChunk( - type=CallOutputType.DATA, - content=None, - ) - condition_handler = ConditionHandler() try: - ret.content = condition_handler.handler(data.choices) - yield ret + branch_id = ConditionHandler.handler(data.choices) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True), + ) except Exception as e: raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index feab5c64..7542f294 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -1,19 +1,79 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """处理条件分支的工具""" + import logging from pydantic import BaseModel -from apps.scheduler.call.choice.schema import ChoiceBranch, ChoiceOutput, Condition, Logic, Operator, Type, Value +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) + +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + Condition, + Logic, + Value +) logger = logging.getLogger(__name__) class ConditionHandler(BaseModel): """条件分支处理器""" + @staticmethod + async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | + BoolOperate | DictOperate) -> Type: + """获取右值的类型""" + if isinstance(operate, NumberOperate): + return Type.NUMBER + if operate in [ + StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS, + StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]: + return Type.STRING + if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN, + StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN, + StringOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]: + return Type.LIST + if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]: + return Type.STRING + if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN, + ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN, + ListOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]: + return Type.BOOL + if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]: + return Type.DICT + if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]: + return Type.STRING + return None + + @staticmethod + def check_value_type(value: Value, expected_type: Type) -> bool: + """检查值的类型是否符合预期""" + if expected_type == Type.STRING and isinstance(value.value, str): + return True + if expected_type == Type.NUMBER and isinstance(value.value, (int, float)): + return True + if expected_type == Type.LIST and isinstance(value.value, list): + return True + if expected_type == Type.DICT and isinstance(value.value, dict): + return True + if expected_type == Type.BOOL and isinstance(value.value, bool): + return True + return False - def handler(self, choices: list[ChoiceBranch]) -> ChoiceOutput: + @staticmethod + def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" default_branch = [c for c in choices if c.is_default] @@ -22,7 +82,7 @@ class ConditionHandler(BaseModel): if block_judgement.is_default: continue for condition in block_judgement.conditions: - result = self._judge_condition(condition) + result = ConditionHandler._judge_condition(condition) results.append(result) if block_judgement.logic == Logic.AND: final_result = all(results) @@ -30,58 +90,55 @@ class ConditionHandler(BaseModel): final_result = any(results) if final_result: - return { - "branch_id": block_judgement.branch_id, - "message": f"选择分支:{block_judgement.branch_id}", - } + return block_judgement.branch_id # 如果没有匹配的分支,选择默认分支 if default_branch: - return { - "branch_id": default_branch[0].branch_id, - "message": f"选择默认分支:{default_branch[0].branch_id}", - } - return { - "branch_id": "", - "message": "没有匹配的分支,且没有默认分支", - } - - def _judge_condition(self, condition: Condition) -> bool: + return default_branch[0].branch_id + return "" + + @staticmethod + def _judge_condition(condition: Condition) -> bool: """ 判断条件是否成立。 Args: - condition (Condition): 'left', 'operator', 'right', 'type' + condition (Condition): 'left', 'operate', 'right', 'type' Returns: bool """ left = condition.left - operator = condition.operator + operate = condition.operate right = condition.right value_type = condition.type result = None if value_type == Type.STRING: - result = self._judge_string_condition(left, operator, right) - elif value_type == Type.INT: - result = self._judge_int_condition(left, operator, right) + result = ConditionHandler._judge_string_condition(left, operate, right) + elif value_type == Type.NUMBER: + result = ConditionHandler._judge_int_condition(left, operate, right) elif value_type == Type.BOOL: - result = self._judge_bool_condition(left, operator, right) + result = ConditionHandler._judge_bool_condition(left, operate, right) + elif value_type == Type.LIST: + result = ConditionHandler._judge_list_condition(left, operate, right) + elif value_type == Type.DICT: + result = ConditionHandler._judge_dict_condition(left, operate, right) else: logger.error("不支持的数据类型: %s", value_type) msg = f"不支持的数据类型: {value_type}" raise ValueError(msg) return result - def _judge_string_condition(self, left: Value, operator: Operator, right: Value) -> bool: + @staticmethod + def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool: """ 判断字符串类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -95,39 +152,41 @@ class ConditionHandler(BaseModel): raise TypeError(msg) right_value = right.value result = False - if operator == Operator.EQUAL: - result = left_value == right_value - elif operator == Operator.NEQUAL: - result = left_value != right_value - elif operator == Operator.GREAT: - result = len(left_value) > len(right_value) - elif operator == Operator.GREAT_EQUALS: - result = len(left_value) >= len(right_value) - elif operator == Operator.LESS: - result = len(left_value) < len(right_value) - elif operator == Operator.LESS_EQUALS: - result = len(left_value) <= len(right_value) - elif operator == Operator.GREATER: - result = left_value > right_value - elif operator == Operator.GREATER_EQUALS: - result = left_value >= right_value - elif operator == Operator.SMALLER: - result = left_value < right_value - elif operator == Operator.SMALLER_EQUALS: - result = left_value <= right_value - elif operator == Operator.CONTAINS: - result = right_value in left_value - elif operator == Operator.NOT_CONTAINS: - result = right_value not in left_value - return result + if operate == StringOperate.EQUAL: + return left_value == right_value + elif operate == StringOperate.NOT_EQUAL: + return left_value != right_value + elif operate == StringOperate.CONTAINS: + return right_value in left_value + elif operate == StringOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == StringOperate.STARTS_WITH: + return left_value.startswith(right_value) + elif operate == StringOperate.ENDS_WITH: + return left_value.endswith(right_value) + elif operate == StringOperate.REGEX_MATCH: + import re + return bool(re.match(right_value, left_value)) + elif operate == StringOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == StringOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == StringOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False - def _judge_int_condition(self, left: Value, operator: Operator, right: Value) -> bool: # noqa: PLR0911 + @staticmethod + def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911 """ - 判断整数类型的条件。 + 判断数字类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -135,32 +194,33 @@ class ConditionHandler(BaseModel): """ left_value = left.value - if not isinstance(left_value, int): - logger.error("左值不是整数类型: %s", left_value) - msg = "左值必须是整数类型" + if not isinstance(left_value, (int, float)): + logger.error("左值不是数字类型: %s", left_value) + msg = "左值必须是数字类型" raise TypeError(msg) right_value = right.value - if operator == Operator.EQUAL: + if operate == NumberOperate.EQUAL: return left_value == right_value - if operator == Operator.NEQUAL: + elif operate == NumberOperate.NOT_EQUAL: return left_value != right_value - if operator == Operator.GREAT: + elif operate == NumberOperate.GREATER_THAN: return left_value > right_value - if operator == Operator.GREAT_EQUALS: - return left_value >= right_value - if operator == Operator.LESS: + elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004 return left_value < right_value - if operator == Operator.LESS_EQUALS: + elif operate == NumberOperate.GREATER_THAN_OR_EQUAL: + return left_value >= right_value + elif operate == NumberOperate.LESS_THAN_OR_EQUAL: return left_value <= right_value return False - def _judge_bool_condition(self, left: Value, operator: Operator, right: Value) -> bool: + @staticmethod + def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool: """ 判断布尔类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -173,12 +233,82 @@ class ConditionHandler(BaseModel): msg = "左值必须是布尔类型" raise TypeError(msg) right_value = right.value - if operator == Operator.EQUAL: + if operate == BoolOperate.EQUAL: + return left_value == right_value + elif operate == BoolOperate.NOT_EQUAL: + return left_value != right_value + elif operate == BoolOperate.IS_EMPTY: + return not left_value + elif operate == BoolOperate.NOT_EMPTY: + return left_value + return False + + @staticmethod + def _judge_list_condition(left: Value, operate: ListOperate, right: Value): + """ + 判断列表类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, list): + logger.error("左值不是列表类型: %s", left_value) + msg = "左值必须是列表类型" + raise TypeError(msg) + right_value = right.value + if operate == ListOperate.EQUAL: + return left_value == right_value + elif operate == ListOperate.NOT_EQUAL: + return left_value != right_value + elif operate == ListOperate.CONTAINS: + return right_value in left_value + elif operate == ListOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == ListOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == ListOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == ListOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_dict_condition(left: Value, operate: DictOperate, right: Value): + """ + 判断字典类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, dict): + logger.error("左值不是字典类型: %s", left_value) + msg = "左值必须是字典类型" + raise TypeError(msg) + right_value = right.value + if operate == DictOperate.EQUAL: return left_value == right_value - if operator == Operator.NEQUAL: + elif operate == DictOperate.NOT_EQUAL: return left_value != right_value - if operator == Operator.IS_EMPTY: - return left_value == "" - if operator == Operator.NOT_EMPTY: - return left_value != "" + elif operate == DictOperate.CONTAINS_KEY: + return right_value in left_value + elif operate == DictOperate.NOT_CONTAINS_KEY: + return right_value not in left_value return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index ed1c628c..b95b1668 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,30 +1,20 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +import uuid from enum import Enum from pydantic import BaseModel, Field +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.scheduler.call.core import DataBase -class Operator(str, Enum): - """Choice Call支持的运算符""" - - EQUAL = "equal" - NEQUAL = "not_equal" - GREAT = "great" - GREAT_EQUALS = "great_equals" - LESS = "less" - LESS_EQUALS = "less_equals" - # string - CONTAINS = "contains" - NOT_CONTAINS = "not_contains" - GREATER = "greater" - GREATER_EQUALS = "greater_equals" - SMALLER = "smaller" - SMALLER_EQUALS = "smaller_equals" - # bool - IS_EMPTY = "is_empty" - NOT_EMPTY = "not_empty" class Logic(str, Enum): @@ -34,38 +24,31 @@ class Logic(str, Enum): OR = "or" -class Type(str, Enum): - """Choice 工具支持的类型""" - - STRING = "string" - INT = "int" - BOOL = "bool" - - -class Value(BaseModel): +class Value(DataBase): """值的结构""" - step_id: str = Field(description="步骤id", default="") - value: str | int | bool = Field(description="值", default=None) + step_id: str | None = Field(description="步骤id", default=None) + type: Type | None = Field(description="值的类型", default=None) + value: str | float | int | bool | list | dict | None = Field(description="值", default=None) -class Condition(BaseModel): +class Condition(DataBase): """单个条件""" - type: Type = Field(description="值的类型", default=Type.STRING) - left: Value = Field(description="左值") - right: Value = Field(description="右值") - operator: Operator = Field(description="运算符", default="equal") - id: int = Field(description="条件ID") + left: Value = Field(description="左值", default=Value()) + right: Value = Field(description="右值", default=Value()) + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field( + description="运算符", default=None) + id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4())) -class ChoiceBranch(BaseModel): +class ChoiceBranch(DataBase): """子分支""" - branch_id: str = Field(description="分支ID", default="") + branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4())) logic: Logic = Field(description="逻辑运算符", default=Logic.AND) conditions: list[Condition] = Field(description="条件列表", default=[]) - is_default: bool = Field(description="是否为默认分支", default=False) + is_default: bool = Field(description="是否为默认分支", default=True) class ChoiceInput(DataBase): @@ -76,3 +59,5 @@ class ChoiceInput(DataBase): class ChoiceOutput(DataBase): """Choice Call的输出""" + + branch_id: str = Field(description="分支ID", default="") diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 2b1cbba8..af28c6a3 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -76,21 +76,18 @@ class CoreCall(BaseModel): extra="allow", ) - def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: """初始化子类""" super().__init_subclass__(**kwargs) cls.input_model = input_model cls.output_model = output_model - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" err = "[CoreCall] 必须手动实现info方法" raise NotImplementedError(err) - @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" @@ -120,7 +117,6 @@ class CoreCall(BaseModel): summary=executor.task.runtime.summary, ) - @staticmethod def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any: """ @@ -131,18 +127,16 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") + if len(split_path) < 2: + err = f"[CoreCall] 路径格式错误: {path}" + logger.error(err) + return None if split_path[0] not in history: err = f"[CoreCall] 步骤{split_path[0]}不存在" logger.error(err) - raise CallError( - message=err, - data={ - "step_id": split_path[0], - }, - ) - + return None data = history[split_path[0]].output_data - for key in split_path[1:]: + for key in split_path[2:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) @@ -156,7 +150,6 @@ class CoreCall(BaseModel): data = data[key] return data - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """实例化Call类""" @@ -170,36 +163,30 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj - async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) - async def _init(self, call_vars: CallVars) -> DataBase: """初始化Call类,并返回Call的输入""" err = "[CoreCall] 初始化方法必须手动实现" raise NotImplementedError(err) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的流式输出方法""" yield CallOutputChunk(type=CallOutputType.TEXT, content="") - async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" async for chunk in self._exec(input_data): yield chunk await self._after_exec(input_data) - async def _llm(self, messages: list[dict[str, Any]]) -> str: """Call可直接使用的LLM非流式调用""" result = "" @@ -210,7 +197,6 @@ class CoreCall(BaseModel): self.output_tokens = llm.output_tokens return result - async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel: """Call可直接使用的JSON生成""" json = FunctionLLM() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index d8d22c46..a86ec4ac 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -70,7 +69,6 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" # 创建步骤Runner @@ -90,7 +88,6 @@ class FlowExecutor(BaseExecutor): # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: @@ -102,7 +99,6 @@ class FlowExecutor(BaseExecutor): # 执行Step await self._invoke_runner(queue_item) - async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" next_ids = [] @@ -111,15 +107,14 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] if self.task.state.step_name == "Choice": # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"].get("branch_id", None) + branch_id = self.task.context[-1]["output_data"]["branch_id"] if branch_id: self.task.state.step_id = self.task.state.step_id + "." + branch_id logger.info("[FlowExecutor] 分支ID:%s", branch_id) @@ -127,7 +122,7 @@ class FlowExecutor(BaseExecutor): logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") return [] - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -146,7 +141,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -159,8 +153,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -179,7 +173,7 @@ class FlowExecutor(BaseExecutor): # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -192,7 +186,7 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 506f3bb1..f3aeb82c 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -170,7 +169,6 @@ class StepExecutor(BaseExecutor): self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor): return content - async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 4caab4d2..89433cad 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -12,6 +12,8 @@ from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema.validators import extend +from apps.schemas.response_data import ParamsNode +from apps.scheduler.call.choice.schema import Type from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -221,6 +223,45 @@ class Slot: return data return _extract_type_desc(self._schema) + def get_params_node_from_schema(self, root: str = "") -> ParamsNode: + """从JSON Schema中提取ParamsNode""" + def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode: + """递归提取ParamsNode""" + if "type" not in schema_node: + return None + + param_type = schema_node["type"] + if param_type == "object": + param_type = Type.DICT + elif param_type == "array": + param_type = Type.LIST + elif param_type == "string": + param_type = Type.STRING + elif param_type == "number": + param_type = Type.NUMBER + elif param_type == "boolean": + param_type = Type.BOOL + else: + logger.warning(f"[Slot] 不支持的参数类型: {param_type}") + return None + sub_params = [] + + if param_type == "object" and "properties" in schema_node: + for key, value in schema_node["properties"].items(): + sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + else: + # 对于非对象类型,直接返回空子参数 + sub_params = None + return ParamsNode(paramName=name, + paramPath=path, + paramType=param_type, + subParams=sub_params) + try: + return _extract_params_node(self._schema, name=root, path=root) + except Exception as e: + logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") + return None + def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """将JSON Schema扁平化""" result = {} @@ -276,7 +317,6 @@ class Slot: logger.exception("[Slot] 错误schema不合法: %s", error.schema) return {}, [] - def _assemble_patch( self, key: str, @@ -329,7 +369,6 @@ class Slot: logger.info("[Slot] 组装patch: %s", patch_list) return patch_list - def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]: """将用户手动填充的参数专为真实JSON""" json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 99bcccde..e91f5f75 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -129,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - no_auth: NoauthConfig + no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig()) deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py new file mode 100644 index 00000000..bd908d23 --- /dev/null +++ b/apps/schemas/parameters.py @@ -0,0 +1,69 @@ +from enum import Enum + + +class NumberOperate(str, Enum): + """Choice 工具支持的数字运算符""" + + EQUAL = "number_equal" + NOT_EQUAL = "number_not_equal" + GREATER_THAN = "number_greater_than" + LESS_THAN = "number_less_than" + GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal" + LESS_THAN_OR_EQUAL = "number_less_than_or_equal" + + +class StringOperate(str, Enum): + """Choice 工具支持的字符串运算符""" + + EQUAL = "string_equal" + NOT_EQUAL = "string_not_equal" + CONTAINS = "string_contains" + NOT_CONTAINS = "string_not_contains" + STARTS_WITH = "string_starts_with" + ENDS_WITH = "string_ends_with" + LENGTH_EQUAL = "string_length_equal" + LENGTH_GREATER_THAN = "string_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal" + LENGTH_LESS_THAN = "string_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal" + REGEX_MATCH = "string_regex_match" + + +class ListOperate(str, Enum): + """Choice 工具支持的列表运算符""" + + EQUAL = "list_equal" + NOT_EQUAL = "list_not_equal" + CONTAINS = "list_contains" + NOT_CONTAINS = "list_not_contains" + LENGTH_EQUAL = "list_length_equal" + LENGTH_GREATER_THAN = "list_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal" + LENGTH_LESS_THAN = "list_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal" + + +class BoolOperate(str, Enum): + """Choice 工具支持的布尔运算符""" + + EQUAL = "bool_equal" + NOT_EQUAL = "bool_not_equal" + + +class DictOperate(str, Enum): + """Choice 工具支持的字典运算符""" + + EQUAL = "dict_equal" + NOT_EQUAL = "dict_not_equal" + CONTAINS_KEY = "dict_contains_key" + NOT_CONTAINS_KEY = "dict_not_contains_key" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + NUMBER = "number" + LIST = "list" + DICT = "dict" + BOOL = "bool" diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 20d7ad9b..b1dc77b7 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -14,6 +14,14 @@ from apps.schemas.flow_topology import ( NodeServiceItem, PositionItem, ) +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType from apps.schemas.record import RecordData from apps.schemas.user import UserInfo @@ -629,20 +637,41 @@ class ListLLMRsp(ResponseData): result: list[LLMProviderInfo] = Field(default=[], title="Result") -class Params(BaseModel): + +class ParamsNode(BaseModel): """参数数据结构""" + param_name: str = Field(..., description="参数名称", alias="paramName") + param_path: str = Field(..., description="参数路径", alias="paramPath") + param_type: Type = Field(..., description="参数类型", alias="paramType") + sub_params: list["ParamsNode"] | None = Field( + default=None, description="子参数列表", alias="subParams" + ) - id: str = Field(..., description="StepID") - name: str = Field(..., description="Step名称") - parameters: dict[str, Any] = Field(..., description="参数") - operate: str = Field(..., description="比较符") -class GetParamsMsg(BaseModel): - """GET /api/params 返回数据结构""" +class StepParams(BaseModel): + """参数数据结构""" + step_id: str = Field(..., description="步骤ID", alias="stepId") + name: str = Field(..., description="Step名称") + params_node: ParamsNode | None = Field( + default=None, description="参数节点", alias="paramsNode") - result: list[Params] = Field(..., title="Result") class GetParamsRsp(ResponseData): """GET /api/params 返回数据结构""" - result: GetParamsMsg \ No newline at end of file + result: list[StepParams] = Field( + default=[], description="参数列表", alias="result" + ) + + +class OperateAndBindType(BaseModel): + """操作和绑定类型数据结构""" + + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型") + bind_type: Type = Field(description="绑定类型") + + +class GetOperaRsp(ResponseData): + """GET /api/operate 返回数据结构""" + + result: list[OperateAndBindType] = Field(..., title="Result") diff --git a/apps/services/flow.py b/apps/services/flow.py index cb8ad57f..c9fd86fd 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,6 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager -from apps.schemas.response_data import Params logger = logging.getLogger(__name__) @@ -470,48 +469,3 @@ class FlowManager: return False else: return True - - @staticmethod - async def get_params_by_flow_and_step_id( - flow: FlowItem, step_id: str - ) -> list[Params] | None: - """递归收集指定节点之前所有路径上的节点参数""" - params = [] - collected = set() # 记录已收集参数的节点 - - async def backtrack(current_id: str, visited: set) -> None: - # 避免循环递归 - if current_id in visited: - return - visited.add(current_id) - - # 获取所有指向当前节点的边 - incoming_edges = [ - edge for edge in flow.edges if edge.target_node == current_id - ] - - for edge in incoming_edges: - source_id = edge.source_node - - # 跳过起始节点 - if source_id == "start": - continue - - # 收集当前节点的参数(如果未被收集过) - if source_id not in collected: - node = flow.nodes.get(source_id) - if node: - collected.add(source_id) - params.append( - Params( - id=source_id, - name=node.name, - parameters=node.parameters.get("parameters", {}), - ), - ) - - # 继续回溯,传递当前路径的visited集合副本 - await backtrack(source_id, visited.copy()) - - await backtrack(step_id, set()) - return params diff --git a/apps/services/node.py b/apps/services/node.py index 6f0d492e..bf48e71f 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -16,6 +16,7 @@ NODE_TYPE_MAP = { "API": APINode, } + class NodeManager: """Node管理器""" @@ -29,7 +30,6 @@ class NodeManager: raise ValueError(err) return node["call_id"] - @staticmethod async def get_node(node_id: str) -> NodePool: """获取Node的类型""" @@ -40,7 +40,6 @@ class NodeManager: raise ValueError(err) return NodePool.model_validate(node) - @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" @@ -52,7 +51,6 @@ class NodeManager: return "" return node_doc["name"] - @staticmethod def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: """递归合并参数Schema,将known_params中的值填充到params_schema的对应位置""" @@ -75,7 +73,6 @@ class NodeManager: return params_schema - @staticmethod async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" @@ -100,7 +97,6 @@ class NodeManager: err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) raise ValueError(err) - # 返回参数Schema return ( NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), diff --git a/apps/services/parameter.py b/apps/services/parameter.py new file mode 100644 index 00000000..ae375e97 --- /dev/null +++ b/apps/services/parameter.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""flow Manager""" + +import logging + +from pymongo import ASCENDING + +from apps.services.node import NodeManager +from apps.schemas.flow_topology import FlowItem +from apps.scheduler.slot.slot import Slot +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, + Type +) +from apps.schemas.response_data import ( + OperateAndBindType, + ParamsNode, + StepParams, +) +from apps.services.node import NodeManager +logger = logging.getLogger(__name__) + + +class ParameterManager: + """Parameter Manager""" + @staticmethod + async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]: + """Get operate and bind type""" + result = [] + operate = None + if param_type == Type.NUMBER: + operate = NumberOperate + elif param_type == Type.STRING: + operate = StringOperate + elif param_type == Type.LIST: + operate = ListOperate + elif param_type == Type.BOOL: + operate = BoolOperate + elif param_type == Type.DICT: + operate = DictOperate + if operate: + for item in operate: + result.append(OperateAndBindType( + operate=item, + bind_type=ConditionHandler.get_value_type_from_operate(item))) + return result + + @staticmethod + async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]: + """Get pre params by flow and step id""" + index = 0 + q = [step_id] + in_edges = {} + step_id_to_node_id = {} + for step in flow.nodes: + step_id_to_node_id[step.step_id] = step.node_id + for edge in flow.edges: + if edge.target_node not in in_edges: + in_edges[edge.target_node] = [] + in_edges[edge.target_node].append(edge.source_node) + while index < len(q): + tmp_step_id = q[index] + index += 1 + for i in range(len(in_edges.get(tmp_step_id, []))): + pre_node_id = in_edges[tmp_step_id][i] + if pre_node_id not in q: + q.append(pre_node_id) + pre_step_params = [] + for step_id in q: + node_id = step_id_to_node_id.get(step_id) + params_schema, output_schema = await NodeManager.get_node_params(node_id) + slot = Slot(output_schema) + params_node = slot.get_params_node_from_schema(root='/output') + pre_step_params.append( + StepParams( + stepId=node_id, + name=params_schema.get("name", ""), + paramsNode=params_node + ) + ) + return pre_step_params -- Gitee From 09f16bad671439d7d6e33e7d10405b1b5dfc78e4 Mon Sep 17 00:00:00 2001 From: zxstty Date: Mon, 28 Jul 2025 14:52:27 +0800 Subject: [PATCH 20/60] =?UTF-8?q?=E5=AE=8C=E5=96=84Agent=E7=9A=84=E5=BC=80?= =?UTF-8?q?=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/enum_var.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 20e9c0f9..578d8121 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -56,6 +56,8 @@ class EventType(str, Enum): STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" FLOW_STOP = "flow.stop" + FLOW_FAILED = "flow.failed" + FLOW_CANCELLED = "flow.cancelled" DONE = "done" -- Gitee From 36b0e9b2bd6472758fa9409edfb9cbe64c24e7f3 Mon Sep 17 00:00:00 2001 From: zxstty Date: Mon, 28 Jul 2025 16:42:58 +0800 Subject: [PATCH 21/60] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=97=A0=E9=89=B4=E6=9D=83=E7=94=A8=E6=88=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/main.py | 16 ++++++++ .../call/choice/condition_handler.py | 41 ++++++++++--------- apps/schemas/request_data.py | 1 + 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/apps/main.py b/apps/main.py index c26e5e47..17d4abb4 100644 --- a/apps/main.py +++ b/apps/main.py @@ -83,12 +83,28 @@ logging.basicConfig( ) +async def add_no_auth_user() -> None: + """ + 添加无认证用户 + """ + from apps.common.mongo import MongoDB + from apps.schemas.collection import User + mongo = MongoDB() + user_collection = mongo.get_collection("user") + await user_collection.insert_one(User( + _id=Config().get_config().no_auth.user_sub, + is_admin=True, + ).model_dump(by_alias=True)) + + async def init_resources() -> None: """初始化必要资源""" WordsCheck() await LanceDB().init() await Pool.init() TokenCalculator() + if Config().get_config().no_auth.enable: + await add_no_auth_user() # 运行 if __name__ == "__main__": diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index 7542f294..6111ba90 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -83,7 +83,11 @@ class ConditionHandler(BaseModel): continue for condition in block_judgement.conditions: result = ConditionHandler._judge_condition(condition) - results.append(result) + if result is not None: + results.append(result) + if not results: + logger.warning(f"[Choice] 分支 {block_judgement.branch_id} 条件处理失败: 没有有效的条件") + continue if block_judgement.logic == Logic.AND: final_result = all(results) elif block_judgement.logic == Logic.OR: @@ -118,7 +122,7 @@ class ConditionHandler(BaseModel): if value_type == Type.STRING: result = ConditionHandler._judge_string_condition(left, operate, right) elif value_type == Type.NUMBER: - result = ConditionHandler._judge_int_condition(left, operate, right) + result = ConditionHandler._judge_number_condition(left, operate, right) elif value_type == Type.BOOL: result = ConditionHandler._judge_bool_condition(left, operate, right) elif value_type == Type.LIST: @@ -126,9 +130,9 @@ class ConditionHandler(BaseModel): elif value_type == Type.DICT: result = ConditionHandler._judge_dict_condition(left, operate, right) else: - logger.error("不支持的数据类型: %s", value_type) msg = f"不支持的数据类型: {value_type}" - raise ValueError(msg) + logger.error(f"[Choice] 条件处理失败: {msg}") + return None return result @staticmethod @@ -147,11 +151,10 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, str): - logger.error("左值不是字符串类型: %s", left_value) - msg = "左值必须是字符串类型" - raise TypeError(msg) + msg = f"左值必须是字符串类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value - result = False if operate == StringOperate.EQUAL: return left_value == right_value elif operate == StringOperate.NOT_EQUAL: @@ -195,9 +198,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, (int, float)): - logger.error("左值不是数字类型: %s", left_value) - msg = "左值必须是数字类型" - raise TypeError(msg) + msg = f"左值必须是数字类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value if operate == NumberOperate.EQUAL: return left_value == right_value @@ -229,9 +232,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, bool): - logger.error("左值不是布尔类型: %s", left_value) msg = "左值必须是布尔类型" - raise TypeError(msg) + logger.warning(msg) + return None right_value = right.value if operate == BoolOperate.EQUAL: return left_value == right_value @@ -259,9 +262,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, list): - logger.error("左值不是列表类型: %s", left_value) - msg = "左值必须是列表类型" - raise TypeError(msg) + msg = f"左值必须是列表类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value if operate == ListOperate.EQUAL: return left_value == right_value @@ -299,9 +302,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, dict): - logger.error("左值不是字典类型: %s", left_value) - msg = "左值必须是字典类型" - raise TypeError(msg) + msg = f"左值必须是字典类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value if operate == DictOperate.EQUAL: return left_value == right_value diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index a3a8848c..af0d4a01 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -107,6 +107,7 @@ class ActiveMCPServiceRequest(BaseModel): """POST /api/mcp/{serviceId} 请求数据结构""" active: bool = Field(description="是否激活mcp服务") + mcp_env: dict[str, Any] | None = Field(default=None, description="MCP服务环境变量", alias="mcpEnv") class UpdateServiceRequest(BaseModel): -- Gitee From e89bb6612f8bb9b406a17bfd5e92cd30d65e7c54 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 09:58:02 +0800 Subject: [PATCH 22/60] =?UTF-8?q?=E5=AE=8C=E5=96=84mcp=20agent=E7=9A=84pro?= =?UTF-8?q?mt&=E4=BF=AE=E5=A4=8D=E5=8E=BB=E9=99=A4=E5=86=97=E4=BD=99?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E5=88=86=E6=94=AF=E8=8A=82=E7=82=B9=E7=9B=B8?= =?UTF-8?q?=E5=85=B3bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/mcp_service.py | 3 +- apps/scheduler/mcp_agent/plan.py | 9 +- apps/scheduler/mcp_agent/prompt.py | 330 ++++++++++++++++++++++++++--- apps/scheduler/mcp_agent/select.py | 30 +-- apps/scheduler/pool/loader/mcp.py | 18 +- apps/schemas/enum_var.py | 2 +- apps/schemas/mcp.py | 4 +- apps/schemas/request_data.py | 2 +- apps/services/flow_validate.py | 20 +- apps/services/mcp_service.py | 4 +- 10 files changed, 356 insertions(+), 66 deletions(-) diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index a845a376..de484e78 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -36,6 +36,7 @@ router = APIRouter( dependencies=[Depends(verify_user)], ) + async def _check_user_admin(user_sub: str) -> None: user = await UserManager.get_userinfo_by_user_sub(user_sub) if not user: @@ -282,7 +283,7 @@ async def active_or_deactivate_mcp_service( """激活/取消激活mcp""" try: if data.active: - await MCPServiceManager.active_mcpservice(user_sub, service_id) + await MCPServiceManager.active_mcpservice(user_sub, service_id, data.mcp_env) else: await MCPServiceManager.deactive_mcpservice(user_sub, service_id) except Exception as e: diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index cd4f5975..47e803f6 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -25,7 +25,6 @@ class MCPPlanner: self.input_tokens = 0 self.output_tokens = 0 - async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 @@ -34,8 +33,10 @@ class MCPPlanner: # 解析为结构化数据 return await self._parse_plan_result(result, max_steps) - - async def _get_reasoning_plan(self, tool_list: list[MCPTool], max_steps: int) -> str: + async def _get_reasoning_plan( + self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(), + tool_list: list[MCPTool] = [], + max_steps: int = 10) -> str: """获取推理大模型的结果""" # 格式化Prompt template = self._env.from_string(CREATE_PLAN) @@ -66,7 +67,6 @@ class MCPPlanner: return result - async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: """将推理结果解析为结构化数据""" # 格式化Prompt @@ -85,7 +85,6 @@ class MCPPlanner: plan = await json_generator.generate() return MCPPlan.model_validate(plan) - async def generate_answer(self, plan: MCPPlan, memory: str) -> str: """生成最终回答""" template = self._env.from_string(FINAL_ANSWER) diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index b322fb08..649a67a9 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -61,6 +61,53 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: +""") +EVALUATE_GAOL = dedent(r""" + 你是一个计划评估器。 + 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 + 如果能够完成,请返回`true`,否则返回`false`。 + 推理过程必须清晰明了,能够让人理解你的判断依据。 + 必须按照以下格式回答: + ```json + { + "can_complete": true/false, + "resoning": "你的推理过程" + } + ``` + + # 样例 + ## 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + + ## 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - mysql_analyzer分析MySQL数据库性能 + - performance_tuner调优数据库性能 + - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf + + ## + ```json + { + "can_complete": true, + "resoning": "当前的工具集合中包含mysql_analyzer和performance_tuner,能够完成对MySQL数据库的性能分析和调优,因此可以完成用户的目标。" + } + ``` + + # 目标 + {{ goal }} + + # 工具集合 + {{ tools }} + + # 附加信息 + {{ additional_info }} + """) CREATE_PLAN = dedent(r""" 你是一个计划生成器。 @@ -163,78 +210,299 @@ CREATE_PLAN = dedent(r""" # 计划 """) -EVALUATE_PLAN = dedent(r""" - 你是一个计划评估器。 - 请根据给定的计划,和当前计划执行的实际情况,分析当前计划是否合理和完整,并生成改进后的计划。 +RECREATE_PLAN = dedent(r""" + 你是一个计划重建器。 + 请根据用户的目标、当前计划和运行报错,重新生成一个计划。 # 一个好的计划应该: 1. 能够成功完成用户的目标 2. 计划中的每一个步骤必须且只能使用一个工具。 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 - 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + 4. 你的计划必须避免之前的错误,并且能够成功执行。 + 5. 计划中的最后一步必须是Final工具,以确保计划执行结束。 - # 你此前的计划是: + # 生成计划时的注意事项: - {{ plan }} + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: - # 这个计划的执行情况是: + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` - 计划的执行情况将放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 - - {{ memory }} - + # 样例 - # 进行评估时的注意事项: + ## 目标 - - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。 - - 评估结果分为两个部分: - - 计划评估的结论 - - 改进后的计划 - - 请按照以下JSON格式输出评估结果: + 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。 + ## 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - command_generator生成命令行指令 + - tool_selector选择合适的工具 + - command_executor执行命令行指令 + - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + ## 当前计划 ```json { - "evaluation": "评估结果", "plans": [ { - "content": "改进后的计划内容", - "tool": "工具ID", - "instruction": "工具指令" + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + ## 运行报错 + 执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 + ## 重新生成的计划 + + + 1. 这个目标需要使用网络扫描工具来完成,首先需要选择合适的网络扫描工具 + 2. 目标可以拆解为以下几个部分: + - 生成端口扫描命令 + - 执行端口扫描命令 + 3.但是在执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 + 4.我将计划调整为: + - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具 + - 执行这个命令,查看当前机器支持哪些网络扫描工具 + - 然后从中选择一个网络扫描工具 + - 基于选择的网络扫描工具,生成端口扫描命令 + - 执行端口扫描命令 + + + ```json + { + "plans": [ + { + "content": "需要生成一条命令查看当前机器支持哪些网络扫描工具", + "tool": "command_generator", + "instruction": "选择一个前机器支持哪些网络扫描工具" + }, + { + "content": "执行Result[0]中生成的命令,查看当前机器支持哪些网络扫描工具", + "tool": "command_executor", + "instruction": "执行Result[0]中生成的命令" + }, + { + "content": "从Result[1]中选择一个网络扫描工具,生成端口扫描命令", + "tool": "tool_selector", + "instruction": "选择一个网络扫描工具,生成端口扫描命令" + }, + { + "content": "基于result[2]中选择的网络扫描工具,生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在Result[0]的MCP Server上执行Result[3]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[4]", + "tool": "Final", + "instruction": "" } ] } ``` - # 现在开始评估计划: + # 现在开始重新生成计划: + + # 目标 + + {{goal}} + + # 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + - Final结束步骤,当执行到这一步时,\ +表示计划执行结束,所得到的结果将作为最终结果。 + + + # 当前计划 + {{ current_plan }} + + # 运行报错 + {{ error_message }} + + # 重新生成的计划 """) +RISK_EVALUATE = dedent(r""" + 你是一个工具执行计划评估器。 + 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。 + ```json + { + "risk": "高/中/低", + "message": "提示信息" + } + ``` + # 样例 + ## 工具名称 + mysql_analyzer + ## 工具描述 + 分析MySQL数据库性能 + ## 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```ini + [mysqld] + innodb_buffer_pool_size=1G + innodb_log_file_size=256M + ``` + ## 输出 + ```json + { + "risk": "中", + "message": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" + } + ``` + # 工具名称 + {{ tool_name }} + # 工具描述 + {{ tool_description }} + # 工具入参 + {{ tool_input }} + # 附加信息 + {{ additional_info }} + # 输出 + """ + ) + +# 获取缺失的参数的json结构体 +GET_MISSING_PARAMS = dedent(r""" + 你是一个工具参数获取器。 + 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,获取当前工具缺失的参数并输出提示。 + ```json + { + "host": "请补充主机地址", + "port": "请补充端口号", + "username": "请补充用户名", + "password": "请补充密码" + } + ``` + # 样例 + ## 工具名称 + mysql_analyzer + ## 工具描述 + 分析MySQL数据库性能 + ## 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + ## 工具入参schema + { + "type": "object", + "properties": { + "host": {"type": "string", "description": "MySQL数据库的主机地址"}, + "port": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ], + "description": "MySQL数据库的端口号(可以是数字或字符串)" + }, + "username": {"type": "string", "description": "MySQL数据库的用户名"}, + "password": {"type": "string", "description": "MySQL数据库的密码"} + }, + "required": ["host", "port", "username", "password"] + } + ## 运行报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + ## 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "请补充用户名", + "password": "请补充密码" + } + ``` + # 工具名称 + {{ tool_name }} + # 工具描述 + {{ tool_description }} + # 工具入参 + {{ tool_input }} + # 工具入参schema + {{ tool_input_schema }} + # 运行报错 + {{ error_message }} + # 输出 + """ + ) FINAL_ANSWER = dedent(r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 # 用户目标 - {{ goal }} + {{goal}} # 计划执行情况 为了完成上述目标,你实施了以下计划: - {{ memory }} + {{memory}} # 其他背景信息: - {{ status }} + {{status}} # 现在,请根据以上信息,向用户报告目标的完成情况: """) + MEMORY_TEMPLATE = dedent(r""" - {% for ctx in context_list %} - - 第{{ loop.index }}步:{{ ctx.step_description }} - - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}` - - 执行状态:{{ ctx.status }} - - 得到数据:`{{ ctx.output_data }}` - {% endfor %} + { % for ctx in context_list % } + - 第{{loop.index}}步:{{ctx.step_description}} + - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` + - 执行状态:{{ctx.status}} + - 得到数据:`{{ctx.output_data}}` + { % endfor % } """) diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index 2ff50344..95f588e2 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -2,9 +2,10 @@ """选择MCP Server及其工具""" import logging - +import uuid from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment +from typing import AsyncGenerator from apps.common.lance import LanceDB from apps.common.mongo import MongoDB @@ -39,7 +40,6 @@ class MCPSelector: sql += f"'{mcp_id}', " return sql.rstrip(", ") + ")" - async def _get_top_mcp_by_embedding( self, query: str, @@ -72,7 +72,6 @@ class MCPSelector: }]) return llm_mcp_list - async def _get_mcp_by_llm( self, query: str, @@ -100,8 +99,7 @@ class MCPSelector: # 使用小模型提取JSON return await self._call_function_mcp(result, mcp_ids) - - async def _call_reasoning(self, prompt: str) -> str: + async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]: """调用大模型进行推理""" logger.info("[MCPHelper] 调用推理大模型") llm = ReasoningLLM() @@ -109,13 +107,8 @@ class MCPSelector: {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] - result = "" async for chunk in llm.call(message): - result += chunk - self.input_tokens += llm.input_tokens - self.output_tokens += llm.output_tokens - return result - + yield chunk async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: """调用结构化输出小模型提取JSON""" @@ -136,7 +129,6 @@ class MCPSelector: raise return result - async def select_top_mcp( self, query: str, @@ -153,7 +145,6 @@ class MCPSelector: # 通过LLM选择最合适的 return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) - @staticmethod async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: """选择最合适的工具""" @@ -181,5 +172,16 @@ class MCPSelector: async for tool in tool_data: tool_obj = MCPTool.model_validate(tool) llm_tool_list.append(tool_obj) - + llm_tool_list.append( + MCPTool( + id="00000000-0000-0000-0000-000000000000", + name="Final", + description="It is the final step, indicating the end of the plan execution.") + ) + llm_tool_list.append( + MCPTool( + id="00000000-0000-0000-0000-000000000001", + name="Chat", + description="It is a chat tool to communicate with the user.") + ) return llm_tool_list diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 1463d0a1..648baa53 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -11,6 +11,7 @@ import shutil import asyncer from anyio import Path from sqids.sqids import Sqids +from typing import Any from apps.common.lance import LanceDB from apps.common.mongo import MongoDB @@ -391,7 +392,7 @@ class MCPLoader(metaclass=SingletonMeta): ) @staticmethod - async def user_active_template(user_sub: str, mcp_id: str) -> None: + async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any]) -> None: """ 用户激活MCP模板 @@ -409,7 +410,8 @@ class MCPLoader(metaclass=SingletonMeta): if await user_path.exists(): err = f"MCP模板“{mcp_id}”已存在或有同名文件,无法激活" raise FileExistsError(err) - + mcp_config = await MCPLoader.get_config(mcp_id) + mcp_config.config.env.update(mcp_env) # 拷贝文件 await asyncer.asyncify(shutil.copytree)( template_path.as_posix(), @@ -417,7 +419,17 @@ class MCPLoader(metaclass=SingletonMeta): dirs_exist_ok=True, symlinks=True, ) - + user_config_path = user_path / "config.json" + # 更新用户配置 + f = await user_config_path.open("w", encoding="utf-8", errors="ignore") + await f.write( + json.dumps( + mcp_config.model_dump(by_alias=True, exclude_none=True), + indent=4, + ensure_ascii=False, + ) + ) + await f.aclose() # 更新数据库 mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 578d8121..d50eeaf7 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -97,7 +97,7 @@ class NodeType(str, Enum): START = "start" END = "end" NORMAL = "normal" - CHOICE = "choice" + CHOICE = "Choice" class SaveType(str, Enum): diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 60c8f17b..693caa09 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -118,7 +118,7 @@ class MCPToolSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - id: str = Field(default_factory=lambda: str(uuid.uuid4())) + step_id: str = Field(description="步骤的ID", default="") content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") @@ -127,4 +127,4 @@ class MCPPlanItem(BaseModel): class MCPPlan(BaseModel): """MCP 计划""" - plans: list[MCPPlanItem] = Field(description="计划列表") + plans: list[MCPPlanItem] = Field(description="计划列表", default=[]) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index af0d4a01..793ff456 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -107,7 +107,7 @@ class ActiveMCPServiceRequest(BaseModel): """POST /api/mcp/{serviceId} 请求数据结构""" active: bool = Field(description="是否激活mcp服务") - mcp_env: dict[str, Any] | None = Field(default=None, description="MCP服务环境变量", alias="mcpEnv") + mcp_env: dict[str, Any] = Field(default={}, description="MCP服务环境变量", alias="mcpEnv") class UpdateServiceRequest(BaseModel): diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index 78e8d340..d799c5dc 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -52,18 +52,26 @@ class FlowService: if node.call_id == NodeType.CHOICE.value: node.parameters = node.parameters["input_parameters"] if "choices" not in node.parameters: - node.parameters["choices"] = [] + logger.error(f"[FlowService] 节点{node.name}的分支字段缺失") + raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段缺失") + if not node.parameters["choices"]: + logger.error(f"[FlowService] 节点{node.name}的分支字段为空") + raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段为空") for choice in node.parameters["choices"]: - if choice["branchId"] in node_branch_map[node.step_id]: - err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}重复" + if "branch_id" not in choice: + err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段" + logger.error(err) + raise FlowBranchValidationError(err) + if choice["branch_id"] in node_branch_map[node.step_id]: + err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}重复" logger.error(err) raise Exception(err) for illegal_char in branch_illegal_chars: - if illegal_char in choice["branchId"]: - err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}名称中含有非法字符" + if illegal_char in choice["branch_id"]: + err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}名称中含有非法字符" logger.error(err) raise Exception(err) - node_branch_map[node.step_id].add(choice["branchId"]) + node_branch_map[node.step_id].add(choice["branch_id"]) else: node_branch_map[node.step_id].add("") valid_edges = [] diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 7cb880c0..2c84a211 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -198,7 +198,6 @@ class MCPServiceManager: base_filters = {"author": {"$regex": keyword, "$options": "i"}} return base_filters - @staticmethod async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: """ @@ -297,6 +296,7 @@ class MCPServiceManager: async def active_mcpservice( user_sub: str, service_id: str, + mcp_env: dict[str, Any] = {}, ) -> None: """ 激活MCP服务 @@ -310,7 +310,7 @@ class MCPServiceManager: for item in status: mcp_status = item.get("status", MCPInstallStatus.INSTALLING) if mcp_status == MCPInstallStatus.READY: - await MCPLoader.user_active_template(user_sub, service_id) + await MCPLoader.user_active_template(user_sub, service_id, mcp_env) else: err = "[MCPServiceManager] MCP服务未准备就绪" raise RuntimeError(err) -- Gitee From 7f48934e3c1bb5d240912a5da746cb01b93d8bdf Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 11:06:54 +0800 Subject: [PATCH 23/60] =?UTF-8?q?=E5=AE=8C=E5=96=84parameter=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/mcp_agent/prompt.py | 38 ++++++++++++++++++++++-------- apps/scheduler/slot/slot.py | 2 +- apps/services/parameter.py | 11 +++++---- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 649a67a9..579a5081 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -437,19 +437,37 @@ GET_MISSING_PARAMS = dedent(r""" "password": "password" } ## 工具入参schema - { - "type": "object", - "properties": { - "host": {"type": "string", "description": "MySQL数据库的主机地址"}, + { + "type": "object", + "properties": { + "host": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的主机地址(可以为字符串或null)" + }, "port": { "anyOf": [ {"type": "string"}, - {"type": "integer"}, + {"type": "null"} ], - "description": "MySQL数据库的端口号(可以是数字或字符串)" + "description": "MySQL数据库的端口号(可以是数字、字符串或null)" }, - "username": {"type": "string", "description": "MySQL数据库的用户名"}, - "password": {"type": "string", "description": "MySQL数据库的密码"} + "username": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的用户名(可以为字符串或null)" + }, + "password": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的密码(可以为字符串或null)" + } }, "required": ["host", "port", "username", "password"] } @@ -460,8 +478,8 @@ GET_MISSING_PARAMS = dedent(r""" { "host": "192.0.0.1", "port": 3306, - "username": "请补充用户名", - "password": "请补充密码" + "username": null, + "password": null } ``` # 工具名称 diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 89433cad..c1b27abf 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -257,7 +257,7 @@ class Slot: paramType=param_type, subParams=sub_params) try: - return _extract_params_node(self._schema, name=root, path=root) + return _extract_params_node(self._schema, name=root, path="/"+root) except Exception as e: logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") return None diff --git a/apps/services/parameter.py b/apps/services/parameter.py index ae375e97..259c4e45 100644 --- a/apps/services/parameter.py +++ b/apps/services/parameter.py @@ -57,8 +57,10 @@ class ParameterManager: q = [step_id] in_edges = {} step_id_to_node_id = {} + step_id_to_node_name = {} for step in flow.nodes: step_id_to_node_id[step.step_id] = step.node_id + step_id_to_node_name[step.step_id] = step.name for edge in flow.edges: if edge.target_node not in in_edges: in_edges[edge.target_node] = [] @@ -71,15 +73,16 @@ class ParameterManager: if pre_node_id not in q: q.append(pre_node_id) pre_step_params = [] - for step_id in q: + for i in range(1, len(q)): + step_id = q[i] node_id = step_id_to_node_id.get(step_id) params_schema, output_schema = await NodeManager.get_node_params(node_id) slot = Slot(output_schema) - params_node = slot.get_params_node_from_schema(root='/output') + params_node = slot.get_params_node_from_schema(root='output') pre_step_params.append( StepParams( - stepId=node_id, - name=params_schema.get("name", ""), + stepId=step_id, + name=step_id_to_node_name.get(step_id), paramsNode=params_node ) ) -- Gitee From 74e54c0851b960bf6cd686e83637eaea5f96f82a Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 11:58:41 +0800 Subject: [PATCH 24/60] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=B5=81=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/services/flow_validate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index d799c5dc..0b3a3e6e 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -141,7 +141,6 @@ class FlowService: branches = {} in_deg = {} out_deg = {} - for e in edges: if e.edge_id in ids: err = f"[FlowService] 边{e.edge_id}的id重复" -- Gitee From 62b13cba494a4412efbab1226bdd457106231f5b Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 15:05:18 +0800 Subject: [PATCH 25/60] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dchoice=E8=8A=82?= =?UTF-8?q?=E7=82=B9=E4=BF=9D=E5=AD=98=E5=8F=82=E6=95=B0=E4=B8=A2=E5=A4=B1?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/flow.py | 6 ++++++ apps/services/flow.py | 1 + apps/services/flow_validate.py | 8 ++++---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/apps/routers/flow.py b/apps/routers/flow.py index fc38c1bf..68c0c9f3 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -2,6 +2,7 @@ """FastAPI Flow拓扑结构展示API""" from typing import Annotated +import logging from fastapi import APIRouter, Body, Depends, Query, status from fastapi.responses import JSONResponse @@ -25,6 +26,8 @@ from apps.services.application import AppManager from apps.services.flow import FlowManager from apps.services.flow_validate import FlowService +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/api/flow", tags=["flow"], @@ -130,8 +133,11 @@ async def put_flow( ).model_dump(exclude_none=True, by_alias=True), ) put_body.flow = await FlowService.remove_excess_structure_from_flow(put_body.flow) + logger.error(f'{put_body.flow}') await FlowService.validate_flow_illegal(put_body.flow) + logger.error(f'{put_body.flow}') put_body.flow.connectivity = await FlowService.validate_flow_connectivity(put_body.flow) + logger.error(f'{put_body.flow}') result = await FlowManager.put_flow_by_app_and_flow_id(app_id, flow_id, put_body.flow) if result is None: return JSONResponse( diff --git a/apps/services/flow.py b/apps/services/flow.py index c9fd86fd..097908cd 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -413,6 +413,7 @@ class FlowManager: flow_config.debug = await FlowManager.is_flow_config_equal(old_flow_config, flow_config) else: flow_config.debug = False + logger.error(f'{flow_config}') await flow_loader.save(app_id, flow_id, flow_config) except Exception: logger.exception("[FlowManager] 存储/更新流失败") diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index 0b3a3e6e..66a2d087 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -50,14 +50,14 @@ class FlowService: logger.error(f"[FlowService] 获取步骤的call_id失败{node.call_id}由于:{e}") node_branch_map[node.step_id] = set() if node.call_id == NodeType.CHOICE.value: - node.parameters = node.parameters["input_parameters"] - if "choices" not in node.parameters: + input_parameters = node.parameters["input_parameters"] + if "choices" not in input_parameters: logger.error(f"[FlowService] 节点{node.name}的分支字段缺失") raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段缺失") - if not node.parameters["choices"]: + if not input_parameters["choices"]: logger.error(f"[FlowService] 节点{node.name}的分支字段为空") raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段为空") - for choice in node.parameters["choices"]: + for choice in input_parameters["choices"]: if "branch_id" not in choice: err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段" logger.error(err) -- Gitee From ddbb1f864d841d8edf8af0ccfe29255cb458dc63 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 16:44:28 +0800 Subject: [PATCH 26/60] =?UTF-8?q?=E4=BF=AE=E5=A4=8DGET=20/api/parameters?= =?UTF-8?q?=E6=97=B6=E7=94=B1=E7=A9=BA=E8=8A=82=E7=82=B9=E5=BC=95=E8=B5=B7?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/facts/facts.py | 9 +-- apps/scheduler/executor/flow.py | 16 +++- apps/scheduler/mcp/host.py | 2 +- apps/scheduler/mcp_agent/host.py | 4 +- apps/scheduler/mcp_agent/plan.py | 119 +++++++++++++++++++---------- apps/scheduler/mcp_agent/prompt.py | 41 ++++++++-- apps/scheduler/slot/slot.py | 6 +- apps/schemas/enum_var.py | 1 + apps/schemas/flow_topology.py | 5 +- apps/schemas/mcp.py | 7 ++ apps/services/flow.py | 5 +- apps/services/flow_validate.py | 7 +- apps/services/node.py | 5 +- 13 files changed, 156 insertions(+), 71 deletions(-) diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index f8aebcd7..2b9df0c6 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -30,13 +30,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): answer: str = Field(description="用户输入") - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" return CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。") - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """初始化工具""" @@ -51,7 +49,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): await obj._set_input(executor) return obj - async def _init(self, call_vars: CallVars) -> FactsInput: """初始化工具""" # 组装必要变量 @@ -65,7 +62,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): message=message, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" data = FactsInput(**input_data) @@ -83,7 +79,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): facts_obj: FactsGen = await self._json([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": facts_prompt}, - ], FactsGen) # type: ignore[arg-type] + ], FactsGen) # type: ignore[arg-type] # 更新用户画像 domain_tpl = env.from_string(DOMAIN_PROMPT) @@ -91,7 +87,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): domain_list: DomainGen = await self._json([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": domain_prompt}, - ], DomainGen) # type: ignore[arg-type] + ], DomainGen) # type: ignore[arg-type] for domain in domain_list.keywords: await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) @@ -104,7 +100,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ).model_dump(by_alias=True, exclude_none=True), ) - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" async for chunk in self._exec(input_data): diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index fbae705d..382ef929 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -172,6 +172,7 @@ class FlowExecutor(BaseExecutor): self.step_queue.append(first_step) self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] # 运行Flow(未达终点) + is_error = False while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] @@ -194,7 +195,7 @@ class FlowExecutor(BaseExecutor): enable_filling=False, to_user=False, )) - self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + is_error = True # 错误处理后结束 self._reached_end = True @@ -209,6 +210,12 @@ class FlowExecutor(BaseExecutor): for step in next_step: self.step_queue.append(step) + # 更新Task状态 + if is_error: + self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + else: + self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: self.step_queue.append(StepQueueItem( @@ -220,6 +227,7 @@ class FlowExecutor(BaseExecutor): # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 - await self.push_message(EventType.FLOW_STOP.value) - # 更新Task状态 - self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + if is_error: + await self.push_message(EventType.FLOW_FAILED.value) + else: + await self.push_message(EventType.FLOW_SUCCESS.value) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index acdd4871..aa196112 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -102,7 +102,7 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, - flow_status=StepStatus.SUCCESS, + flow_status=StepStatus.RUNNING, step_id=tool.name, step_name=tool.name, # description是规划的实际内容 diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index acdd4871..a8ebec7b 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -123,7 +123,7 @@ class MCPHost: return output_data - async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: + async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" @@ -139,7 +139,7 @@ class MCPHost: {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": await self.assemble_memory()}, ], - tool.input_schema, + schema, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 47e803f6..771115d2 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -1,19 +1,23 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" - +from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER -from apps.schemas.mcp import MCPPlan, MCPTool +from apps.llm.function import JsonGenerator +from apps.scheduler.mcp_agent.prompt import EVALUATE_GAOL, CREATE_PLAN, RECREATE_PLAN, FINAL_ANSWER +from apps.schemas.mcp import ( + GoalEvaluationResult, + MCPPlan, + MCPTool +) class MCPPlanner: """MCP 用户目标拆解与规划""" - def __init__(self, user_goal: str) -> None: + def __init__(self, user_goal: str, resoning_llm: ReasoningLLM = None) -> None: """初始化MCP规划器""" self.user_goal = user_goal self._env = SandboxedEnvironment( @@ -24,36 +28,17 @@ class MCPPlanner: ) self.input_tokens = 0 self.output_tokens = 0 + self.resoning_llm = resoning_llm or ReasoningLLM() - async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: - """规划下一步的执行流程,并输出""" - # 获取推理结果 - result = await self._get_reasoning_plan(tool_list, max_steps) - - # 解析为结构化数据 - return await self._parse_plan_result(result, max_steps) - - async def _get_reasoning_plan( - self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(), - tool_list: list[MCPTool] = [], - max_steps: int = 10) -> str: - """获取推理大模型的结果""" - # 格式化Prompt - template = self._env.from_string(CREATE_PLAN) - prompt = template.render( - goal=self.user_goal, - tools=tool_list, - max_num=max_steps, - ) - + async def get_resoning_result(self, prompt: str) -> str: + """获取推理结果""" # 调用推理大模型 message = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] - reasoning_llm = ReasoningLLM() result = "" - async for chunk in reasoning_llm.call( + async for chunk in self.resoning_llm.call( message, streaming=False, temperature=0.07, @@ -62,18 +47,12 @@ class MCPPlanner: result += chunk # 保存token用量 - self.input_tokens = reasoning_llm.input_tokens - self.output_tokens = reasoning_llm.output_tokens - + self.input_tokens = self.resoning_llm.input_tokens + self.output_tokens = self.resoning_llm.output_tokens return result - async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: - """将推理结果解析为结构化数据""" - # 格式化Prompt - schema = MCPPlan.model_json_schema() - schema["properties"]["plans"]["maxItems"] = max_steps - - # 使用Function模型解析结果 + async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: + """解析推理结果""" json_generator = JsonGenerator( result, [ @@ -82,7 +61,69 @@ class MCPPlanner: ], schema, ) - plan = await json_generator.generate() + json_result = await json_generator.generate() + return json_result + + async def evaluate_goal(self, tool_list: list[MCPTool]) -> str: + pass + + async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str: + """获取推理大模型的评估结果""" + template = self._env.from_string(EVALUATE_GAOL) + prompt = template.render( + goal=self.user_goal, + tools=tool_list, + ) + result = await self.get_resoning_result(prompt) + return result + + async def _parse_evaluation_result(self, result: str) -> str: + """将推理结果解析为结构化数据""" + schema = GoalEvaluationResult.model_json_schema() + evaluation = await self._parse_result(result, schema) + # 使用GoalEvaluationResult模型解析结果 + return GoalEvaluationResult.model_validate(evaluation) + + async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: + """规划下一步的执行流程,并输出""" + # 获取推理结果 + result = await self._get_reasoning_plan(tool_list, max_steps) + + # 解析为结构化数据 + return await self._parse_plan_result(result, max_steps) + + async def _get_reasoning_plan( + self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(), + tool_list: list[MCPTool] = [], + max_steps: int = 10) -> str: + """获取推理大模型的结果""" + # 格式化Prompt + if is_replan: + template = self._env.from_string(RECREATE_PLAN) + prompt = template.render( + current_plan=current_plan, + error_message=error_message, + goal=self.user_goal, + tools=tool_list, + max_num=max_steps, + ) + else: + template = self._env.from_string(CREATE_PLAN) + prompt = template.render( + goal=self.user_goal, + tools=tool_list, + max_num=max_steps, + ) + result = await self.get_resoning_result(prompt) + return result + + async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: + """将推理结果解析为结构化数据""" + # 格式化Prompt + schema = MCPPlan.model_json_schema() + schema["properties"]["plans"]["maxItems"] = max_steps + plan = await self._parse_result(result, schema) + # 使用Function模型解析结果 return MCPPlan.model_validate(plan) async def generate_answer(self, plan: MCPPlan, memory: str) -> str: diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 579a5081..7cef9e11 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -103,12 +103,45 @@ EVALUATE_GAOL = dedent(r""" {{ goal }} # 工具集合 - {{ tools }} + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + # 附加信息 {{ additional_info }} """) +GENERATE_FLOW_NAME = dedent(r""" + 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。 + + # 生成流程名称时的注意事项: + 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。 + 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 + 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。 + 4. 流程名称应该尽量简短,小于20个字或者单词。 + + - 必须按照如下格式生成流程名称,不要输出任何额外数据: + ```json + { + "flow_name": "生成的流程名称" + } + ``` + # 样例 + ## 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + ## 输出 + ```json + { + "flow_name": "MySQL性能分析与调优" + } + ``` + # 现在开始生成流程名称: + # 目标 + {{ goal }} + # 输出 + """) CREATE_PLAN = dedent(r""" 你是一个计划生成器。 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 @@ -153,8 +186,6 @@ CREATE_PLAN = dedent(r""" {% for tool in tools %} - {{ tool.id }}{{tool.name}};{{ tool.description }} {% endfor %} - - Final结束步骤,当执行到这一步时,\ -表示计划执行结束,所得到的结果将作为最终结果。 # 样例 @@ -352,8 +383,6 @@ RECREATE_PLAN = dedent(r""" {% for tool in tools %} - {{ tool.id }}{{tool.name}};{{ tool.description }} {% endfor %} - - Final结束步骤,当执行到这一步时,\ -表示计划执行结束,所得到的结果将作为最终结果。 # 当前计划 @@ -415,7 +444,7 @@ RISK_EVALUATE = dedent(r""" # 获取缺失的参数的json结构体 GET_MISSING_PARAMS = dedent(r""" 你是一个工具参数获取器。 - 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,获取当前工具缺失的参数并输出提示。 + 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。 ```json { "host": "请补充主机地址", diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index c1b27abf..045c4862 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -248,7 +248,9 @@ class Slot: if param_type == "object" and "properties" in schema_node: for key, value in schema_node["properties"].items(): - sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}") + if sub_param: + sub_params.append(sub_param) else: # 对于非对象类型,直接返回空子参数 sub_params = None @@ -257,7 +259,7 @@ class Slot: paramType=param_type, subParams=sub_params) try: - return _extract_params_node(self._schema, name=root, path="/"+root) + return _extract_params_node(self._schema, name=root, path="/" + root) except Exception as e: logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") return None diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index d50eeaf7..8458e103 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -57,6 +57,7 @@ class EventType(str, Enum): STEP_OUTPUT = "step.output" FLOW_STOP = "flow.stop" FLOW_FAILED = "flow.failed" + FLOW_SUCCESS = "flow.success" FLOW_CANCELLED = "flow.cancelled" DONE = "done" diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index d0ab666a..fecf2160 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -5,6 +5,7 @@ from typing import Any from pydantic import BaseModel, Field +from apps.schemas.enum_var import SpecialCallType from apps.schemas.enum_var import EdgeType @@ -51,7 +52,7 @@ class NodeItem(BaseModel): service_id: str = Field(alias="serviceId", default="") node_id: str = Field(alias="nodeId", default="") name: str = Field(default="") - call_id: str = Field(alias="callId", default="Empty") + call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value) description: str = Field(default="") enable: bool = Field(default=True) parameters: dict[str, Any] = Field(default={}) @@ -81,6 +82,6 @@ class FlowItem(BaseModel): nodes: list[NodeItem] = Field(default=[]) edges: list[EdgeItem] = Field(default=[]) created_at: float | None = Field(alias="createdAt", default=0) - connectivity: bool = Field(default=False,description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") + connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem()) debug: bool = Field(default=False) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 693caa09..2ee50061 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -104,6 +104,13 @@ class MCPToolVector(LanceModel): embedding: Vector(dim=1024) = Field(description="MCP工具描述的向量信息") # type: ignore[call-arg] +class GoalEvaluationResult(BaseModel): + """MCP 目标评估结果""" + + can_complete: bool = Field(description="是否可以完成目标") + reason: str = Field(description="评估原因") + + class MCPSelectResult(BaseModel): """MCP选择结果""" diff --git a/apps/services/flow.py b/apps/services/flow.py index 097908cd..4d682e5a 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -257,10 +257,7 @@ class FlowManager: ) for node_id, node_config in flow_config.steps.items(): input_parameters = node_config.params - if node_config.node not in ("Empty"): - _, output_parameters = await NodeManager.get_node_params(node_config.node) - else: - output_parameters = {} + _, output_parameters = await NodeManager.get_node_params(node_config.node) parameters = { "input_parameters": input_parameters, "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(), diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index 66a2d087..16c23053 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -4,6 +4,7 @@ import collections import logging +from apps.schemas.enum_var import SpecialCallType from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError from apps.schemas.enum_var import NodeType from apps.schemas.flow_topology import EdgeItem, FlowItem, NodeItem @@ -38,14 +39,14 @@ class FlowService: for node in flow_item.nodes: from apps.scheduler.pool.pool import Pool from pydantic import BaseModel - if node.node_id != 'start' and node.node_id != 'end' and node.node_id != 'Empty': + if node.node_id != 'start' and node.node_id != 'end' and node.node_id != SpecialCallType.EMPTY.value: try: call_class: type[BaseModel] = await Pool().get_call(node.call_id) if not call_class: - node.node_id = 'Empty' + node.node_id = SpecialCallType.EMPTY.value node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description except Exception as e: - node.node_id = 'Empty' + node.node_id = SpecialCallType.EMPTY.value node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description logger.error(f"[FlowService] 获取步骤的call_id失败{node.call_id}由于:{e}") node_branch_map[node.step_id] = set() diff --git a/apps/services/node.py b/apps/services/node.py index bf48e71f..3fb311bf 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -4,6 +4,7 @@ import logging from typing import TYPE_CHECKING, Any +from apps.schemas.enum_var import SpecialCallType from apps.common.mongo import MongoDB from apps.schemas.node import APINode from apps.schemas.pool import NodePool @@ -77,7 +78,9 @@ class NodeManager: async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" from apps.scheduler.pool.pool import Pool - + if node_id == SpecialCallType.EMPTY.value: + # 如果是空节点,返回空Schema + return {}, {} # 查找Node信息 logger.info("[NodeManager] 获取节点 %s", node_id) node_collection = MongoDB().get_collection("node") -- Gitee From e133af5d656ec84f1e09c40d9ce36ce3ccd9b07b Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 17:07:22 +0800 Subject: [PATCH 27/60] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8F=90=E5=8F=96node?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/slot/slot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 045c4862..58c90bac 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -246,7 +246,7 @@ class Slot: return None sub_params = [] - if param_type == "object" and "properties" in schema_node: + if param_type == Type.DICT and "properties" in schema_node: for key, value in schema_node["properties"].items(): sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}") if sub_param: -- Gitee From b9efb73808405a6f6516a00e34355c00de03979b Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 17:12:57 +0800 Subject: [PATCH 28/60] =?UTF-8?q?choice=E7=9A=84=E5=88=86=E6=94=AF?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E8=A1=A8=E8=BE=BE=E5=BC=8F=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/choice/schema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index b95b1668..d97a0c8d 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -29,6 +29,7 @@ class Value(DataBase): step_id: str | None = Field(description="步骤id", default=None) type: Type | None = Field(description="值的类型", default=None) + name: str | None = Field(description="值的名称", default=None) value: str | float | int | bool | list | dict | None = Field(description="值", default=None) -- Gitee From 3948571b23607a587ddeb56c212f8ff831bbb08a Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 17:32:04 +0800 Subject: [PATCH 29/60] =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E4=B8=8B=E7=94=A8=E6=88=B7=E5=90=8D=E7=A7=B0=E7=AD=89=E4=BA=8E?= =?UTF-8?q?=E5=BD=93=E5=89=8D=E6=93=8D=E4=BD=9C=E7=B3=BB=E7=BB=9F=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/user.py | 9 ++++++--- apps/main.py | 8 +++++++- apps/schemas/config.py | 1 - 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 87cbd290..8f5848a3 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -1,6 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """用户鉴权""" - +import os import logging from fastapi import Depends @@ -75,8 +75,11 @@ async def get_user(request: HTTPConnection) -> str: :return: 用户sub """ if Config().get_config().no_auth.enable: - # 如果启用了无认证访问,直接返回调试用户 - return Config().get_config().no_auth.user_sub + # 如果启用了无认证访问,直接返回当前操作系统用户的名称 + username = os.environ.get('USERNAME') # 适用于 Windows 系统 + if not username: + username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 + return username or "admin" session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/main.py b/apps/main.py index 17d4abb4..1646f671 100644 --- a/apps/main.py +++ b/apps/main.py @@ -89,10 +89,16 @@ async def add_no_auth_user() -> None: """ from apps.common.mongo import MongoDB from apps.schemas.collection import User + import os mongo = MongoDB() user_collection = mongo.get_collection("user") + username = os.environ.get('USERNAME') # 适用于 Windows 系统 + if not username: + username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 + if not username: + username = "admin" await user_collection.insert_one(User( - _id=Config().get_config().no_auth.user_sub, + _id=username, is_admin=True, ).model_dump(by_alias=True)) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index e91f5f75..675a9ba7 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -10,7 +10,6 @@ class NoauthConfig(BaseModel): """无认证配置""" enable: bool = Field(description="是否启用无认证访问", default=False) - user_sub: str = Field(description="调试用户的sub", default="admin") class DeployConfig(BaseModel): -- Gitee From ba0b20901eea32adf4363f7c7a6de8d31aa0b754 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 21:36:36 +0800 Subject: [PATCH 30/60] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=87=BA=E5=8F=82?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E3=80=81=E5=87=BA=E9=A4=90=E5=85=B3=E8=81=94?= =?UTF-8?q?=E5=92=8C=E9=BB=98=E8=AE=A4=E5=8F=82=E6=95=B0=E6=B8=B2=E6=9F=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/core.py | 4 +- apps/scheduler/call/slot/slot.py | 4 - apps/scheduler/mcp_agent/plan.py | 21 +++++- apps/scheduler/mcp_agent/prompt.py | 2 +- apps/scheduler/slot/slot.py | 115 ++++++++++++++++++++++------- apps/services/parameter.py | 2 +- 6 files changed, 109 insertions(+), 39 deletions(-) diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index af28c6a3..5bed8030 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -127,7 +127,7 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") - if len(split_path) < 2: + if len(split_path) < 1: err = f"[CoreCall] 路径格式错误: {path}" logger.error(err) return None @@ -136,7 +136,7 @@ class CoreCall(BaseModel): logger.error(err) return None data = history[split_path[0]].output_data - for key in split_path[2:]: + for key in split_path[1:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 4f8e1010..d24e1661 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -32,13 +32,11 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): facts: list[str] = Field(description="事实信息", default=[]) step_num: int = Field(description="历史步骤数", default=1) - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" return CallInfo(name="参数自动填充", description="根据步骤历史,自动填充参数") - async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]: """使用大模型填充参数;若大模型解析度足够,则直接返回结果""" env = SandboxedEnvironment( @@ -106,7 +104,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): await obj._set_input(executor) return obj - async def _init(self, call_vars: CallVars) -> SlotInput: """初始化""" self._flow_history = [] @@ -126,7 +123,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): remaining_schema=remaining_schema, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行参数填充""" data = SlotInput(**input_data) diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 771115d2..40bec6ef 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -6,7 +6,12 @@ from jinja2.sandbox import SandboxedEnvironment from apps.llm.reasoning import ReasoningLLM from apps.llm.function import JsonGenerator -from apps.scheduler.mcp_agent.prompt import EVALUATE_GAOL, CREATE_PLAN, RECREATE_PLAN, FINAL_ANSWER +from apps.scheduler.mcp_agent.prompt import ( + EVALUATE_GOAL, + CREATE_PLAN, + RECREATE_PLAN, + FINAL_ANSWER +) from apps.schemas.mcp import ( GoalEvaluationResult, MCPPlan, @@ -64,8 +69,16 @@ class MCPPlanner: json_result = await json_generator.generate() return json_result - async def evaluate_goal(self, tool_list: list[MCPTool]) -> str: - pass + async def evaluate_goal(self, tool_list: list[MCPTool]) -> GoalEvaluationResult: + """评估用户目标的可行性""" + # 获取推理结果 + result = await self._get_reasoning_evaluation(tool_list) + + # 解析为结构化数据 + evaluation = await self._parse_evaluation_result(result) + + # 返回评估结果 + return evaluation async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str: """获取推理大模型的评估结果""" @@ -77,7 +90,7 @@ class MCPPlanner: result = await self.get_resoning_result(prompt) return result - async def _parse_evaluation_result(self, result: str) -> str: + async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult: """将推理结果解析为结构化数据""" schema = GoalEvaluationResult.model_json_schema() evaluation = await self._parse_result(result, schema) diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 7cef9e11..8933f69a 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -62,7 +62,7 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: """) -EVALUATE_GAOL = dedent(r""" +EVALUATE_GOAL = dedent(r""" 你是一个计划评估器。 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 如果能够完成,请返回`true`,否则返回`false`。 diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 58c90bac..f5e5354f 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -128,7 +128,7 @@ class Slot: # Schema标准 return [_process_json_value(item, spec_data["items"]) for item in json_value] if spec_data["type"] == "object" and isinstance(json_value, dict): - # 若Schema不标准,则不进行处理 + # 若Schema不标准,则不进行处理F if "properties" not in spec_data: return json_value # Schema标准 @@ -156,35 +156,60 @@ class Slot: @staticmethod def _generate_example(schema_node: dict) -> Any: # noqa: PLR0911 """根据schema生成示例值""" + if "anyOf" in schema_node or "oneOf" in schema_node: + # 如果有anyOf,随机返回一个示例 + for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]: + example = Slot._generate_example(item) + if example is not None: + return example + + if "allOf" in schema_node: + # 如果有allOf,返回所有示例的合并 + example = None + for item in schema_node["allOf"]: + if example is None: + example = Slot._generate_example(item) + else: + other_example = Slot._generate_example(item) + if isinstance(example, dict) and isinstance(other_example, dict): + example.update(other_example) + else: + example = None + break + return example + if "default" in schema_node: return schema_node["default"] if "type" not in schema_node: return None - + type_value = schema_node["type"] + if isinstance(type_value, list): + # 如果是多类型,随机返回一个示例 + if len(type_value) > 1: + type_value = type_value[0] # 处理类型为 object 的节点 - if schema_node["type"] == "object": + if type_value == "object": data = {} properties = schema_node.get("properties", {}) for name, schema in properties.items(): data[name] = Slot._generate_example(schema) return data - # 处理类型为 array 的节点 - if schema_node["type"] == "array": + elif type_value == "array": items_schema = schema_node.get("items", {}) return [Slot._generate_example(items_schema)] # 处理类型为 string 的节点 - if schema_node["type"] == "string": + elif type_value == "string": return "" # 处理类型为 number 或 integer 的节点 - if schema_node["type"] in ["number", "integer"]: + elif type_value in ["number", "integer"]: return 0 # 处理类型为 boolean 的节点 - if schema_node["type"] == "boolean": + elif type_value == "boolean": return False # 处理其他类型或未定义类型 @@ -198,29 +223,63 @@ class Slot: """从JSON Schema中提取类型描述""" def _extract_type_desc(schema_node: dict[str, Any]) -> dict[str, Any]: - if "type" not in schema_node and "anyOf" not in schema_node: - return {} - data = {"type": schema_node.get("type", ""), "description": schema_node.get("description", "")} - if "anyOf" in schema_node: - data["type"] = "anyOf" - # 处理类型为 object 的节点 - if "anyOf" in schema_node: - data["items"] = {} - type_index = 0 - for type_index, sub_schema in enumerate(schema_node["anyOf"]): - sub_result = _extract_type_desc(sub_schema) - if sub_result: - data["items"]["type_"+str(type_index)] = sub_result - if schema_node.get("type", "") == "object": + # 处理组合关键字 + special_keys = ["anyOf", "allOf", "oneOf"] + for key in special_keys: + if key in schema_node: + data = { + "type": key, + "description": schema_node.get("description", ""), + "items": {}, + } + type_index = 0 + for item in schema_node[key]: + if isinstance(item, dict): + data["items"][f"item_{type_index}"] = _extract_type_desc(item) + else: + data["items"][f"item_{type_index}"] = {"type": item, "description": ""} + type_index += 1 + return data + # 处理基本类型 + type_val = schema_node.get("type", "") + description = schema_node.get("description", "") + + # 处理多类型数组 + if isinstance(type_val, list): + if len(type_val) > 1: + data = {"type": "union", "description": description, "items": {}} + type_index = 0 + for t in type_val: + if t == "object": + tmp_dict = {} + for key, val in schema_node.get("properties", {}).items(): + tmp_dict[key] = _extract_type_desc(val) + data["items"][f"item_{type_index}"] = tmp_dict + elif t == "array": + items_schema = schema_node.get("items", {}) + data["items"][f"item_{type_index}"] = _extract_type_desc(items_schema) + else: + data["items"][f"item_{type_index}"] = {"type": t, "description": description} + type_index += 1 + return data + elif len(type_val) == 1: + type_val = type_val[0] + else: + type_val = "" + + data = {"type": type_val, "description": description} + + # 递归处理对象和数组 + if type_val == "object": data["items"] = {} for key, val in schema_node.get("properties", {}).items(): data["items"][key] = _extract_type_desc(val) - - # 处理类型为 array 的节点 - if schema_node.get("type", "") == "array": + elif type_val == "array": items_schema = schema_node.get("items", {}) - data["items"] = _extract_type_desc(items_schema) + data["items"]["items"] = _extract_type_desc(items_schema) + return data + return _extract_type_desc(self._schema) def get_params_node_from_schema(self, root: str = "") -> ParamsNode: @@ -231,13 +290,15 @@ class Slot: return None param_type = schema_node["type"] + if isinstance(param_type, list): + return None # 不支持多类型 if param_type == "object": param_type = Type.DICT elif param_type == "array": param_type = Type.LIST elif param_type == "string": param_type = Type.STRING - elif param_type == "number": + elif param_type in ["number", "integer"]: param_type = Type.NUMBER elif param_type == "boolean": param_type = Type.BOOL diff --git a/apps/services/parameter.py b/apps/services/parameter.py index 259c4e45..c58fb39b 100644 --- a/apps/services/parameter.py +++ b/apps/services/parameter.py @@ -78,7 +78,7 @@ class ParameterManager: node_id = step_id_to_node_id.get(step_id) params_schema, output_schema = await NodeManager.get_node_params(node_id) slot = Slot(output_schema) - params_node = slot.get_params_node_from_schema(root='output') + params_node = slot.get_params_node_from_schema() pre_step_params.append( StepParams( stepId=step_id, -- Gitee From e36416e01b26c6082aa45755fa2be81feb40d3ae Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 29 Jul 2025 22:18:35 +0800 Subject: [PATCH 31/60] =?UTF-8?q?=E5=AE=8C=E5=96=84mcp=20agent=E7=9A=84?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/mcp_agent/plan.py | 69 +++++++++++++++++++++++++++++- apps/scheduler/mcp_agent/prompt.py | 42 ++++++++---------- apps/scheduler/mcp_agent/select.py | 7 +-- apps/scheduler/slot/slot.py | 52 ++++++++++++++++++++++ apps/schemas/mcp.py | 15 +++++++ 5 files changed, 155 insertions(+), 30 deletions(-) diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 40bec6ef..a7c1f132 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -8,15 +8,20 @@ from apps.llm.reasoning import ReasoningLLM from apps.llm.function import JsonGenerator from apps.scheduler.mcp_agent.prompt import ( EVALUATE_GOAL, + GENERATE_FLOW_NAME, CREATE_PLAN, RECREATE_PLAN, + RISK_EVALUATE, + GET_MISSING_PARAMS, FINAL_ANSWER ) from apps.schemas.mcp import ( GoalEvaluationResult, + ToolRisk, MCPPlan, MCPTool ) +from apps.scheduler.slot.slot import Slot class MCPPlanner: @@ -31,9 +36,9 @@ class MCPPlanner: trim_blocks=True, lstrip_blocks=True, ) + self.resoning_llm = resoning_llm or ReasoningLLM() self.input_tokens = 0 self.output_tokens = 0 - self.resoning_llm = resoning_llm or ReasoningLLM() async def get_resoning_result(self, prompt: str) -> str: """获取推理结果""" @@ -82,7 +87,7 @@ class MCPPlanner: async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str: """获取推理大模型的评估结果""" - template = self._env.from_string(EVALUATE_GAOL) + template = self._env.from_string(EVALUATE_GOAL) prompt = template.render( goal=self.user_goal, tools=tool_list, @@ -97,6 +102,18 @@ class MCPPlanner: # 使用GoalEvaluationResult模型解析结果 return GoalEvaluationResult.model_validate(evaluation) + async def get_flow_name(self) -> str: + """获取当前流程的名称""" + result = await self._get_reasoning_flow_name() + return result + + async def _get_reasoning_flow_name(self) -> str: + """获取推理大模型的流程名称""" + template = self._env.from_string(GENERATE_FLOW_NAME) + prompt = template.render(goal=self.user_goal) + result = await self.get_resoning_result(prompt) + return result + async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 @@ -139,6 +156,54 @@ class MCPPlanner: # 使用Function模型解析结果 return MCPPlan.model_validate(plan) + async def get_tool_risk(self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "") -> ToolRisk: + """获取MCP工具的风险评估结果""" + # 获取推理结果 + result = await self._get_reasoning_risk(tool, input_parm, additional_info) + + # 解析为结构化数据 + risk = await self._parse_risk_result(result) + + # 返回风险评估结果 + return risk + + async def _get_reasoning_risk(self, tool: MCPTool, input_param: dict[str, Any], additional_info: str) -> str: + """获取推理大模型的风险评估结果""" + template = self._env.from_string(RISK_EVALUATE) + prompt = template.render( + tool=tool, + input_param=input_param, + additional_info=additional_info, + ) + result = await self.get_resoning_result(prompt) + return result + + async def _parse_risk_result(self, result: str) -> ToolRisk: + """将推理结果解析为结构化数据""" + schema = ToolRisk.model_json_schema() + risk = await self._parse_result(result, schema) + # 使用ToolRisk模型解析结果 + return ToolRisk.model_validate(risk) + + async def get_missing_param( + self, tool: MCPTool, schema: dict[str, Any], + input_param: dict[str, Any], + error_message: str) -> list[str]: + """获取缺失的参数""" + slot = Slot(schema=schema) + schema_with_null = slot.add_null_to_basic_types() + template = self._env.from_string(GET_MISSING_PARAMS) + prompt = template.render( + tool=tool, + input_param=input_param, + schema=schema_with_null, + error_message=error_message, + ) + result = await self.get_resoning_result(prompt) + # 解析为结构化数据 + input_param_with_null = await self._parse_result(result, schema_with_null) + return input_param_with_null + async def generate_answer(self, plan: MCPPlan, memory: str) -> str: """生成最终回答""" template = self._env.from_string(FINAL_ANSWER) diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 8933f69a..74dd06c0 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -121,22 +121,12 @@ GENERATE_FLOW_NAME = dedent(r""" 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。 4. 流程名称应该尽量简短,小于20个字或者单词。 - - - 必须按照如下格式生成流程名称,不要输出任何额外数据: - ```json - { - "flow_name": "生成的流程名称" - } - ``` + 5. 只输出流程名称,不要输出其他内容。 # 样例 ## 目标 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 ## 输出 - ```json - { - "flow_name": "MySQL性能分析与调优" - } - ``` + 扫描MySQL数据库并分析性能瓶颈,进行调优 # 现在开始生成流程名称: # 目标 {{ goal }} @@ -398,7 +388,7 @@ RISK_EVALUATE = dedent(r""" 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。 ```json { - "risk": "高/中/低", + "risk": "low/medium/high", "message": "提示信息" } ``` @@ -429,12 +419,13 @@ RISK_EVALUATE = dedent(r""" "message": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" } ``` - # 工具名称 - {{ tool_name }} - # 工具描述 - {{ tool_description }} + # 工具 + + {{ tool.name }} + {{ tool.description }} + # 工具入参 - {{ tool_input }} + {{ input_param }} # 附加信息 {{ additional_info }} # 输出 @@ -511,14 +502,15 @@ GET_MISSING_PARAMS = dedent(r""" "password": null } ``` - # 工具名称 - {{ tool_name }} - # 工具描述 - {{ tool_description }} + # 工具 + + {{ tool.name }} + {{ tool.description }} + # 工具入参 - {{ tool_input }} - # 工具入参schema - {{ tool_input_schema }} + {{ input_param }} + # 工具入参schema(部分字段允许为null) + {{ input_schema }} # 运行报错 {{ error_message }} # 输出 diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index 95f588e2..37d1e752 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -7,6 +7,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from typing import AsyncGenerator +from apps.llm.reasoning import ReasoningLLM from apps.common.lance import LanceDB from apps.common.mongo import MongoDB from apps.llm.embedding import Embedding @@ -27,8 +28,9 @@ logger = logging.getLogger(__name__) class MCPSelector: """MCP选择器""" - def __init__(self) -> None: + def __init__(self, resoning_llm: ReasoningLLM = None) -> None: """初始化助手类""" + self.resoning_llm = resoning_llm or ReasoningLLM() self.input_tokens = 0 self.output_tokens = 0 @@ -102,12 +104,11 @@ class MCPSelector: async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]: """调用大模型进行推理""" logger.info("[MCPHelper] 调用推理大模型") - llm = ReasoningLLM() message = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] - async for chunk in llm.call(message): + async for chunk in self.resoning_llm.call(message): yield chunk async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index f5e5354f..40aaffce 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """参数槽位管理""" +import copy import json import logging import traceback @@ -475,3 +476,54 @@ class Slot: return schema_template return {} + + def add_null_to_basic_types(self) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + """ + def add_null_to_basic_types(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + + 参数: + schema (dict): 原始 JSON Schema + + 返回: + dict: 修改后的 JSON Schema + """ + # 如果不是字典类型(schema),直接返回 + if not isinstance(schema, dict): + return schema + + # 处理当前节点的 type 字段 + if 'type' in schema: + # 处理单一类型字符串 + if isinstance(schema['type'], str): + if schema['type'] in ['boolean', 'number', 'string', 'integer']: + schema['type'] = [schema['type'], 'null'] + + # 处理类型数组 + elif isinstance(schema['type'], list): + for i, t in enumerate(schema['type']): + if isinstance(t, str) and t in ['boolean', 'number', 'string', 'integer']: + if 'null' not in schema['type']: + schema['type'].append('null') + break + + # 递归处理 properties 字段(对象类型) + if 'properties' in schema: + for prop, prop_schema in schema['properties'].items(): + schema['properties'][prop] = add_null_to_basic_types(prop_schema) + + # 递归处理 items 字段(数组类型) + if 'items' in schema: + schema['items'] = add_null_to_basic_types(schema['items']) + + # 递归处理 anyOf, oneOf, allOf 字段 + for keyword in ['anyOf', 'oneOf', 'allOf']: + if keyword in schema: + schema[keyword] = [add_null_to_basic_types(sub_schema) for sub_schema in schema[keyword]] + + return schema + schema_copy = copy.deepcopy(self._schema) + return add_null_to_basic_types(schema_copy) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 2ee50061..959a273c 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -111,6 +111,21 @@ class GoalEvaluationResult(BaseModel): reason: str = Field(description="评估原因") +class Risk(Enum, str): + """MCP工具风险类型""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class ToolRisk(BaseModel): + """MCP工具风险评估结果""" + + risk: Risk = Field(description="风险类型", default=Risk.LOW) + reason: str = Field(description="风险原因", default="") + + class MCPSelectResult(BaseModel): """MCP选择结果""" -- Gitee From 734c8c0cc8b22e9eeebd4a772336ffc508bac123 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 09:32:22 +0800 Subject: [PATCH 32/60] fix bug --- apps/schemas/mcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 959a273c..368865ac 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -111,7 +111,7 @@ class GoalEvaluationResult(BaseModel): reason: str = Field(description="评估原因") -class Risk(Enum, str): +class Risk(str, Enum): """MCP工具风险类型""" LOW = "low" -- Gitee From e377e5aca1555e184195ed0e5f522181567aed03 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 09:35:23 +0800 Subject: [PATCH 33/60] fix bug --- apps/scheduler/slot/slot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 40aaffce..74ea3da8 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -321,7 +321,7 @@ class Slot: paramType=param_type, subParams=sub_params) try: - return _extract_params_node(self._schema, name=root, path="/" + root) + return _extract_params_node(self._schema, name=root, path=root) except Exception as e: logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") return None -- Gitee From adcf1eb10ea8e3ad10acd4065d64e1dd77124f07 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 09:46:30 +0800 Subject: [PATCH 34/60] fix bug --- apps/scheduler/slot/slot.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 74ea3da8..41c516d2 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -268,17 +268,23 @@ class Slot: else: type_val = "" - data = {"type": type_val, "description": description} + data = {"type": type_val, "description": description, " items": {}} # 递归处理对象和数组 if type_val == "object": - data["items"] = {} for key, val in schema_node.get("properties", {}).items(): data["items"][key] = _extract_type_desc(val) elif type_val == "array": items_schema = schema_node.get("items", {}) - data["items"]["items"] = _extract_type_desc(items_schema) - + if isinstance(items_schema, list): + item_index = 0 + for item in items_schema: + data["items"][f"item_{item_index}"] = _extract_type_desc(item) + item_index += 1 + else: + data["items"]["item"] = _extract_type_desc(items_schema) + if data["items"] == {}: + del data["items"] return data return _extract_type_desc(self._schema) -- Gitee From 254fdc62a44fcfe00381269e9f7facc1e0971be7 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 09:51:57 +0800 Subject: [PATCH 35/60] fix bug --- apps/scheduler/slot/slot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 41c516d2..7a6313a1 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -268,7 +268,7 @@ class Slot: else: type_val = "" - data = {"type": type_val, "description": description, " items": {}} + data = {"type": type_val, "description": description, "items": {}} # 递归处理对象和数组 if type_val == "object": -- Gitee From 52564e5c7e1f1b967e5453ccd23eca88767f5e2f Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 12:05:16 +0800 Subject: [PATCH 36/60] =?UTF-8?q?=E9=80=9A=E8=BF=87=E7=9B=91=E8=A7=86?= =?UTF-8?q?=E6=9C=BA=E5=88=B6=E6=9D=A5=E6=9A=82=E5=81=9C=E9=97=AE=E7=AD=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 5 +- apps/routers/record.py | 18 +++-- apps/scheduler/mcp_agent/prompt.py | 97 ++++++++++++++++++++++----- apps/scheduler/scheduler/context.py | 8 ++- apps/scheduler/scheduler/message.py | 25 ++++--- apps/scheduler/scheduler/scheduler.py | 57 ++++++++++++++-- apps/schemas/enum_var.py | 3 +- apps/schemas/record.py | 13 +++- apps/schemas/task.py | 16 +++-- apps/services/rag.py | 3 - apps/services/task.py | 2 +- 11 files changed, 190 insertions(+), 57 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 589000be..26a87481 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue from apps.common.wordscheck import WordsCheck from apps.dependency import get_session, get_user +from apps.schemas.enum_var import FlowStatus from apps.scheduler.scheduler import Scheduler from apps.scheduler.scheduler.context import save_data from apps.schemas.request_data import RequestData @@ -82,8 +83,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 获取最终答案 task = scheduler.task - if not task.runtime.answer: - logger.error("[Chat] 答案为空") + if task.state.flow_status == FlowStatus.ERROR: + logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" await Activity.remove_active(user_sub) return diff --git a/apps/routers/record.py b/apps/routers/record.py index f357f0de..367138e2 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -81,19 +81,17 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ # 获得Record关联的文档 tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) - + tmp_record.flow = RecordFlow( + id=record.flow_history.flow_id, # TODO: 此处前端应该用name + recordId=record.id, + flowStatus=record.flow_history.flow_staus, + flowId=record.flow_history.flow_id, + stepNum=len(flow_step_list), + steps=[], + ) # 获得Record关联的flow数据 flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) if flow_step_list: - first_step_history = FlowStepHistory.model_validate(flow_step_list[0]) - tmp_record.flow = RecordFlow( - id=first_step_history.flow_name, # TODO: 此处前端应该用name - recordId=record.id, - flowStatus=first_step_history.flow_status, - flowId=first_step_history.id, - stepNum=len(flow_step_list), - steps=[], - ) for flow_step in flow_step_list: flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 74dd06c0..cf05afc8 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -115,7 +115,7 @@ EVALUATE_GOAL = dedent(r""" """) GENERATE_FLOW_NAME = dedent(r""" 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。 - + # 生成流程名称时的注意事项: 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 @@ -431,7 +431,74 @@ RISK_EVALUATE = dedent(r""" # 输出 """ ) - +# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划 +JUDGE_NEXT_STEP = dedent(r""" + 你是一个计划决策器。 + 你的任务是根据当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 + 请根据以下规则进行判断: + 1. 仅通过补充工具入参来解决问题的,返回 fill_params; + 2. 需要重计划当前步骤的,返回 replan_current_step; + 3. 需要重计划接下来的所有计划的,返回 replan_all_steps; + 你的输出要以json格式返回,格式如下: + ```json + { + "next_step": "fill_params/replan_current_step/replan_all_steps", + "reason": "你的判断依据" + } + ``` + 注意: + reason字段必须清晰明了,能够让人理解你的判断依据,并且不超过50个中文字或者100个英文单词。 + # 样例 + ## 当前计划 + {"plans": [ + { + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ]} + ## 当前使用的工具 + + command_executor + 执行命令行指令 + + ## 工具入参 + { + "command": "nmap -sS -p--open 192.168.1.1" + } + ## 工具运行报错 + 执行端口扫描命令时,出现了错误:`-bash: nmap: command not found`。 + ## 输出 + ```json + { + "next_step": "replan_all_steps", + "reason": "当前工具执行报错,提示nmap命令未找到,需要增加command_generator和command_executor的步骤,生成nmap安装命令并执行,之后再生成端口扫描命令并执行。" + } + ``` + # 当前计划 + {{ current_plan }} + # 当前使用的工具 + + {{ tool.name }} + {{ tool.description }} + + # 工具入参 + {{ input_param }} + # 工具运行报错 + {{ error_message }} + # 输出 + """ + ) # 获取缺失的参数的json结构体 GET_MISSING_PARAMS = dedent(r""" 你是一个工具参数获取器。 @@ -445,18 +512,18 @@ GET_MISSING_PARAMS = dedent(r""" } ``` # 样例 - ## 工具名称 + # 工具名称 mysql_analyzer - ## 工具描述 + # 工具描述 分析MySQL数据库性能 - ## 工具入参 + # 工具入参 { "host": "192.0.0.1", "port": 3306, "username": "root", "password": "password" } - ## 工具入参schema + # 工具入参schema { "type": "object", "properties": { @@ -491,9 +558,9 @@ GET_MISSING_PARAMS = dedent(r""" }, "required": ["host", "port", "username", "password"] } - ## 运行报错 + # 运行报错 执行端口扫描命令时,出现了错误:`password is not correct`。 - ## 输出 + # 输出 ```json { "host": "192.0.0.1", @@ -503,16 +570,16 @@ GET_MISSING_PARAMS = dedent(r""" } ``` # 工具 - - {{ tool.name }} - {{ tool.description }} - + < tool > + < name > {{tool.name}} < /name > + < description > {{tool.description}} < /description > + < / tool > # 工具入参 - {{ input_param }} + {{input_param}} # 工具入参schema(部分字段允许为null) - {{ input_schema }} + {{input_schema}} # 运行报错 - {{ error_message }} + {{error_message}} # 输出 """ ) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 32331cf3..d7ce8652 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -10,6 +10,7 @@ from apps.llm.patterns.facts import Facts from apps.schemas.collection import Document from apps.schemas.enum_var import StepStatus from apps.schemas.record import ( + FlowHistory, Record, RecordContent, RecordDocument, @@ -188,7 +189,12 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow=[i["_id"] for i in task.context], + flow_history=FlowHistory( + flow_id=task.state.flow_id, + flow_name=task.state.flow_name, + flow_status=task.state.flow_status, + history_ids=[context["_id"] for context in task.context], + ) ) # 检查是否存在group_id diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index a2a45e41..d43ba7fb 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -15,6 +15,7 @@ from apps.schemas.message import ( InitContentFeature, TextAddContent, ) +from apps.schemas.enum_var import FlowStatus from apps.schemas.rag_data import RAGEventData, RAGQueryReq from apps.schemas.record import RecordDocument from apps.schemas.task import Task @@ -61,21 +62,25 @@ async def push_init_message( async def push_rag_message( task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], - rag_data: RAGQueryReq,) -> Task: + rag_data: RAGQueryReq,) -> None: """推送RAG消息""" full_answer = "" - - async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): - task, content_obj = await _push_rag_chunk(task, queue, chunk) - if content_obj.event_type == EventType.TEXT_ADD.value: - # 如果是文本消息,直接拼接到答案中 - full_answer += content_obj.content - elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append(content_obj.content) + try: + async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): + task, content_obj = await _push_rag_chunk(task, queue, chunk) + if content_obj.event_type == EventType.TEXT_ADD.value: + # 如果是文本消息,直接拼接到答案中 + full_answer += content_obj.content + elif content_obj.event_type == EventType.DOCUMENT_ADD.value: + task.runtime.documents.append(content_obj.content) + task.state.flow_status = FlowStatus.SUCCESS + except Exception as e: + logger.error(f"[Scheduler] RAG服务发生错误: {e}") + task.state.flow_status = FlowStatus.ERROR # 保存答案 task.runtime.answer = full_answer + task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time await TaskManager.save_task(task.id, task) - return task async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]: diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 417f93d2..91930f8c 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Scheduler模块""" +import asyncio import logging from datetime import UTC, datetime @@ -17,11 +18,12 @@ from apps.scheduler.scheduler.message import ( push_rag_message, ) from apps.schemas.collection import LLM -from apps.schemas.enum_var import AppType, EventType +from apps.schemas.enum_var import FlowStatus, AppType, EventType from apps.schemas.pool import AppPool from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground +from apps.services.activity import Activity from apps.schemas.task import Task from apps.services.appcenter import AppCenterManager from apps.services.knowledge import KnowledgeBaseManager @@ -41,10 +43,30 @@ class Scheduler: """初始化""" self.used_docs = [] self.task = task - self.queue = queue self.post_body = post_body + async def _monitor_activity(self, kill_event, user_sub): + """监控用户活动状态,不活跃时终止工作流""" + try: + check_interval = 0.5 # 每0.5秒检查一次 + + while not kill_event.is_set(): + # 检查用户活动状态 + is_active = await Activity.is_active(user_sub) + + if not is_active: + logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_sub) + kill_event.set() + break + + # 控制检查频率 + await asyncio.sleep(check_interval) + except asyncio.CancelledError: + logger.info("[Scheduler] 活动监控任务已取消") + except Exception as e: + logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") + async def run(self) -> None: # noqa: PLR0911 """运行调度器""" try: @@ -95,6 +117,9 @@ class Scheduler: # 如果是智能问答,直接执行 logger.info("[Scheduler] 开始执行") + # 创建用于通信的事件 + kill_event = asyncio.Event() + monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub)) if not self.post_body.app or self.post_body.app.app_id == "": self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( @@ -102,8 +127,11 @@ class Scheduler: query=self.post_body.question, tokensLimit=llm.max_tokens, ) - self.task = await push_rag_message(self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data) - self.task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time + + # 启动监控任务和主任务 + main_task = asyncio.create_task(push_rag_message( + self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data)) + else: # 查找对应的App元数据 app_data = await AppCenterManager.fetch_app_data_by_id(self.post_body.app.app_id) @@ -127,8 +155,27 @@ class Scheduler: conversation=context, facts=facts, ) - await self.run_executor(self.queue, self.post_body, executor_background) + # 启动监控任务和主任务 + main_task = asyncio.create_task(self.run_executor(self.queue, self.post_body, executor_background)) + # 等待任一任务完成 + done, pending = await asyncio.wait( + [main_task, monitor], + return_when=asyncio.FIRST_COMPLETED + ) + + # 如果是监控任务触发,终止主任务 + if kill_event.is_set(): + logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...") + main_task.cancel() + need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING] + if self.task.state.flow_status in need_change_cancel_flow_state: + self.task.state.flow_status = FlowStatus.CANCELLED + try: + await main_task + logger.info("[Scheduler] 工作流执行已被终止") + except Exception as e: + logger.error(f"[Scheduler] 终止工作流时发生错误: {e}") # 更新Task,发送结束消息 logger.info("[Scheduler] 发送结束消息") await self.queue.push_output(self.task, event_type=EventType.DONE.value, data={}) diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 8458e103..49a6c250 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -14,7 +14,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" - + UNKNOWN = "unknown" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" @@ -26,6 +26,7 @@ class StepStatus(str, Enum): class FlowStatus(str, Enum): """Flow状态""" + UNKNOWN = "unknown" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 0c3d7185..3dd81ca9 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from apps.schemas.collection import ( Document, ) -from apps.schemas.enum_var import CommentType, StepStatus +from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus class RecordDocument(Document): @@ -116,6 +116,14 @@ class RecordGroupDocument(BaseModel): created_at: float = Field(default=0.0, description="文档创建时间") +class FlowHistory(BaseModel): + """Flow执行历史""" + flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + flow_name: str = Field(default="", description="Flow名称") + flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态") + history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表") + + class Record(RecordData): """问答,用于保存在MongoDB中""" @@ -123,7 +131,8 @@ class Record(RecordData): key: dict[str, Any] = {} content: str comment: RecordComment = Field(default=RecordComment()) - flow: list[str] = Field(default=[]) + flow_history: FlowHistory = Field( + default=FlowHistory(), description="Flow执行历史信息") class RecordGroup(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 98d8c6b3..586ed1da 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -37,15 +37,17 @@ class ExecutorState(BaseModel): """FlowExecutor状态""" # 执行器级数据 - flow_id: str = Field(description="Flow ID") - flow_name: str = Field(description="Flow名称") - description: str = Field(description="Flow描述") - flow_status: FlowStatus = Field(description="Flow状态") + flow_id: str = Field(description="Flow ID", default="") + flow_name: str = Field(description="Flow名称", default="") + description: str = Field(description="Flow描述", default="") + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN) # 任务级数据 - step_id: str = Field(description="当前步骤ID") - step_name: str = Field(description="当前步骤名称") - step_status: StepStatus = Field(description="当前步骤状态") + step_id: str = Field(description="当前步骤ID", default="") + step_name: str = Field(description="当前步骤名称", default="") + step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") + retry_times: int = Field(description="当前步骤重试次数", default=0) + error_message: str = Field(description="当前步骤错误信息", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) diff --git a/apps/services/rag.py b/apps/services/rag.py index efbdfe94..8b95d2b1 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -18,7 +18,6 @@ from apps.schemas.collection import LLM from apps.schemas.config import LLMConfig from apps.schemas.enum_var import EventType from apps.schemas.rag_data import RAGQueryReq -from apps.services.activity import Activity from apps.services.session import SessionManager logger = logging.getLogger(__name__) @@ -257,8 +256,6 @@ class RAG: result_only=False, model=llm.model_name, ): - if not await Activity.is_active(user_sub): - return chunk = buffer + chunk # 防止脚注被截断 if len(chunk) >= 2 and chunk[-2:] != "]]": diff --git a/apps/services/task.py b/apps/services/task.py index 2456d96b..39085305 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -82,7 +82,7 @@ class TaskManager: return [] flow_context_list = [] - for flow_context_id in records[0]["records"]["flow"]: + for flow_context_id in records[0]["records"]["flow_history"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: flow_context_list.append(flow_context) -- Gitee From eb97467fcfc56ae939c4b6b3aec2ee18c5ebd5ac Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 15:39:53 +0800 Subject: [PATCH 37/60] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=BF=90=E7=AE=97?= =?UTF-8?q?=E7=AC=A6=E5=8F=B7=E8=BF=94=E5=9B=9Ebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/services/parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/services/parameter.py b/apps/services/parameter.py index c58fb39b..8f83eab0 100644 --- a/apps/services/parameter.py +++ b/apps/services/parameter.py @@ -47,7 +47,7 @@ class ParameterManager: for item in operate: result.append(OperateAndBindType( operate=item, - bind_type=ConditionHandler.get_value_type_from_operate(item))) + bind_type=(await ConditionHandler.get_value_type_from_operate(item)))) return result @staticmethod -- Gitee From 86397ffb313882c0f113bfd92ab945936ee7a874 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 30 Jul 2025 16:39:32 +0800 Subject: [PATCH 38/60] =?UTF-8?q?=E5=AE=8C=E5=96=84record=E6=94=B9?= =?UTF-8?q?=E9=80=A0=E5=90=8E=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/record.py | 16 ++++++++-------- apps/scheduler/scheduler/context.py | 2 +- apps/schemas/record.py | 2 +- apps/schemas/task.py | 4 ++-- apps/services/task.py | 3 +-- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/apps/routers/record.py b/apps/routers/record.py index 367138e2..29ff68e0 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -81,17 +81,17 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ # 获得Record关联的文档 tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) - tmp_record.flow = RecordFlow( - id=record.flow_history.flow_id, # TODO: 此处前端应该用name - recordId=record.id, - flowStatus=record.flow_history.flow_staus, - flowId=record.flow_history.flow_id, - stepNum=len(flow_step_list), - steps=[], - ) # 获得Record关联的flow数据 flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) if flow_step_list: + tmp_record.flow = RecordFlow( + id=record.flow.flow_id, # TODO: 此处前端应该用name + recordId=record.id, + flowStatus=record.flow.flow_staus, + flowId=record.flow.flow_id, + stepNum=len(flow_step_list), + steps=[], + ) for flow_step in flow_step_list: flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index d7ce8652..109a5456 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -189,7 +189,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow_history=FlowHistory( + flow=FlowHistory( flow_id=task.state.flow_id, flow_name=task.state.flow_name, flow_status=task.state.flow_status, diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 3dd81ca9..d3222762 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -131,7 +131,7 @@ class Record(RecordData): key: dict[str, Any] = {} content: str comment: RecordComment = Field(default=RecordComment()) - flow_history: FlowHistory = Field( + flow: FlowHistory = Field( default=FlowHistory(), description="Flow执行历史信息") diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 586ed1da..e1901416 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -48,7 +48,7 @@ class ExecutorState(BaseModel): step_description: str = Field(description="当前步骤描述", default="") retry_times: int = Field(description="当前步骤重试次数", default=0) error_message: str = Field(description="当前步骤错误信息", default="") - app_id: str = Field(description="应用ID") + app_id: str = Field(description="应用ID", default="") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -94,7 +94,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) - state: ExecutorState | None = Field(description="Flow的状态", default=None) + state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/services/task.py b/apps/services/task.py index 39085305..dceee324 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -82,7 +82,7 @@ class TaskManager: return [] flow_context_list = [] - for flow_context_id in records[0]["records"]["flow_history"]["history_ids"]: + for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: flow_context_list.append(flow_context) @@ -199,7 +199,6 @@ class TaskManager: conversation_id=post_body.conversation_id, group_id=post_body.group_id if post_body.group_id else "", ), - state=None, tokens=TaskTokens(), runtime=TaskRuntime(), ) -- Gitee From f9cfd5975219b2f00e3026162918fd5281105568 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 00:09:40 +0800 Subject: [PATCH 39/60] =?UTF-8?q?=E5=AE=8C=E5=96=84choice=E8=8A=82?= =?UTF-8?q?=E7=82=B9=E9=BB=98=E8=AE=A4=E5=88=86=E6=94=AF=E7=9A=84=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/choice/choice.py | 2 +- apps/scheduler/call/choice/condition_handler.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 8cab8288..01ac7106 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -113,7 +113,7 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): valid_conditions.append(condition) # 如果所有条件都无效,抛出异常 - if not valid_conditions: + if not valid_conditions and not choice.is_default: msg = "分支没有有效条件" logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") continue diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index 6111ba90..6f10f2c8 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -75,12 +75,11 @@ class ConditionHandler(BaseModel): @staticmethod def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" - default_branch = [c for c in choices if c.is_default] - for block_judgement in choices: + for block_judgement in choices[::-1]: results = [] if block_judgement.is_default: - continue + return block_judgement.branch_id for condition in block_judgement.conditions: result = ConditionHandler._judge_condition(condition) if result is not None: @@ -96,9 +95,6 @@ class ConditionHandler(BaseModel): if final_result: return block_judgement.branch_id - # 如果没有匹配的分支,选择默认分支 - if default_branch: - return default_branch[0].branch_id return "" @staticmethod -- Gitee From cf4d309de30a924705a02194aae4568ad4459c21 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 15:15:08 +0800 Subject: [PATCH 40/60] =?UTF-8?q?=E5=9F=BA=E4=BA=8E=E5=BD=93=E5=89=8Dnode?= =?UTF-8?q?=E7=9A=84type=E5=88=A4=E6=96=AD=E6=80=8E=E4=B9=88=E8=B5=B0?= =?UTF-8?q?=E4=B8=8B=E4=B8=80=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/flow.py | 22 +++++++++++++--------- apps/schemas/enum_var.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 382ef929..157c9159 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -46,6 +46,10 @@ class FlowExecutor(BaseExecutor): flow_id: str = Field(description="Flow ID") question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") + current_step: StepQueueItem | None = Field( + description="当前执行的步骤", + default=None + ) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" @@ -70,13 +74,13 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: + async def _invoke_runner(self) -> None: """单一Step执行""" # 创建步骤Runner step_runner = StepExecutor( msg_queue=self.msg_queue, task=self.task, - step=queue_item, + step=self.current_step, background=self.background, question=self.question, ) @@ -84,8 +88,8 @@ class FlowExecutor(BaseExecutor): # 初始化步骤 await step_runner.init() # 运行Step - await step_runner.run() + await step_runner.run() # 更新Task(已存过库) self.task = step_runner.task @@ -93,12 +97,12 @@ class FlowExecutor(BaseExecutor): """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: try: - queue_item = self.step_queue.pop() + self.current_step = self.step_queue.pop() except IndexError: break # 执行Step - await self._invoke_runner(queue_item) + await self._invoke_runner() async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" @@ -113,17 +117,17 @@ class FlowExecutor(BaseExecutor): # 如果当前步骤为结束,则直接返回 if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - if self.task.state.step_name == "Choice": + if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID branch_id = self.task.context[-1]["output_data"]["branch_id"] if branch_id: - self.task.state.step_id = self.task.state.step_id + "." + branch_id + next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) else: logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") return [] - - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + else: + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 49a6c250..5ff3381c 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -162,7 +162,7 @@ class SpecialCallType(str, Enum): LLM = "LLM" START = "start" END = "end" - CHOICE = "choice" + CHOICE = "Choice" class CommentType(str, Enum): -- Gitee From d95f677fdfd555405fba8aafc49861e39d9e24d9 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 16:22:09 +0800 Subject: [PATCH 41/60] =?UTF-8?q?context=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/flow.py | 2 +- apps/scheduler/scheduler/context.py | 2 +- apps/schemas/task.py | 2 +- apps/services/task.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 157c9159..d8400fdf 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -119,7 +119,7 @@ class FlowExecutor(BaseExecutor): return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"]["branch_id"] + branch_id = self.task.context[-1].output_data["branch_id"] if branch_id: next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 109a5456..72f64a71 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -193,7 +193,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: flow_id=task.state.flow_id, flow_name=task.state.flow_name, flow_status=task.state.flow_status, - history_ids=[context["_id"] for context in task.context], + history_ids=[context.id for context in task.context], ) ) diff --git a/apps/schemas/task.py b/apps/schemas/task.py index e1901416..9e9f1531 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -93,7 +93,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) + context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") diff --git a/apps/services/task.py b/apps/services/task.py index dceee324..0f237a3f 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -13,6 +13,7 @@ from apps.schemas.task import ( TaskIds, TaskRuntime, TaskTokens, + FlowStepHistory ) from apps.services.record import RecordManager @@ -67,7 +68,7 @@ class TaskManager: return Task.model_validate(task) @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> FlowStepHistory: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") @@ -85,7 +86,7 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context_list.append(flow_context) + flow_context_list.append(FlowStepHistory.model_validate(flow_context)) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") return [] -- Gitee From ff332ab39f2e264f792bf5b08e959e3b67a2f89f Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 16:24:13 +0800 Subject: [PATCH 42/60] =?UTF-8?q?context=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/record.py | 1 - apps/services/task.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/routers/record.py b/apps/routers/record.py index 29ff68e0..f73e6c98 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -93,7 +93,6 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ steps=[], ) for flow_step in flow_step_list: - flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( RecordFlowStep( stepId=flow_step.step_name, # TODO: 此处前端应该用name diff --git a/apps/services/task.py b/apps/services/task.py index 0f237a3f..93604a93 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -68,7 +68,7 @@ class TaskManager: return Task.model_validate(task) @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> FlowStepHistory: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") -- Gitee From 15dc14679a106bef9ffa66b9f31cc1d9b4aaf502 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 22:26:56 +0800 Subject: [PATCH 43/60] =?UTF-8?q?=E5=AE=8C=E5=96=84=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E5=90=AF=E5=8A=A8=E6=B7=BB=E5=8A=A0=E6=9C=AC=E5=9C=B0=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E5=A4=B1=E8=B4=A5=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/apps/main.py b/apps/main.py index 1646f671..58e2bd40 100644 --- a/apps/main.py +++ b/apps/main.py @@ -97,10 +97,13 @@ async def add_no_auth_user() -> None: username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 if not username: username = "admin" - await user_collection.insert_one(User( - _id=username, - is_admin=True, - ).model_dump(by_alias=True)) + try: + await user_collection.insert_one(User( + _id=username, + is_admin=True, + ).model_dump(by_alias=True)) + except Exception as e: + logging.warning(f"添加无认证用户失败: {e}") async def init_resources() -> None: -- Gitee From 65953620cfa12b8153732a0c2bb05b78610be1a4 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:13:52 +0800 Subject: [PATCH 44/60] fix buh --- apps/scheduler/scheduler/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 91930f8c..f6325369 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -229,7 +229,8 @@ class Scheduler: # 初始化Executor logger.info("[Scheduler] 初始化Executor") - + logger.error(f"{flow_data}") + logger.error(f"{self.task}") flow_exec = FlowExecutor( flow_id=flow_id, flow=flow_data, -- Gitee From 3b1fb85f828acd35c1e3b4c68afb092455b04868 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:38:56 +0800 Subject: [PATCH 45/60] fix bug --- apps/scheduler/executor/flow.py | 2 +- apps/schemas/enum_var.py | 1 + apps/schemas/task.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index d8400fdf..ebd4da8e 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -55,7 +55,7 @@ class FlowExecutor(BaseExecutor): """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 5ff3381c..3fb65028 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -27,6 +27,7 @@ class FlowStatus(str, Enum): """Flow状态""" UNKNOWN = "unknown" + INIT = "init" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 9e9f1531..602123e7 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -40,7 +40,7 @@ class ExecutorState(BaseModel): flow_id: str = Field(description="Flow ID", default="") flow_name: str = Field(description="Flow名称", default="") description: str = Field(description="Flow描述", default="") - flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN) + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) # 任务级数据 step_id: str = Field(description="当前步骤ID", default="") step_name: str = Field(description="当前步骤名称", default="") -- Gitee From 6168945d3c707522480dbbecc9c9eda18e844871 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:42:59 +0800 Subject: [PATCH 46/60] fix bug --- apps/scheduler/executor/step.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 377a4c6e..5a95e407 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -260,7 +260,7 @@ class StepExecutor(BaseExecutor): input_data=self.obj.input, output_data=output_data, ) - self.task.context.append(history.model_dump(exclude_none=True, by_alias=True)) + self.task.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) -- Gitee From 9a0c015e79f3dc5302f1669cfed1b46347b3fcf7 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:48:26 +0800 Subject: [PATCH 47/60] fix bug --- apps/services/task.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/services/task.py b/apps/services/task.py index 93604a93..d133cbfb 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -94,7 +94,7 @@ class TaskManager: return flow_context_list @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") @@ -105,7 +105,8 @@ class TaskManager: ).sort( "created_at", -1, ).limit(length): - flow_context += [history] + for i in range(len(flow_context)): + flow_context.append(FlowStepHistory.model_validate(history)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] @@ -113,7 +114,7 @@ class TaskManager: return flow_context @staticmethod - async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: + async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: @@ -121,12 +122,12 @@ class TaskManager: # 查找是否存在 current_context = await flow_context_collection.find_one({ "task_id": task_id, - "_id": history["_id"], + "_id": history.id, }) if current_context: await flow_context_collection.update_one( {"_id": current_context["_id"]}, - {"$set": history}, + {"$set": history.model_dump(exclude_none=True, by_alias=True)}, ) else: await flow_context_collection.insert_one(history) -- Gitee From c83aced55db60209d6c8d5cad3e33088e93121b6 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:53:33 +0800 Subject: [PATCH 48/60] fix bug --- apps/services/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/services/task.py b/apps/services/task.py index d133cbfb..a8cb1848 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -130,7 +130,7 @@ class TaskManager: {"$set": history.model_dump(exclude_none=True, by_alias=True)}, ) else: - await flow_context_collection.insert_one(history) + await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True)) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") -- Gitee From 75f1670dcdeb04676b27e0bb7ee68ffeb52c8d91 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 17:43:45 +0800 Subject: [PATCH 49/60] =?UTF-8?q?=E5=AE=8C=E5=96=84mcp=20agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/auth.py | 5 ++-- apps/routers/user.py | 26 ++++++++++++++++++-- apps/scheduler/executor/agent.py | 42 +++++++++++++++++++++++++++++--- apps/scheduler/mcp/host.py | 4 +-- apps/scheduler/mcp_agent/host.py | 6 ++--- apps/schemas/collection.py | 1 + apps/schemas/message.py | 6 +++++ apps/schemas/request_data.py | 9 ++++++- apps/schemas/response_data.py | 1 + apps/schemas/task.py | 2 +- apps/schemas/user.py | 1 + apps/services/record.py | 5 +++- apps/services/user.py | 22 ++++++++++++++++- tests/manager/test_user.py | 2 +- 14 files changed, 115 insertions(+), 17 deletions(-) diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 4a3f8293..72416fa3 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -74,7 +74,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: status_code=status.HTTP_403_FORBIDDEN, ) - await UserManager.update_userinfo_by_user_sub(user_sub) + await UserManager.update_refresh_revision_by_user_sub(user_sub) current_session = await SessionManager.create_session(user_host, user_sub) @@ -177,6 +177,7 @@ async def userinfo( user_sub=user_sub, revision=user.is_active, is_admin=user.is_admin, + auto_execute=user.auto_execute, ), ).model_dump(exclude_none=True, by_alias=True), ) @@ -192,7 +193,7 @@ async def userinfo( ) async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001 """更新用户协议信息""" - ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True) + ret: bool = await UserManager.update_refresh_revision_by_user_sub(user_sub, refresh_revision=True) if not ret: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/user.py b/apps/routers/user.py index 54e12f44..8c197204 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -3,10 +3,11 @@ from typing import Annotated -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Body, Depends, status from fastapi.responses import JSONResponse from apps.dependency import get_user +from apps.schemas.request_data import UserUpdateRequest from apps.schemas.response_data import UserGetMsp, UserGetRsp from apps.schemas.user import UserInfo from apps.services.user import UserManager @@ -18,7 +19,7 @@ router = APIRouter( @router.get("") -async def chat( +async def get_user_sub( user_sub: Annotated[str, Depends(get_user)], ) -> JSONResponse: """查询所有用户接口""" @@ -42,3 +43,24 @@ async def chat( result=UserGetMsp(userInfoList=user_info_list), ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.post("") +async def update_user_info( + user_sub: Annotated[str, Depends(get_user)], + *, + data: Annotated[UserUpdateRequest, Body(..., description="用户更新信息")], +) -> JSONResponse: + """更新用户信息接口""" + # 更新用户信息 + + result = await UserManager.update_userinfo_by_user_sub(user_sub, data) + if not result: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, + ) + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"code": status.HTTP_404_NOT_FOUND, "message": "用户信息更新失败"}, + ) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index cb8e183e..d6ed5917 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -6,25 +6,61 @@ import logging from pydantic import Field from apps.scheduler.executor.base import BaseExecutor +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus from apps.scheduler.mcp_agent import host, plan, select +from apps.schemas.mcp import MCPServerConfig, MCPTool from apps.schemas.task import ExecutorState, StepQueueItem +from apps.schemas.message import param from apps.services.task import TaskManager - +from apps.services.appcenter import AppCenterManager +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" - question: str = Field(description="用户输入") max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") + mcp_list: list[MCPServerConfig] = Field(description="MCP服务器列表", default=[]) + tool_list: list[MCPTool] = Field(description="MCP工具列表", default=[]) + params: param | None = Field( + default=None, description="流执行过程中的参数补充", alias="params" + ) + + async def load_mcp_list(self) -> None: + """加载MCP服务器列表""" + logger.info("[MCPAgentExecutor] 加载MCP服务器列表") + # 获取MCP服务器列表 + app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) + mcp_ids = app.mcp_service + for mcp_id in mcp_ids: + self.mcp_list.append( + await MCPServiceManager.get_mcp_service(mcp_id) + ) + + async def load_tools(self) -> None: + """加载MCP工具列表""" + logger.info("[MCPAgentExecutor] 加载MCP工具列表") + # 获取工具列表 + mcp_ids = [mcp.id for mcp in self.mcp_list] + for mcp_id in mcp_ids: + if not await MCPServiceManager.is_mcp_enabled(mcp_id, self.agent_id): + logger.warning("MCP %s 未启用,跳过工具加载", mcp_id) + continue + # 获取MCP工具 + tools = await MCPServiceManager.get_mcp_tools(mcp_id) + self.tool_list.extend(tools) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + + async def run(self) -> None: + """执行MCP Agent的主逻辑""" + # 初始化MCP服务 diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index aa196112..8e110839 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -67,7 +67,7 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue context_list.append(context) @@ -118,7 +118,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context.model_dump(exclude_none=True, by_alias=True)) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index a8ebec7b..f05d8957 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -67,10 +67,10 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue - context_list.append(context) + context_list.append(context.model_dump(exclude_none=True, by_alias=True)) return self._env.from_string(MEMORY_TEMPLATE).render( context_list=context_list, @@ -118,7 +118,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index 0ff66c72..a2991f85 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -61,6 +61,7 @@ class User(BaseModel): fav_apps: list[str] = [] fav_services: list[str] = [] is_admin: bool = Field(default=False, description="是否为管理员") + auto_execute: bool = Field(default=True, description="是否自动执行任务") class LLM(BaseModel): diff --git a/apps/schemas/message.py b/apps/schemas/message.py index cf70a82b..1e58fb1b 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -9,6 +9,12 @@ from apps.schemas.enum_var import EventType, StepStatus from apps.schemas.record import RecordMetadata +class param(BaseModel): + """流执行过程中的参数补充""" + content: dict[str, Any] | bool = Field(default={}, description="流执行过程中的参数补充内容") + description: str = Field(default="", description="流执行过程中的参数补充描述") + + class HeartbeatData(BaseModel): """心跳事件的数据结构""" diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 793ff456..7bb5ac76 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -10,6 +10,7 @@ from apps.schemas.appcenter import AppData from apps.schemas.enum_var import CommentType from apps.schemas.flow_topology import FlowItem from apps.schemas.mcp import MCPType +from apps.schemas.message import param class RequestDataApp(BaseModel): @@ -17,7 +18,7 @@ class RequestDataApp(BaseModel): app_id: str = Field(description="应用ID", alias="appId") flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") - params: dict[str, Any] | None = Field(default=None, description="插件参数") + params: param | None = Field(default=None, description="流执行过程中的参数补充", alias="params") class MockRequestData(BaseModel): @@ -185,3 +186,9 @@ class UpdateKbReq(BaseModel): """更新知识库请求体""" kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[]) + + +class UserUpdateRequest(BaseModel): + """更新用户信息请求体""" + + auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index b1dc77b7..7f162326 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -55,6 +55,7 @@ class AuthUserMsg(BaseModel): user_sub: str revision: bool is_admin: bool + auto_execute: bool class AuthUserRsp(ResponseData): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 602123e7..b08da0a8 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -50,7 +50,7 @@ class ExecutorState(BaseModel): error_message: str = Field(description="当前步骤错误信息", default="") app_id: str = Field(description="应用ID", default="") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) - error_info: dict[str, Any] = Field(description="错误信息", default={}) + error_info: str = Field(description="错误信息", default="") class TaskIds(BaseModel): diff --git a/apps/schemas/user.py b/apps/schemas/user.py index 61aa2587..debb3446 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -9,3 +9,4 @@ class UserInfo(BaseModel): user_sub: str = Field(alias="userSub", default="") user_name: str = Field(alias="userName", default="") + auto_execute: bool = Field(alias="autoExecute", default=False) diff --git a/apps/services/record.py b/apps/services/record.py index 5c8a89df..e1925ef6 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -49,6 +49,10 @@ class RecordManager: mongo = MongoDB() group_collection = mongo.get_collection("record_group") try: + await group_collection.update_one( + {"_id": group_id, "user_sub": user_sub}, + {"$pull": {"records": {"id": record.id}}} + ) await group_collection.update_one( {"_id": group_id, "user_sub": user_sub}, {"$push": {"records": record.model_dump(by_alias=True)}}, @@ -151,7 +155,6 @@ class RecordManager: logger.exception("[RecordManager] 验证记录是否在组中失败") return False - @staticmethod async def check_group_id(group_id: str, user_sub: str) -> bool: """检查group_id是否存在""" diff --git a/apps/services/user.py b/apps/services/user.py index 2721d377..476f1ef4 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -4,6 +4,7 @@ import logging from datetime import UTC, datetime +from apps.schemas.request_data import UserUpdateRequest from apps.common.mongo import MongoDB from apps.schemas.collection import User from apps.services.conversation import ConversationManager @@ -52,7 +53,26 @@ class UserManager: return User(**user_data) if user_data else None @staticmethod - async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: + async def update_userinfo_by_user_sub(user_sub: str, data: UserUpdateRequest) -> bool: + """ + 根据用户sub更新用户信息 + + :param user_sub: 用户sub + :param data: 用户更新信息 + :return: 是否更新成功 + """ + mongo = MongoDB() + user_collection = mongo.get_collection("user") + update_dict = { + "$set": { + "auto_execute": data.auto_execute, + } + } + result = await user_collection.update_one({"_id": user_sub}, update_dict) + return result.modified_count > 0 + + @staticmethod + async def update_refresh_revision_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: """ 根据用户sub更新用户信息 diff --git a/tests/manager/test_user.py b/tests/manager/test_user.py index d350ca30..ab6eeab4 100644 --- a/tests/manager/test_user.py +++ b/tests/manager/test_user.py @@ -73,7 +73,7 @@ class TestUserManager(unittest.TestCase): mock_mysql_db_instance.get_session.return_value = mock_session # 调用被测方法 - updated_userinfo = UserManager.update_userinfo_by_user_sub(userinfo, refresh_revision=True) + updated_userinfo = UserManager.update_refresh_revision_by_user_sub(userinfo, refresh_revision=True) # 断言返回的用户信息的 revision_number 是否与原始用户信息一致 self.assertEqual(updated_userinfo.revision_number, userinfo.revision_number) -- Gitee From dc80695ae21a3fc59058fb03a4cc49d3c587c73c Mon Sep 17 00:00:00 2001 From: zxstty Date: Sat, 2 Aug 2025 17:20:11 +0800 Subject: [PATCH 50/60] =?UTF-8?q?=E5=AE=8C=E5=96=84task=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=E5=92=8Cmcp=20agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 2 + apps/routers/chat.py | 35 +++--- apps/routers/record.py | 5 +- apps/scheduler/executor/agent.py | 58 +++++----- apps/scheduler/mcp_agent/host.py | 165 ++++++++++------------------ apps/scheduler/mcp_agent/plan.py | 26 ++--- apps/scheduler/mcp_agent/prompt.py | 101 ++++++++++++++--- apps/scheduler/scheduler/context.py | 15 ++- apps/schemas/message.py | 4 +- apps/schemas/record.py | 3 +- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 7 +- apps/services/conversation.py | 2 +- apps/services/record.py | 18 ++- apps/services/task.py | 85 +++++++------- 15 files changed, 289 insertions(+), 238 deletions(-) diff --git a/apps/common/queue.py b/apps/common/queue.py index 911485b3..089d475e 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -56,6 +56,8 @@ class MessageQueue: flow = MessageFlow( appId=task.state.app_id, flowId=task.state.flow_id, + flowName=task.state.flow_name, + flowStatus=task.state.flow_status, stepId=task.state.step_id, stepName=task.state.step_name, stepStatus=task.state.step_status diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 26a87481..06bc2dd7 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -22,6 +22,8 @@ from apps.schemas.task import Task from apps.services.activity import Activity from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager from apps.services.flow import FlowManager +from apps.services.conversation import ConversationManager +from apps.services.record import RecordManager from apps.services.task import TaskManager RECOMMEND_TRES = 5 @@ -32,25 +34,33 @@ router = APIRouter( ) -async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task: +async def init_task(post_body: RequestData, user_sub: str) -> Task: """初始化Task""" # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - if post_body.new_task: - # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) - if task: - await TaskManager.delete_task_by_task_id(task.id) - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + # 更改信息并刷新数据库 if post_body.new_task: - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + conversation = await ConversationManager.get_conversation_by_conversation_id( + user_sub=user_sub, + conversation_id=post_body.conversation_id, + ) + if not conversation: + err = "[Chat] 用户没有权限访问该对话!" + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err) + task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) + await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) + task = await TaskManager.init_new_task(user_sub=user_sub, conversation_id=post_body.conversation_id, post_body=post_body) + else: + if not post_body.task_id: + err = "[Chat] task_id 不可为空!" + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty") + task = await TaskManager.get_task_by_conversation_id(post_body.task_id) return task -async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: +async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: await Activity.set_active(user_sub) @@ -62,7 +72,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await Activity.remove_active(user_sub) return - task = await init_task(post_body, user_sub, session_id) + task = await init_task(post_body, user_sub) # 创建queue;由Scheduler进行关闭 queue = MessageQueue() @@ -120,7 +130,6 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) async def chat( post_body: RequestData, user_sub: Annotated[str, Depends(get_user)], - session_id: Annotated[str, Depends(get_session)], ) -> StreamingResponse: """LLM流式对话接口""" # 问题黑名单检测 @@ -133,7 +142,7 @@ async def chat( if await Activity.is_active(user_sub): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - res = chat_generator(post_body, user_sub, session_id) + res = chat_generator(post_body, user_sub) return StreamingResponse( content=res, media_type="text/event-stream", diff --git a/apps/routers/record.py b/apps/routers/record.py index f73e6c98..8d0105ff 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -65,7 +65,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record = RecordData( id=record.id, groupId=record_group.id, - taskId=record_group.task_id, + taskId=record.task_id, conversationId=conversation_id, content=record_data, metadata=record.metadata @@ -87,8 +87,9 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record.flow = RecordFlow( id=record.flow.flow_id, # TODO: 此处前端应该用name recordId=record.id, - flowStatus=record.flow.flow_staus, flowId=record.flow.flow_id, + flowName=record.flow.flow_name, + flowStatus=record.flow.flow_staus, stepNum=len(flow_step_list), steps=[], ) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index d6ed5917..603ea65b 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -5,10 +5,13 @@ import logging from pydantic import Field +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.executor.base import BaseExecutor from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus -from apps.scheduler.mcp_agent import host, plan, select -from apps.schemas.mcp import MCPServerConfig, MCPTool +from apps.scheduler.mcp_agent.host import MCPHost +from apps.scheduler.mcp_agent.plan import MCPPlanner +from apps.scheduler.pool.mcp.client import MCPClient +from apps.schemas.mcp import MCPCollection, MCPTool from apps.schemas.task import ExecutorState, StepQueueItem from apps.schemas.message import param from apps.services.task import TaskManager @@ -24,43 +27,48 @@ class MCPAgentExecutor(BaseExecutor): servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - mcp_list: list[MCPServerConfig] = Field(description="MCP服务器列表", default=[]) + mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[]) + mcp_client: dict[str, MCPClient] = Field( + description="MCP客户端列表,key为mcp_id", default={} + ) tool_list: list[MCPTool] = Field(description="MCP工具列表", default=[]) params: param | None = Field( default=None, description="流执行过程中的参数补充", alias="params" ) + resoning_llm: ReasoningLLM = Field( + default=ReasoningLLM(), + description="推理大模型", + ) + + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) - async def load_mcp_list(self) -> None: + async def load_mcp(self) -> None: """加载MCP服务器列表""" logger.info("[MCPAgentExecutor] 加载MCP服务器列表") # 获取MCP服务器列表 app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) mcp_ids = app.mcp_service for mcp_id in mcp_ids: - self.mcp_list.append( - await MCPServiceManager.get_mcp_service(mcp_id) - ) - - async def load_tools(self) -> None: - """加载MCP工具列表""" - logger.info("[MCPAgentExecutor] 加载MCP工具列表") - # 获取工具列表 - mcp_ids = [mcp.id for mcp in self.mcp_list] - for mcp_id in mcp_ids: - if not await MCPServiceManager.is_mcp_enabled(mcp_id, self.agent_id): - logger.warning("MCP %s 未启用,跳过工具加载", mcp_id) + mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) + if self.task.ids.user_sub not in mcp_service.activated: + logger.warning( + "[MCPAgentExecutor] 用户 %s 未启用MCP %s", + self.task.ids.user_sub, + mcp_id, + ) continue - # 获取MCP工具 - tools = await MCPServiceManager.get_mcp_tools(mcp_id) - self.tool_list.extend(tools) - async def load_state(self) -> None: - """从数据库中加载FlowExecutor的状态""" - logger.info("[FlowExecutor] 加载Executor状态") - # 尝试恢复State - if self.task.state and self.task.state.flow_status != FlowStatus.INIT: - self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + self.mcp_list.append(mcp_service) + self.mcp_client[mcp_id] = await MCPHost.get_client(self.task.ids.user_sub, mcp_id) + self.tool_list.extend(mcp_service.tools) async def run(self) -> None: """执行MCP Agent的主逻辑""" # 初始化MCP服务 + self.load_state() + self.load_mcp() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index f05d8957..3217f539 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -14,116 +14,53 @@ from apps.llm.function import JsonGenerator from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import MCPPool +from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS from apps.schemas.enum_var import StepStatus from apps.schemas.mcp import MCPPlanItem, MCPTool -from apps.schemas.task import FlowStepHistory +from apps.schemas.task import Task, FlowStepHistory from apps.services.task import TaskManager logger = logging.getLogger(__name__) +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) + class MCPHost: """MCP宿主服务""" - def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None: - """初始化MCP宿主""" - self._user_sub = user_sub - self._task_id = task_id - # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description - self._runtime_id = runtime_id - self._runtime_name = runtime_name - self._context_list = [] - self._env = SandboxedEnvironment( - loader=BaseLoader(), - autoescape=False, - trim_blocks=True, - lstrip_blocks=True, - ) - - async def get_client(self, mcp_id: str) -> MCPClient | None: + @staticmethod + async def get_client(user_sub, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) if not mcp_db_result: - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + logger.warning("用户 %s 未启用MCP %s", user_sub, mcp_id) return None # 获取MCP配置 try: - return await MCPPool().get(mcp_id, self._user_sub) + return await MCPPool().get(mcp_id, user_sub) except KeyError: - logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) + logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", user_sub, mcp_id) return None - async def assemble_memory(self) -> str: + @staticmethod + async def assemble_memory(task: Task) -> str: """组装记忆""" - task = await TaskManager.get_task_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return "" - - context_list = [] - for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) - if not context: - continue - context_list.append(context.model_dump(exclude_none=True, by_alias=True)) - return self._env.from_string(MEMORY_TEMPLATE).render( - context_list=context_list, - ) - - async def _save_memory( - self, - tool: MCPTool, - plan_item: MCPPlanItem, - input_data: dict[str, Any], - result: str, - ) -> dict[str, Any]: - """保存记忆""" - try: - output_data = json.loads(result) - except Exception: # noqa: BLE001 - logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str") - output_data = { - "message": result, - } - - if not isinstance(output_data, dict): - output_data = { - "message": result, - } - - # 创建context;注意用法 - context = FlowStepHistory( - task_id=self._task_id, - flow_id=self._runtime_id, - flow_name=self._runtime_name, - flow_status=StepStatus.SUCCESS, - step_id=tool.name, - step_name=tool.name, - # description是规划的实际内容 - step_description=plan_item.content, - step_status=StepStatus.SUCCESS, - input_data=input_data, - output_data=output_data, + return _env.from_string(MEMORY_TEMPLATE).render( + context_list=task.context, ) - # 保存到task - task = await TaskManager.get_task_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return {} - self._context_list.append(context.id) - task.context.append(context) - await TaskManager.save_task(self._task_id, task) - - return output_data - - async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]: + async def _get_first_input_params(schema: dict[str, Any], query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" @@ -137,23 +74,48 @@ class MCPHost: llm_query, [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": await self.assemble_memory()}, + {"role": "user", "content": await MCPHost.assemble_memory()}, + ], + schema, + ) + return await json_generator.generate() + + async def _fill_params(mcp_tool: MCPTool, schema: dict[str, Any], + current_input: dict[str, Any], + error_message: str = "", params: dict[str, Any] = {}, + params_description: str = "") -> dict[str, Any]: + llm_query = "请生成修复之后的工具参数" + prompt = _env.from_string(REPAIR_PARAMS).render( + tool_name=mcp_tool.name, + tool_description=mcp_tool.description, + input_schema=schema, + current_input=current_input, + error_message=error_message, + params=params, + params_description=params_description, + ) + + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, ], schema, ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: + async def call_tool(user_sub: str, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client - client = await MCPPool().get(tool.mcp_id, self._user_sub) + client = await MCPPool().get(tool.mcp_id, user_sub) if client is None: err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}" logger.error(err) raise ValueError(err) # 填充参数 - params = await self._fill_params(tool, plan_item.instruction) + params = await MCPHost._fill_params(tool, plan_item.instruction) # 调用工具 result = await client.call_tool(tool.name, params) # 保存记忆 @@ -162,29 +124,12 @@ class MCPHost: if not isinstance(item, TextContent): logger.error("MCP结果类型不支持: %s", item) continue - processed_result.append(await self._save_memory(tool, plan_item, params, item.text)) - - return processed_result - - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: - """获取工具列表""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - - # 获取工具列表 - tool_list = [] - for mcp_id in mcp_id_list: - # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) - if not mcp_db_result: - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) - continue - # 获取MCP工具配置 + result = item.text try: - for tool in mcp_db_result["tools"]: - tool_list.extend([MCPTool.model_validate(tool)]) - except KeyError: - logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) + json_result = json.loads(result) + except Exception as e: + logger.error("MCP结果解析失败: %s, 错误: %s", result, e) continue + processed_result.append(json_result) - return tool_list + return processed_result diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index a7c1f132..13e7a98d 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -1,6 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" -from typing import Any +from typing import Any, AsyncGenerator from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -57,8 +57,8 @@ class MCPPlanner: result += chunk # 保存token用量 - self.input_tokens = self.resoning_llm.input_tokens - self.output_tokens = self.resoning_llm.output_tokens + self.input_tokens += self.resoning_llm.input_tokens + self.output_tokens += self.resoning_llm.output_tokens return result async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: @@ -171,7 +171,8 @@ class MCPPlanner: """获取推理大模型的风险评估结果""" template = self._env.from_string(RISK_EVALUATE) prompt = template.render( - tool=tool, + tool_name=tool.name, + tool_description=tool.description, input_param=input_param, additional_info=additional_info, ) @@ -194,7 +195,8 @@ class MCPPlanner: schema_with_null = slot.add_null_to_basic_types() template = self._env.from_string(GET_MISSING_PARAMS) prompt = template.render( - tool=tool, + tool_name=tool.name, + tool_description=tool.description, input_param=input_param, schema=schema_with_null, error_message=error_message, @@ -204,7 +206,7 @@ class MCPPlanner: input_param_with_null = await self._parse_result(result, schema_with_null) return input_param_with_null - async def generate_answer(self, plan: MCPPlan, memory: str) -> str: + async def generate_answer(self, plan: MCPPlan, memory: str) -> AsyncGenerator[str, None]: """生成最终回答""" template = self._env.from_string(FINAL_ANSWER) prompt = template.render( @@ -213,16 +215,12 @@ class MCPPlanner: goal=self.user_goal, ) - llm = ReasoningLLM() - result = "" - async for chunk in llm.call( + async for chunk in self.resoning_llm.call( [{"role": "user", "content": prompt}], streaming=False, temperature=0.07, ): - result += chunk + yield chunk - self.input_tokens = llm.input_tokens - self.output_tokens = llm.output_tokens - - return result + self.input_tokens = self.resoning_llm.input_tokens + self.output_tokens = self.resoning_llm.output_tokens diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index cf05afc8..9cbc2f5b 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -421,8 +421,8 @@ RISK_EVALUATE = dedent(r""" ``` # 工具 - {{ tool.name }} - {{ tool.description }} + {{ tool_name }} + {{ tool_description }} # 工具入参 {{ input_param }} @@ -464,7 +464,7 @@ JUDGE_NEXT_STEP = dedent(r""" { "content": "任务执行完成,端口扫描结果为Result[2]", "tool": "Final", - "instruction": "" + "instruction": "" } ]} ## 当前使用的工具 @@ -489,8 +489,8 @@ JUDGE_NEXT_STEP = dedent(r""" {{ current_plan }} # 当前使用的工具 - {{ tool.name }} - {{ tool.description }} + {{ tool_name }} + {{ tool_description }} # 工具入参 {{ input_param }} @@ -571,8 +571,8 @@ GET_MISSING_PARAMS = dedent(r""" ``` # 工具 < tool > - < name > {{tool.name}} < /name > - < description > {{tool.description}} < /description > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > < / tool > # 工具入参 {{input_param}} @@ -583,32 +583,107 @@ GET_MISSING_PARAMS = dedent(r""" # 输出 """ ) +REPAIR_PARAMS = dedent(r""" + 你是一个工具参数修复器。 + 你的任务是根据当前的工具信息、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。 + + # 样例 + ## 工具信息 + + mysql_analyzer + 分析MySQL数据库性能 + + ## 工具入参的schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL数据库的主机地址" + }, + "port": { + "type": "integer", + "description": "MySQL数据库的端口号" + }, + "username": { + "type": "string", + "description": "MySQL数据库的用户名" + }, + "password": { + "type": "string", + "description": "MySQL数据库的密码" + } + }, + "required": ["host", "port", "username", "password"] + } + ## 工具当前的入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + ## 工具的报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + ## 补充的参数 + { + "username": "admin", + "password": "admin123" + } + ## 补充的参数描述 + 用户希望使用admin用户和admin123密码来连接MySQL数据库。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "admin", + "password": "admin123" + } + ``` + # 工具 + + {{tool_name}} + {{tool_description}} + + # 工具入参scheme + {{input_schema}} + # 工具入参 + {{input_param}} + # 运行报错 + {{error_message}} + # 补充的参数 + {{params}} + # 补充的参数描述 + {{params_description}} + # 输出 + """ + ) FINAL_ANSWER = dedent(r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 # 用户目标 - {{goal}} + {{ goal }} # 计划执行情况 为了完成上述目标,你实施了以下计划: - {{memory}} + {{ memory }} # 其他背景信息: - {{status}} + {{ status }} # 现在,请根据以上信息,向用户报告目标的完成情况: """) - MEMORY_TEMPLATE = dedent(r""" - { % for ctx in context_list % } + {% for ctx in context_list % } - 第{{loop.index}}步:{{ctx.step_description}} - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` - 执行状态:{{ctx.status}} - 得到数据:`{{ctx.output_data}}` - { % endfor % } + {% endfor % } """) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 72f64a71..3b26f42f 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -199,21 +199,21 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: # 检查是否存在group_id if not await RecordManager.check_group_id(task.ids.group_id, user_sub): - record_group = await RecordManager.create_record_group( - task.ids.group_id, user_sub, post_body.conversation_id, task.id, + record_group_id = await RecordManager.create_record_group( + task.ids.group_id, user_sub, post_body.conversation_id ) - if not record_group: + if not record_group_id: logger.error("[Scheduler] 创建问答组失败") return else: - record_group = task.ids.group_id + record_group_id = task.ids.group_id # 修改文件状态 - await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) + await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group_id) # 保存Record - await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record) + await RecordManager.insert_record_data_into_record_group(user_sub, record_group_id, record) # 保存与答案关联的文件 - await DocumentManager.save_answer_doc(user_sub, record_group, used_docs) + await DocumentManager.save_answer_doc(user_sub, record_group_id, used_docs) if post_body.app and post_body.app.app_id: # 更新最近使用的应用 @@ -223,5 +223,4 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED: await TaskManager.delete_task_by_task_id(task.id) else: - # 更新Task await TaskManager.save_task(task.id, task) diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 1e58fb1b..e7341324 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -5,7 +5,7 @@ from typing import Any from datetime import UTC, datetime from pydantic import BaseModel, Field -from apps.schemas.enum_var import EventType, StepStatus +from apps.schemas.enum_var import EventType, FlowStatus, StepStatus from apps.schemas.record import RecordMetadata @@ -28,6 +28,8 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") + flow_name: str = Field(description="Flow名称", alias="flowName") + flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN) step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d3222762..6a394375 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -44,6 +44,7 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_name: str = Field(alias="flowName", default="") flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] @@ -129,6 +130,7 @@ class Record(RecordData): user_sub: str key: dict[str, Any] = {} + task_id: str content: str comment: RecordComment = Field(default=RecordComment()) flow: FlowHistory = Field( @@ -149,5 +151,4 @@ class RecordGroup(BaseModel): records: list[Record] = [] docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变 conversation_id: str - task_id: str created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 7bb5ac76..8719c2e9 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -47,6 +47,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + task_id: str | None = Field(default=None, alias="taskId", description="任务ID") new_task: bool = Field(default=True, description="是否新建任务") diff --git a/apps/schemas/task.py b/apps/schemas/task.py index b08da0a8..eccc95a5 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -46,11 +46,12 @@ class ExecutorState(BaseModel): step_name: str = Field(description="当前步骤名称", default="") step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") - retry_times: int = Field(description="当前步骤重试次数", default=0) - error_message: str = Field(description="当前步骤错误信息", default="") app_id: str = Field(description="应用ID", default="") - slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) + current_input: dict[str, Any] = Field(description="当前输入数据", default={}) + params: dict[str, Any] = Field(description="补充的参数", default={}) + params_description: str = Field(description="补充的参数描述", default="") error_info: str = Field(description="错误信息", default="") + retry_times: int = Field(description="当前步骤重试次数", default=0) class TaskIds(BaseModel): diff --git a/apps/services/conversation.py b/apps/services/conversation.py index bac964db..6bacb727 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -140,4 +140,4 @@ class ConversationManager: await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session) await session.commit_transaction() - await TaskManager.delete_tasks_by_conversation_id(conversation_id) + await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id) diff --git a/apps/services/record.py b/apps/services/record.py index e1925ef6..6b61f91e 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -9,7 +9,7 @@ from apps.schemas.record import ( Record, RecordGroup, ) - +from apps.schemas.enum_var import FlowStatus logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None: + async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None: """创建问答组""" mongo = MongoDB() record_group_collection = mongo.get_collection("record_group") @@ -26,7 +26,6 @@ class RecordManager: _id=group_id, user_sub=user_sub, conversation_id=conversation_id, - task_id=task_id, ) try: @@ -137,6 +136,19 @@ class RecordManager: logger.exception("[RecordManager] 查询问答组失败") return [] + @staticmethod + async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: + """更新Record关联的Flow状态""" + record_group_collection = MongoDB().get_collection("record_group") + try: + await record_group_collection.update_many( + {"records.flow.flow_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], + ) + except Exception: + logger.exception("[RecordManager] 更新Record关联的Flow状态失败") + @staticmethod async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool: """ diff --git a/apps/services/task.py b/apps/services/task.py index a8cb1848..2f75a8c3 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -113,6 +113,28 @@ class TaskManager: else: return flow_context + @staticmethod + async def init_new_task( + cls, + user_sub: str, + session_id: str | None = None, + post_body: RequestData | None = None, + ) -> Task: + """获取任务块""" + return Task( + _id=str(uuid.uuid4()), + ids=TaskIds( + user_sub=user_sub if user_sub else "", + session_id=session_id if session_id else "", + conversation_id=post_body.conversation_id, + group_id=post_body.group_id if post_body.group_id else "", + ), + question=post_body.question if post_body else "", + group_id=post_body.group_id if post_body else "", + tokens=TaskTokens(), + runtime=TaskRuntime(), + ) + @staticmethod async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" @@ -145,7 +167,25 @@ class TaskManager: await task_collection.delete_one({"_id": task_id}) @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: str) -> None: + async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]: + """通过ConversationID删除Task信息""" + mongo = MongoDB() + task_collection = mongo.get_collection("task") + task_ids = [] + try: + async for task in task_collection.find( + {"conversation_id": conversation_id}, + {"_id": 1}, + ): + task_ids.append(task["_id"]) + if task_ids: + await task_collection.delete_many({"conversation_id": conversation_id}) + except Exception: + logger.exception("[TaskManager] 删除ConversationID的Task信息失败") + return [] + + @staticmethod + async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") @@ -162,49 +202,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod - async def get_task( - cls, - task_id: str | None = None, - session_id: str | None = None, - post_body: RequestData | None = None, - user_sub: str | None = None, - ) -> Task: - """获取任务块""" - if task_id: - try: - task = await cls.get_task_by_task_id(task_id) - if task: - return task - except Exception: - logger.exception("[TaskManager] 通过task_id获取任务失败") - - logger.info("[TaskManager] 未提供task_id,通过session_id获取任务") - if not session_id or not post_body: - err = ( - "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。" - ) - raise ValueError(err) - - if post_body.group_id: - task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id) - else: - task = await cls.get_task_by_conversation_id(post_body.conversation_id) - - if task: - return task - return Task( - _id=str(uuid.uuid4()), - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - group_id=post_body.group_id if post_body.group_id else "", - ), - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee From ea68121853eac7182da40f046682fc202aef8439 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 11:14:56 +0800 Subject: [PATCH 51/60] =?UTF-8?q?=E5=AE=8C=E5=96=84Agent=20=E5=BC=80?= =?UTF-8?q?=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/agent.py | 389 +++++++++++++++++++++++++- apps/scheduler/mcp/select.py | 6 - apps/scheduler/mcp_agent/host.py | 12 +- apps/scheduler/mcp_agent/plan.py | 220 ++++++++++----- apps/scheduler/mcp_agent/prompt.py | 340 ++++++++++++++-------- apps/scheduler/mcp_agent/select.py | 240 ++++++---------- apps/scheduler/scheduler/scheduler.py | 56 +++- apps/schemas/enum_var.py | 6 +- apps/schemas/mcp.py | 27 ++ apps/schemas/message.py | 4 +- apps/schemas/pool.py | 3 + apps/schemas/task.py | 6 +- 12 files changed, 933 insertions(+), 376 deletions(-) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 603ea65b..4db38587 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -2,21 +2,34 @@ """MCP Agent执行器""" import logging - +import uuid from pydantic import Field - +from typing import Any +from apps.llm.patterns.rewrite import QuestionRewrite from apps.llm.reasoning import ReasoningLLM from apps.scheduler.executor.base import BaseExecutor from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus from apps.scheduler.mcp_agent.host import MCPHost from apps.scheduler.mcp_agent.plan import MCPPlanner +from apps.scheduler.mcp_agent.select import FINAL_TOOL_ID, MCPSelector from apps.scheduler.pool.mcp.client import MCPClient -from apps.schemas.mcp import MCPCollection, MCPTool -from apps.schemas.task import ExecutorState, StepQueueItem +from apps.schemas.mcp import ( + GoalEvaluationResult, + RestartStepIndex, + ToolRisk, + ErrorType, + ToolExcutionErrorType, + MCPPlan, + MCPCollection, + MCPTool +) +from apps.schemas.task import ExecutorState, FlowStepHistory, StepQueueItem from apps.schemas.message import param from apps.services.task import TaskManager from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager +from apps.services.task import TaskManager +from apps.services.user import UserManager logger = logging.getLogger(__name__) @@ -31,7 +44,9 @@ class MCPAgentExecutor(BaseExecutor): mcp_client: dict[str, MCPClient] = Field( description="MCP客户端列表,key为mcp_id", default={} ) - tool_list: list[MCPTool] = Field(description="MCP工具列表", default=[]) + tools: dict[str, MCPTool] = Field( + description="MCP工具列表,key为tool_id", default={} + ) params: param | None = Field( default=None, description="流执行过程中的参数补充", alias="params" ) @@ -65,10 +80,372 @@ class MCPAgentExecutor(BaseExecutor): self.mcp_list.append(mcp_service) self.mcp_client[mcp_id] = await MCPHost.get_client(self.task.ids.user_sub, mcp_id) - self.tool_list.extend(mcp_service.tools) + for tool in mcp_service.tools: + self.tools[tool.id] = tool + + async def plan(self, is_replan: bool = False, start_index: int | None = None) -> None: + if is_replan: + error_message = "之前的计划遇到以下报错\n\n"+self.task.state.error_message + else: + error_message = "初始化计划" + tools = MCPSelector.select_top_tool( + self.task.runtime.question, list(self.tools.values()), + additional_info=error_message, top_n=40) + if is_replan: + logger.info("[MCPAgentExecutor] 重新规划流程") + if not start_index: + start_index = await MCPPlanner.get_replan_start_step_index(self.task.runtime.question, + self.task.state.error_message, + self.task.runtime.temporary_plans, + self.resoning_llm) + current_plan = self.task.runtime.temporary_plans.plans[start_index:] + error_message = self.task.state.error_message + temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question, + is_replan=is_replan, + error_message=error_message, + current_plan=current_plan, + tool_list=tools, + max_steps=self.max_steps-start_index-1, + reasoning_llm=self.resoning_llm + ) + self.msg_queue.push_output( + self.task, + EventType.STEP_CANCEL, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.CANCELLED + self.task.runtime.temporary_plans = self.task.runtime.temporary_plans.plans[:start_index] + temporary_plans.plans + self.task.state.step_index = start_index + else: + start_index = 0 + self.task.runtime.temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question, tool_list=tools, max_steps=self.max_steps, reasoning_llm=self.resoning_llm) + for i in range(start_index, len(self.task.runtime.temporary_plans.plans)): + self.task.runtime.temporary_plans.plans[i].step_id = str(uuid.uuid4()) + + async def get_tool_input_param(self, is_first: bool) -> dict[str, Any]: + if is_first: + # 获取第一个输入参数 + self.task.state.current_input = await MCPHost._get_first_input_params(self.tools[self.task.state.step_id], self.task.runtime.question, self.task) + else: + # 获取后续输入参数 + if isinstance(self.params, param): + params = self.params.content + params_description = self.params.description + else: + params = {} + params_description = "" + self.task.state.current_input = await MCPHost._fill_params(self.tools[self.task.state.step_id], self.task.state.current_input, self.task.state.error_message, params, params_description) + + async def reset_step_to_index(self, start_index: int) -> None: + """重置步骤到开始""" + logger.info("[MCPAgentExecutor] 重置步骤到索引 %d", start_index) + if self.task.runtime.temporary_plans: + self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.step_id = self.task.runtime.temporary_plans.plans[start_index].step_id + self.task.state.step_index = 0 + self.task.state.step_name = self.task.runtime.temporary_plans.plans[start_index].tool + self.task.state.step_description = self.task.runtime.temporary_plans.plans[start_index].content + self.task.state.step_status = StepStatus.RUNNING + self.task.state.retry_times = 0 + else: + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_id = FINAL_TOOL_ID + + async def confirm_before_step(self) -> None: + logger.info("[MCPAgentExecutor] 等待用户确认步骤 %d", self.task.state.step_index) + # 发送确认消息 + confirm_message = await MCPPlanner.get_tool_risk(self.tools[self.task.state.step_id], self.task.state.current_input, "", self.resoning_llm) + self.msg_queue.push_output(self.task, EventType.STEP_WAITING_FOR_START, confirm_message.model_dump( + exclude_none=True, by_alias=True)) + self.msg_queue.push_output(self.task, EventType.FLOW_STOP, {}) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.WAITING + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True), + ) + ) + + async def run_step(self): + """执行步骤""" + self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.step_status = StepStatus.RUNNING + logger.info("[MCPAgentExecutor] 执行步骤 %d", self.task.state.step_index) + # 获取MCP客户端 + mcp_tool = self.tools[self.task.state.step_id] + mcp_client = self.mcp_client[mcp_tool.mcp_id] + if not mcp_client: + logger.error("[MCPAgentExecutor] MCP客户端未找到: %s", mcp_tool.mcp_id) + self.task.state.flow_status = FlowStatus.ERROR + error = "[MCPAgentExecutor] MCP客户端未找到: {}".format(mcp_tool.mcp_id) + raise Exception(error) + try: + output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) + self.msg_queue.push_output( + self.task, + EventType.STEP_INPUT, + self.task.state.current_input + ) + self.msg_queue.push_output( + self.task, + EventType.STEP_OUTPUT, + output_params + ) + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=StepStatus.SUCCESS, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data=self.task.state.current_input, + output_data=output_params, + ) + ) + self.task.state.step_status = StepStatus.SUCCESS + except Exception as e: + logging.warning("[MCPAgentExecutor] 执行步骤 %s 失败: %s", mcp_tool.name, str(e)) + import traceback + self.task.state.error_message = traceback.format_exc() + self.task.state.step_status = StepStatus.ERROR + + async def generate_params_with_null(self) -> None: + """生成参数补充""" + mcp_tool = self.tools[self.task.state.step_id] + params_with_null = await MCPPlanner.get_missing_param( + mcp_tool, + self.task.state.current_input, + self.task.state.error_message, + self.resoning_llm + ) + self.msg_queue.push_output( + self.task, + EventType.STEP_WAITING_FOR_PARAM, + data={ + "message": "当运行产生如下报错:\n" + self.task.state.error_message, + "params": params_with_null + } + ) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.PARAM + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data={ + "message": "当运行产生如下报错:\n" + self.task.state.error_message, + "params": params_with_null + } + ) + ) + + async def get_next_step(self) -> None: + self.task.state.step_index += 1 + if self.task.state.step_index < len(self.task.runtime.temporary_plans): + if self.task.runtime.temporary_plans.plans[self.task.state.step_index].step_id == FINAL_TOOL_ID: + # 最后一步 + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_status = StepStatus.SUCCESS + self.msg_queue.push_output( + self.task, + EventType.FLOW_SUCCESS, + data={} + ) + return + self.task.state.step_id = self.task.runtime.temporary_plans.plans[self.task.state.step_index].step_id + self.task.state.step_name = self.task.runtime.temporary_plans.plans[self.task.state.step_index].tool + self.task.state.step_description = self.task.runtime.temporary_plans.plans[self.task.state.step_index].content + self.task.state.step_status = StepStatus.INIT + self.task.state.current_input = {} + self.msg_queue.push_output( + self.task, + EventType.STEP_INIT, + data={} + ) + else: + # 没有下一步了,结束流程 + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_status = StepStatus.SUCCESS + self.msg_queue.push_output( + self.task, + EventType.FLOW_SUCCESS, + data={} + ) + return + + async def error_handle_after_step(self) -> None: + """步骤执行失败后的错误处理""" + self.task.state.step_status = StepStatus.ERROR + self.task.state.flow_status = FlowStatus.ERROR + self.msg_queue.push_output( + self.task, + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) + + async def work(self) -> None: + """执行当前步骤""" + if self.task.state.step_status == StepStatus.INIT: + self.get_tool_input_param(is_first=True) + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if not user_info.auto_execute: + # 等待用户确认 + await self.confirm_before_step() + return + self.step.state.step_status = StepStatus.RUNNING + elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: + if self.task.context[-1].step_status == StepStatus.PARAM: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + elif self.task.state.step_status == StepStatus.WAITING: + if self.params.content: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + else: + self.task.state.flow_status = FlowStatus.CANCELLED + self.task.state.step_status = StepStatus.CANCELLED + self.msg_queue.push_output( + self.task, + EventType.STEP_CANCEL, + data={} + ) + self.msg_queue.push_output( + self.task, + EventType.FLOW_CANCEL, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.CANCELLED + if self.task.state.step_status == StepStatus.PARAM: + self.get_tool_input_param(is_first=False) + max_retry = 5 + for i in range(max_retry): + if i != 0: + self.get_tool_input_param(is_first=False) + await self.run_step() + if self.task.state.step_status == StepStatus.SUCCESS: + break + elif self.task.state.step_status == StepStatus.ERROR: + # 错误处理 + if self.task.state.retry_times >= 3: + await self.error_handle_after_step() + else: + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + mcp_tool = self.tools[self.task.state.step_id] + error_type = await MCPPlanner.get_tool_execute_error_type( + self.task.runtime.question, + self.task.runtime.temporary_plans, + mcp_tool, + self.task.state.current_input, + self.task.state.error_message, + self.resoning_llm + ) + if error_type.type == ErrorType.DECORRECT_PLAN or user_info.auto_execute: + await self.plan(is_replan=True) + self.reset_step_to_index(self.task.state.step_index) + elif error_type.type == ErrorType.MISSING_PARAM: + await self.generate_params_with_null() + elif self.task.state.step_status == StepStatus.SUCCESS: + await self.get_next_step() + + async def summarize(self) -> None: + async for chunk in MCPPlanner.generate_answer( + self.task.runtime.question, + self.task.runtime.temporary_plans, + (await MCPHost.assemble_memory(self.task)), + self.resoning_llm + ): + self.msg_queue.push_output( + self.task, + EventType.TEXT_ADD, + data=chunk + ) + self.task.runtime.answer += chunk async def run(self) -> None: """执行MCP Agent的主逻辑""" # 初始化MCP服务 self.load_state() self.load_mcp() + if self.task.state.flow_status == FlowStatus.INIT: + # 初始化状态 + self.task.state.flow_id = str(uuid.uuid4()) + self.task.state.flow_name = await MCPPlanner.get_flow_name(self.task.runtime.question, self.resoning_llm) + self.task.runtime.temporary_plans = await self.plan(is_replan=False) + self.reset_step_to_index(0) + TaskManager.save_task(self.task.id, self.task) + self.task.state.flow_status = FlowStatus.RUNNING + self.msg_queue.push_output( + self.task, + EventType.FLOW_START, + data={} + ) + try: + while self.task.state.step_index < len(self.task.runtime.temporary_plans) and \ + self.task.state.flow_status == FlowStatus.RUNNING: + self.work() + TaskManager.save_task(self.task.id, self.task) + except Exception as e: + logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e)) + self.task.state.flow_status = FlowStatus.ERROR + self.task.state.error_message = str(e) + self.task.state.step_status = StepStatus.ERROR + self.msg_queue.push_output( + self.task, + EventType.STEP_ERROR, + data={} + ) + self.msg_queue.push_output( + self.task, + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 2ff50344..f3d6e0d4 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -39,7 +39,6 @@ class MCPSelector: sql += f"'{mcp_id}', " return sql.rstrip(", ") + ")" - async def _get_top_mcp_by_embedding( self, query: str, @@ -72,7 +71,6 @@ class MCPSelector: }]) return llm_mcp_list - async def _get_mcp_by_llm( self, query: str, @@ -100,7 +98,6 @@ class MCPSelector: # 使用小模型提取JSON return await self._call_function_mcp(result, mcp_ids) - async def _call_reasoning(self, prompt: str) -> str: """调用大模型进行推理""" logger.info("[MCPHelper] 调用推理大模型") @@ -116,7 +113,6 @@ class MCPSelector: self.output_tokens += llm.output_tokens return result - async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: """调用结构化输出小模型提取JSON""" logger.info("[MCPHelper] 调用结构化输出小模型") @@ -136,7 +132,6 @@ class MCPSelector: raise return result - async def select_top_mcp( self, query: str, @@ -153,7 +148,6 @@ class MCPSelector: # 通过LLM选择最合适的 return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) - @staticmethod async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: """选择最合适的工具""" diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index 3217f539..ced175ef 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -60,7 +60,7 @@ class MCPHost: context_list=task.context, ) - async def _get_first_input_params(schema: dict[str, Any], query: str) -> dict[str, Any]: + async def _get_first_input_params(mcp_tool: MCPTool, query: str, task: Task) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" @@ -74,13 +74,13 @@ class MCPHost: llm_query, [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": await MCPHost.assemble_memory()}, + {"role": "user", "content": await MCPHost.assemble_memory(task)}, ], - schema, + mcp_tool.input_schema, ) return await json_generator.generate() - async def _fill_params(mcp_tool: MCPTool, schema: dict[str, Any], + async def _fill_params(mcp_tool: MCPTool, current_input: dict[str, Any], error_message: str = "", params: dict[str, Any] = {}, params_description: str = "") -> dict[str, Any]: @@ -88,7 +88,7 @@ class MCPHost: prompt = _env.from_string(REPAIR_PARAMS).render( tool_name=mcp_tool.name, tool_description=mcp_tool.description, - input_schema=schema, + input_schema=mcp_tool.input_schema, current_input=current_input, error_message=error_message, params=params, @@ -101,7 +101,7 @@ class MCPHost: {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], - schema, + mcp_tool.input_schema, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 13e7a98d..91d293fb 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -9,38 +9,36 @@ from apps.llm.function import JsonGenerator from apps.scheduler.mcp_agent.prompt import ( EVALUATE_GOAL, GENERATE_FLOW_NAME, + GET_REPLAN_START_STEP_INDEX, CREATE_PLAN, RECREATE_PLAN, RISK_EVALUATE, + TOOL_EXECUTE_ERROR_TYPE_ANALYSIS, GET_MISSING_PARAMS, FINAL_ANSWER ) from apps.schemas.mcp import ( GoalEvaluationResult, + RestartStepIndex, ToolRisk, + ToolExcutionErrorType, MCPPlan, MCPTool ) from apps.scheduler.slot.slot import Slot +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) + class MCPPlanner: """MCP 用户目标拆解与规划""" - - def __init__(self, user_goal: str, resoning_llm: ReasoningLLM = None) -> None: - """初始化MCP规划器""" - self.user_goal = user_goal - self._env = SandboxedEnvironment( - loader=BaseLoader, - autoescape=True, - trim_blocks=True, - lstrip_blocks=True, - ) - self.resoning_llm = resoning_llm or ReasoningLLM() - self.input_tokens = 0 - self.output_tokens = 0 - - async def get_resoning_result(self, prompt: str) -> str: + @staticmethod + async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理结果""" # 调用推理大模型 message = [ @@ -48,7 +46,7 @@ class MCPPlanner: {"role": "user", "content": prompt}, ] result = "" - async for chunk in self.resoning_llm.call( + async for chunk in resoning_llm.call( message, streaming=False, temperature=0.07, @@ -56,12 +54,10 @@ class MCPPlanner: ): result += chunk - # 保存token用量 - self.input_tokens += self.resoning_llm.input_tokens - self.output_tokens += self.resoning_llm.output_tokens return result - async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: + @staticmethod + async def _parse_result(result: str, schema: dict[str, Any]) -> str: """解析推理结果""" json_generator = JsonGenerator( result, @@ -74,126 +70,210 @@ class MCPPlanner: json_result = await json_generator.generate() return json_result - async def evaluate_goal(self, tool_list: list[MCPTool]) -> GoalEvaluationResult: + @staticmethod + async def evaluate_goal( + tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM()) -> GoalEvaluationResult: """评估用户目标的可行性""" # 获取推理结果 - result = await self._get_reasoning_evaluation(tool_list) + result = await MCPPlanner._get_reasoning_evaluation(tool_list, resoning_llm) # 解析为结构化数据 - evaluation = await self._parse_evaluation_result(result) + evaluation = await MCPPlanner._parse_evaluation_result(result) # 返回评估结果 return evaluation - async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str: + @staticmethod + async def _get_reasoning_evaluation( + goal, tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理大模型的评估结果""" - template = self._env.from_string(EVALUATE_GOAL) + template = _env.from_string(EVALUATE_GOAL) prompt = template.render( - goal=self.user_goal, + goal=goal, tools=tool_list, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) return result - async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult: + @staticmethod + async def _parse_evaluation_result(result: str) -> GoalEvaluationResult: """将推理结果解析为结构化数据""" schema = GoalEvaluationResult.model_json_schema() - evaluation = await self._parse_result(result, schema) + evaluation = await MCPPlanner._parse_result(result, schema) # 使用GoalEvaluationResult模型解析结果 return GoalEvaluationResult.model_validate(evaluation) - async def get_flow_name(self) -> str: + async def get_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取当前流程的名称""" - result = await self._get_reasoning_flow_name() + result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm) return result - async def _get_reasoning_flow_name(self) -> str: + @staticmethod + async def _get_reasoning_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理大模型的流程名称""" - template = self._env.from_string(GENERATE_FLOW_NAME) - prompt = template.render(goal=self.user_goal) - result = await self.get_resoning_result(prompt) + template = _env.from_string(GENERATE_FLOW_NAME) + prompt = template.render(goal=user_goal) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) return result - async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: + @staticmethod + async def get_replan_start_step_index( + user_goal: str, error_message: str, current_plan: MCPPlan | None = None, + history: str = "", + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan: + """获取重新规划的步骤索引""" + # 获取推理结果 + template = _env.from_string(GET_REPLAN_START_STEP_INDEX) + prompt = template.render( + goal=user_goal, + error_message=error_message, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + history=history, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + schema = RestartStepIndex.model_json_schema() + schema["properties"]["start_index"]["maximum"] = len(current_plan.plans) - 1 + schema["properties"]["start_index"]["minimum"] = 0 + restart_index = await MCPPlanner._parse_result(result, schema) + # 使用RestartStepIndex模型解析结果 + return RestartStepIndex.model_validate(restart_index) + + @staticmethod + async def create_plan( + user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 6, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 - result = await self._get_reasoning_plan(tool_list, max_steps) + result = await MCPPlanner._get_reasoning_plan(user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm) # 解析为结构化数据 - return await self._parse_plan_result(result, max_steps) + return await MCPPlanner._parse_plan_result(result, max_steps) + @staticmethod async def _get_reasoning_plan( - self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(), + user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, tool_list: list[MCPTool] = [], - max_steps: int = 10) -> str: + max_steps: int = 10, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理大模型的结果""" # 格式化Prompt if is_replan: - template = self._env.from_string(RECREATE_PLAN) + template = _env.from_string(RECREATE_PLAN) prompt = template.render( - current_plan=current_plan, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), error_message=error_message, - goal=self.user_goal, + goal=user_goal, tools=tool_list, max_num=max_steps, ) else: - template = self._env.from_string(CREATE_PLAN) + template = _env.from_string(CREATE_PLAN) prompt = template.render( - goal=self.user_goal, + goal=user_goal, tools=tool_list, max_num=max_steps, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) return result - async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: + @staticmethod + async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan: """将推理结果解析为结构化数据""" # 格式化Prompt schema = MCPPlan.model_json_schema() schema["properties"]["plans"]["maxItems"] = max_steps - plan = await self._parse_result(result, schema) + plan = await MCPPlanner._parse_result(result, schema) # 使用Function模型解析结果 return MCPPlan.model_validate(plan) - async def get_tool_risk(self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "") -> ToolRisk: + @staticmethod + async def get_tool_risk( + tool: MCPTool, input_parm: dict[str, Any], + additional_info: str = "", resoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolRisk: """获取MCP工具的风险评估结果""" # 获取推理结果 - result = await self._get_reasoning_risk(tool, input_parm, additional_info) + result = await MCPPlanner._get_reasoning_risk(tool, input_parm, additional_info, resoning_llm) # 解析为结构化数据 - risk = await self._parse_risk_result(result) + risk = await MCPPlanner._parse_risk_result(result) # 返回风险评估结果 return risk - async def _get_reasoning_risk(self, tool: MCPTool, input_param: dict[str, Any], additional_info: str) -> str: + @staticmethod + async def _get_reasoning_risk( + tool: MCPTool, input_param: dict[str, Any], + additional_info: str, resoning_llm: ReasoningLLM) -> str: """获取推理大模型的风险评估结果""" - template = self._env.from_string(RISK_EVALUATE) + template = _env.from_string(RISK_EVALUATE) prompt = template.render( tool_name=tool.name, tool_description=tool.description, input_param=input_param, additional_info=additional_info, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) return result - async def _parse_risk_result(self, result: str) -> ToolRisk: + @staticmethod + async def _parse_risk_result(result: str) -> ToolRisk: """将推理结果解析为结构化数据""" schema = ToolRisk.model_json_schema() - risk = await self._parse_result(result, schema) + risk = await MCPPlanner._parse_result(result, schema) # 使用ToolRisk模型解析结果 return ToolRisk.model_validate(risk) + @staticmethod + async def _get_reasoning_tool_execute_error_type( + user_goal: str, current_plan: MCPPlan, + tool: MCPTool, input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理大模型的工具执行错误类型""" + template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS) + prompt = template.render( + goal=user_goal, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return result + + @staticmethod + async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType: + """将推理结果解析为工具执行错误类型""" + schema = ToolExcutionErrorType.model_json_schema() + error_type = await MCPPlanner._parse_result(result, schema) + # 使用ToolExcutionErrorType模型解析结果 + return ToolExcutionErrorType.model_validate(error_type) + + @staticmethod + async def get_tool_execute_error_type( + user_goal: str, current_plan: MCPPlan, + tool: MCPTool, input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolExcutionErrorType: + """获取MCP工具执行错误类型""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_tool_execute_error_type( + user_goal, current_plan, tool, input_param, error_message, reasoning_llm) + error_type = await MCPPlanner._parse_tool_execute_error_type_result(result) + # 返回工具执行错误类型 + return error_type + + @staticmethod async def get_missing_param( - self, tool: MCPTool, schema: dict[str, Any], + tool: MCPTool, input_param: dict[str, Any], - error_message: str) -> list[str]: + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> list[str]: """获取缺失的参数""" - slot = Slot(schema=schema) + slot = Slot(schema=tool.input_schema) + template = _env.from_string(GET_MISSING_PARAMS) schema_with_null = slot.add_null_to_basic_types() - template = self._env.from_string(GET_MISSING_PARAMS) prompt = template.render( tool_name=tool.name, tool_description=tool.description, @@ -201,26 +281,26 @@ class MCPPlanner: schema=schema_with_null, error_message=error_message, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) # 解析为结构化数据 - input_param_with_null = await self._parse_result(result, schema_with_null) + input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null) return input_param_with_null - async def generate_answer(self, plan: MCPPlan, memory: str) -> AsyncGenerator[str, None]: + @staticmethod + async def generate_answer( + user_goal: str, plan: MCPPlan, memory: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> AsyncGenerator[ + str, None]: """生成最终回答""" - template = self._env.from_string(FINAL_ANSWER) + template = _env.from_string(FINAL_ANSWER) prompt = template.render( - plan=plan, + plan=plan.model_dump(exclude_none=True, by_alias=True), memory=memory, - goal=self.user_goal, + goal=user_goal, ) - async for chunk in self.resoning_llm.call( + async for chunk in resoning_llm.call( [{"role": "user", "content": prompt}], streaming=False, temperature=0.07, ): yield chunk - - self.input_tokens = self.resoning_llm.input_tokens - self.output_tokens = self.resoning_llm.output_tokens diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 9cbc2f5b..b5bc085c 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -62,6 +62,62 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: """) +TOOL_SELECT = dedent(r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,附加信息,选择最合适的MCP工具。 + ## 选择MCP工具时的注意事项: + 1. 确保充分理解当前目标,选择实现目标所需的MCP工具。 + 2. 请在给定的MCP工具列表中选择,不要自己生成MCP工具。 + 3. 可以选择一些辅助工具,但必须确保这些工具与当前目标相关。 + 必须按照以下格式生成选择结果,不要输出任何其他内容: + ```json + { + "tool_ids": ["工具ID1", "工具ID2", ...] + } + ``` + + # 示例 + ## 目标 + 调优mysql性能 + ## MCP工具列表 + + - mcp_tool_1 MySQL链接池工具;用于优化MySQL链接池 + - mcp_tool_2 MySQL性能调优工具;用于分析MySQL性能瓶颈 + - mcp_tool_3 MySQL查询优化工具;用于优化MySQL查询语句 + - mcp_tool_4 MySQL索引优化工具;用于优化MySQL索引 + - mcp_tool_5 文件存储工具;用于存储文件 + - mcp_tool_6 mongoDB工具;用于操作MongoDB数据库 + + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```json + { + "max_connections": 1000, + "innodb_buffer_pool_size": "1G", + "query_cache_size": "64M" + } + ##输出 + ```json + { + "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"] + } + ``` + # 现在开始! + ## 目标 + {{goal}} + ## MCP工具列表 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + ## 附加信息 + {{additional_info}} + # 输出 + """ + ) + EVALUATE_GOAL = dedent(r""" 你是一个计划评估器。 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 @@ -76,18 +132,18 @@ EVALUATE_GOAL = dedent(r""" ``` # 样例 - ## 目标 - 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 - ## 工具集合 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + # 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - - mysql_analyzer分析MySQL数据库性能 - - performance_tuner调优数据库性能 - - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + - mysql_analyzer 分析MySQL数据库性能 + - performance_tuner 调优数据库性能 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 - ## 附加信息 + # 附加信息 1. 当前MySQL数据库的版本是8.0.26 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf @@ -100,17 +156,17 @@ EVALUATE_GOAL = dedent(r""" ``` # 目标 - {{ goal }} + {{goal}} # 工具集合 - {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} - {% endfor %} + { % for tool in tools % } + - {{tool.id}} {{tool.name}};{{tool.description}} + { % endfor % } # 附加信息 - {{ additional_info }} + {{additional_info}} """) GENERATE_FLOW_NAME = dedent(r""" @@ -123,15 +179,79 @@ GENERATE_FLOW_NAME = dedent(r""" 4. 流程名称应该尽量简短,小于20个字或者单词。 5. 只输出流程名称,不要输出其他内容。 # 样例 - ## 目标 - 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 - ## 输出 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 输出 扫描MySQL数据库并分析性能瓶颈,进行调优 # 现在开始生成流程名称: # 目标 - {{ goal }} + {{goal}} # 输出 """) +GET_REPLAN_START_STEP_INDEX = dedent(r""" + 你是一个智能助手,你的任务是根据用户的目标、报错信息和当前计划和历史,获取重新规划的步骤起始索引。 + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 报错信息 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 当前计划 + ```json + { + "plans": [ + { + "step_id": "step_1", + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描 + }, + { + "step_id": "step_2", + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + } + ] + } + # 历史 + [ + { + id: "0", + task_id: "task_1", + flow_id: "flow_1", + flow_name: "MYSQL性能调优", + flow_status: "RUNNING", + step_id: "step_1", + step_name: "生成端口扫描命令", + step_description: "生成端口扫描命令:扫描当前MySQL数据库的端口", + step_status: "FAILED", + input_data: { + "command": "nmap -p 3306 + "target": "localhost" + }, + output_data: { + "error": "- bash: curl: command not found" + } + } + ] + # 输出 + { + "start_index": 0, + "reasoning": "当前计划的第一步就失败了,报错信息显示curl命令未找到,可能是因为没有安装curl工具,因此需要从第一步重新规划。" + } + # 现在开始获取重新规划的步骤起始索引: + # 目标 + {{goal}} + # 报错信息 + {{error_message}} + # 当前计划 + {{current_plan}} + # 历史 + {{history}} + # 输出 + """) + CREATE_PLAN = dedent(r""" 你是一个计划生成器。 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 @@ -163,40 +283,38 @@ CREATE_PLAN = dedent(r""" } ``` - - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ -思考过程应放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 - - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 # 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} - {% endfor %} + { % for tool in tools % } + - {{tool.id}} {{tool.name}};{{tool.description}} + { % endfor % } # 样例 - ## 目标 + # 目标 - 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + 在后台运行一个新的alpine: latest容器,将主机/root文件夹挂载至/data,并执行top命令。 - ## 计划 + # 计划 - 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server + 1. 这个目标需要使用Docker来完成, 首先需要选择合适的MCP Server 2. 目标可以拆解为以下几个部分: - - 运行alpine:latest容器 + - 运行alpine: latest容器 - 挂载主机目录 - 在后台运行 - 执行top命令 - 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令 - - - ```json + 3. 需要先选择MCP Server, 然后生成Docker命令, 最后执行命令 + ```json { "plans": [ { @@ -225,7 +343,7 @@ CREATE_PLAN = dedent(r""" # 现在开始生成计划: - ## 目标 + # 目标 {{goal}} @@ -263,26 +381,24 @@ RECREATE_PLAN = dedent(r""" } ``` - - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ -思考过程应放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 - - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 # 样例 - ## 目标 + # 目标 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。 - ## 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + # 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - - command_generator生成命令行指令 - - tool_selector选择合适的工具 - - command_executor执行命令行指令 - - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 - - - ## 当前计划 + - command_generator 生成命令行指令 + - tool_selector 选择合适的工具 + - command_executor 执行命令行指令 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + # 当前计划 ```json { "plans": [ @@ -304,25 +420,23 @@ RECREATE_PLAN = dedent(r""" ] } ``` - ## 运行报错 - 执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 - ## 重新生成的计划 + # 运行报错 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 重新生成的计划 - 1. 这个目标需要使用网络扫描工具来完成,首先需要选择合适的网络扫描工具 + 1. 这个目标需要使用网络扫描工具来完成, 首先需要选择合适的网络扫描工具 2. 目标可以拆解为以下几个部分: - 生成端口扫描命令 - 执行端口扫描命令 - 3.但是在执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 + 3.但是在执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 4.我将计划调整为: - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具 - 执行这个命令,查看当前机器支持哪些网络扫描工具 - 然后从中选择一个网络扫描工具 - 基于选择的网络扫描工具,生成端口扫描命令 - 执行端口扫描命令 - - - ```json + ```json { "plans": [ { @@ -367,19 +481,19 @@ RECREATE_PLAN = dedent(r""" # 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} - {% endfor %} + { % for tool in tools % } + - {{tool.id}} {{tool.name}};{{tool.description}} + { % endfor % } # 当前计划 - {{ current_plan }} + {{current_plan}} # 运行报错 - {{ error_message }} + {{error_message}} # 重新生成的计划 """) @@ -393,18 +507,18 @@ RISK_EVALUATE = dedent(r""" } ``` # 样例 - ## 工具名称 + # 工具名称 mysql_analyzer - ## 工具描述 + # 工具描述 分析MySQL数据库性能 - ## 工具入参 + # 工具入参 { "host": "192.0.0.1", "port": 3306, "username": "root", "password": "password" } - ## 附加信息 + # 附加信息 1. 当前MySQL数据库的版本是8.0.26 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 ```ini @@ -412,7 +526,7 @@ RISK_EVALUATE = dedent(r""" innodb_buffer_pool_size=1G innodb_log_file_size=256M ``` - ## 输出 + # 输出 ```json { "risk": "中", @@ -421,35 +535,35 @@ RISK_EVALUATE = dedent(r""" ``` # 工具 - {{ tool_name }} - {{ tool_description }} + {{tool_name}} + {{tool_description}} # 工具入参 - {{ input_param }} + {{input_param}} # 附加信息 - {{ additional_info }} + {{additional_info}} # 输出 """ ) # 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划 -JUDGE_NEXT_STEP = dedent(r""" +TOOL_EXECUTE_ERROR_TYPE_ANALYSIS = dedent(r""" 你是一个计划决策器。 - 你的任务是根据当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 + 你的任务是根据用户目标、当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 请根据以下规则进行判断: - 1. 仅通过补充工具入参来解决问题的,返回 fill_params; - 2. 需要重计划当前步骤的,返回 replan_current_step; - 3. 需要重计划接下来的所有计划的,返回 replan_all_steps; + 1. 仅通过补充工具入参来解决问题的,返回 missing_param; + 2. 需要重计划当前步骤的,返回 decorrect_plan + 3.推理过程必须清晰明了,能够让人理解你的判断依据,并且不超过100字。 你的输出要以json格式返回,格式如下: ```json { - "next_step": "fill_params/replan_current_step/replan_all_steps", - "reason": "你的判断依据" + "error_type": "missing_param/decorrect_plan, + "reason": "你的推理过程" } ``` - 注意: - reason字段必须清晰明了,能够让人理解你的判断依据,并且不超过50个中文字或者100个英文单词。 # 样例 - ## 当前计划 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 当前计划 {"plans": [ { "content": "生成端口扫描命令", @@ -467,38 +581,40 @@ JUDGE_NEXT_STEP = dedent(r""" "instruction": "" } ]} - ## 当前使用的工具 + # 当前使用的工具 - command_executor - 执行命令行指令 + command_executor + 执行命令行指令 - ## 工具入参 + # 工具入参 { "command": "nmap -sS -p--open 192.168.1.1" } - ## 工具运行报错 - 执行端口扫描命令时,出现了错误:`-bash: nmap: command not found`。 - ## 输出 + # 工具运行报错 + 执行端口扫描命令时,出现了错误:`- bash: nmap: command not found`。 + # 输出 ```json { - "next_step": "replan_all_steps", - "reason": "当前工具执行报错,提示nmap命令未找到,需要增加command_generator和command_executor的步骤,生成nmap安装命令并执行,之后再生成端口扫描命令并执行。" + "error_type": "decorrect_plan", + "reason": "当前计划的第二步执行失败,报错信息显示nmap命令未找到,可能是因为没有安装nmap工具,因此需要重计划当前步骤。" } ``` + # 用户目标 + {{goal}} # 当前计划 - {{ current_plan }} + {{current_plan}} # 当前使用的工具 - {{ tool_name }} - {{ tool_description }} + {{tool_name}} + {{tool_description}} # 工具入参 - {{ input_param }} + {{input_param}} # 工具运行报错 - {{ error_message }} + {{error_message}} # 输出 """ - ) + ) # 获取缺失的参数的json结构体 GET_MISSING_PARAMS = dedent(r""" 你是一个工具参数获取器。 @@ -570,10 +686,10 @@ GET_MISSING_PARAMS = dedent(r""" } ``` # 工具 - < tool > - < name > {{tool_name}} < /name > - < description > {{tool_description}} < /description > - < / tool > + + {{tool_name}} + {{tool_description}} + # 工具入参 {{input_param}} # 工具入参schema(部分字段允许为null) @@ -588,12 +704,12 @@ REPAIR_PARAMS = dedent(r""" 你的任务是根据当前的工具信息、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。 # 样例 - ## 工具信息 + # 工具信息 - mysql_analyzer - 分析MySQL数据库性能 + mysql_analyzer + 分析MySQL数据库性能 - ## 工具入参的schema + # 工具入参的schema { "type": "object", "properties": { @@ -616,21 +732,21 @@ REPAIR_PARAMS = dedent(r""" }, "required": ["host", "port", "username", "password"] } - ## 工具当前的入参 + # 工具当前的入参 { "host": "192.0.0.1", "port": 3306, "username": "root", "password": "password" } - ## 工具的报错 + # 工具的报错 执行端口扫描命令时,出现了错误:`password is not correct`。 - ## 补充的参数 + # 补充的参数 { "username": "admin", "password": "admin123" } - ## 补充的参数描述 + # 补充的参数描述 用户希望使用admin用户和admin123密码来连接MySQL数据库。 # 输出 ```json @@ -643,8 +759,8 @@ REPAIR_PARAMS = dedent(r""" ``` # 工具 - {{tool_name}} - {{tool_description}} + {{tool_name}} + {{tool_description}} # 工具入参scheme {{input_schema}} @@ -664,17 +780,17 @@ FINAL_ANSWER = dedent(r""" # 用户目标 - {{ goal }} + {{goal}} # 计划执行情况 为了完成上述目标,你实施了以下计划: - {{ memory }} + {{memory}} # 其他背景信息: - {{ status }} + {{status}} # 现在,请根据以上信息,向用户报告目标的完成情况: diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index 37d1e752..933527c3 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -2,7 +2,7 @@ """选择MCP Server及其工具""" import logging -import uuid +import random from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from typing import AsyncGenerator @@ -13,176 +13,94 @@ from apps.common.mongo import MongoDB from apps.llm.embedding import Embedding from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp.prompt import ( - MCP_SELECT, -) +from apps.llm.token import TokenCalculator +from apps.scheduler.mcp_agent.prompt import TOOL_SELECT from apps.schemas.mcp import ( + BaseModel, MCPCollection, MCPSelectResult, MCPTool, + MCPToolIdsSelectResult ) - +from apps.common.config import Config logger = logging.getLogger(__name__) +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) -class MCPSelector: - """MCP选择器""" - - def __init__(self, resoning_llm: ReasoningLLM = None) -> None: - """初始化助手类""" - self.resoning_llm = resoning_llm or ReasoningLLM() - self.input_tokens = 0 - self.output_tokens = 0 - - @staticmethod - def _assemble_sql(mcp_list: list[str]) -> str: - """组装SQL""" - sql = "(" - for mcp_id in mcp_list: - sql += f"'{mcp_id}', " - return sql.rstrip(", ") + ")" - - async def _get_top_mcp_by_embedding( - self, - query: str, - mcp_list: list[str], - ) -> list[dict[str, str]]: - """通过向量检索获取Top5 MCP Server""" - logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) - mcp_table = await LanceDB().get_table("mcp") - query_embedding = await Embedding.get_embedding([query]) - mcp_vecs = await (await mcp_table.search( - query=query_embedding, - vector_column_name="embedding", - )).where(f"id IN {MCPSelector._assemble_sql(mcp_list)}").limit(5).to_list() - - # 拿到名称和description - logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs) - mcp_collection = MongoDB().get_collection("mcp") - llm_mcp_list: list[dict[str, str]] = [] - for mcp_vec in mcp_vecs: - mcp_id = mcp_vec["id"] - mcp_data = await mcp_collection.find_one({"_id": mcp_id}) - if not mcp_data: - logger.warning("[MCPHelper] 查询MCP Server名称和描述失败: %s", mcp_id) - continue - mcp_data = MCPCollection.model_validate(mcp_data) - llm_mcp_list.extend([{ - "id": mcp_id, - "name": mcp_data.name, - "description": mcp_data.description, - }]) - return llm_mcp_list - - async def _get_mcp_by_llm( - self, - query: str, - mcp_list: list[dict[str, str]], - mcp_ids: list[str], - ) -> MCPSelectResult: - """通过LLM选择最合适的MCP Server""" - # 初始化jinja2环境 - env = SandboxedEnvironment( - loader=BaseLoader, - autoescape=True, - trim_blocks=True, - lstrip_blocks=True, - ) - template = env.from_string(MCP_SELECT) - # 渲染模板 - mcp_prompt = template.render( - mcp_list=mcp_list, - goal=query, - ) - - # 调用大模型进行推理 - result = await self._call_reasoning(mcp_prompt) - - # 使用小模型提取JSON - return await self._call_function_mcp(result, mcp_ids) - - async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]: - """调用大模型进行推理""" - logger.info("[MCPHelper] 调用推理大模型") - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ] - async for chunk in self.resoning_llm.call(message): - yield chunk - - async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: - """调用结构化输出小模型提取JSON""" - logger.info("[MCPHelper] 调用结构化输出小模型") - llm = FunctionLLM() - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": reasoning_result}, - ] - schema = MCPSelectResult.model_json_schema() - # schema中加入选项 - schema["properties"]["mcp_id"]["enum"] = mcp_ids - result = await llm.call(messages=message, schema=schema) - try: - result = MCPSelectResult.model_validate(result) - except Exception: - logger.exception("[MCPHelper] 解析MCP Select Result失败") - raise - return result - - async def select_top_mcp( - self, - query: str, - mcp_list: list[str], - ) -> MCPSelectResult: - """ - 选择最合适的MCP Server +FINAL_TOOL_ID = "FIANL" +SUMMARIZE_TOOL_ID = "SUMMARIZE" - 先通过Embedding选择Top5,然后通过LLM选择Top 1 - """ - # 通过向量检索获取Top5 - llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list) - # 通过LLM选择最合适的 - return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) +class MCPSelector: + """MCP选择器""" @staticmethod - async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: + async def select_top_tool( + goal: str, tool_list: list[MCPTool], + additional_info: str | None = None, top_n: int | None = None) -> list[MCPTool]: """选择最合适的工具""" - tool_vector = await LanceDB().get_table("mcp_tool") - query_embedding = await Embedding.get_embedding([query]) - tool_vecs = await (await tool_vector.search( - query=query_embedding, - vector_column_name="embedding", - )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list() - - # 拿到工具 - tool_collection = MongoDB().get_collection("mcp") - llm_tool_list = [] - - for tool_vec in tool_vecs: - # 到MongoDB里找对应的工具 - logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) - tool_data = await tool_collection.aggregate([ - {"$match": {"_id": tool_vec["mcp_id"]}}, - {"$unwind": "$tools"}, - {"$match": {"tools.id": tool_vec["id"]}}, - {"$project": {"_id": 0, "tools": 1}}, - {"$replaceRoot": {"newRoot": "$tools"}}, - ]) - async for tool in tool_data: - tool_obj = MCPTool.model_validate(tool) - llm_tool_list.append(tool_obj) - llm_tool_list.append( - MCPTool( - id="00000000-0000-0000-0000-000000000000", - name="Final", - description="It is the final step, indicating the end of the plan execution.") - ) - llm_tool_list.append( - MCPTool( - id="00000000-0000-0000-0000-000000000001", - name="Chat", - description="It is a chat tool to communicate with the user.") - ) - return llm_tool_list + random.shuffle(tool_list) + max_tokens = Config().get_config().function_call.max_tokens + template = _env.from_string(TOOL_SELECT) + if TokenCalculator.calculate_token_length( + messages=[{"role": "user", "content": template.render( + goal=goal, tools=[], additional_info=additional_info + )}], + pure_text=True) > max_tokens: + logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择") + return [] + llm = FunctionLLM() + current_index = 0 + tool_ids = [] + while current_index < len(tool_list): + index = current_index + sub_tools = [] + while index < len(tool_list): + tool = tool_list[index] + tokens = TokenCalculator.calculate_token_length( + messages=[{"role": "user", "content": template.render( + goal=goal, tools=[tool], + additional_info=additional_info + )}], + pure_text=True + ) + if tokens > max_tokens: + continue + sub_tools.append(tool) + + tokens = TokenCalculator.calculate_token_length(messages=[{"role": "user", "content": template.render( + goal=goal, tools=sub_tools, additional_info=additional_info)}, ], pure_text=True) + if tokens > max_tokens: + del sub_tools[-1] + break + else: + index += 1 + current_index = index + if sub_tools: + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": template.render(tools=sub_tools)}, + ] + schema = MCPToolIdsSelectResult.model_json_schema() + schema["properties"]["tool_ids"]["enum"] = [tool.id for tool in sub_tools] + result = await llm.call(messages=message, schema=schema) + try: + result = MCPToolIdsSelectResult.model_validate(result) + tool_ids.extend(result.tool_ids) + except Exception: + logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败") + continue + mcp_tools = [tool for tool in tool_list if tool.id in tool_ids] + + if top_n is not None: + mcp_tools = mcp_tools[:top_n] + mcp_tools.append(MCPTool(id=FINAL_TOOL_ID, name="Final", + description="终止", mcp_id=FINAL_TOOL_ID, input_schema={})) + # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize", + # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={})) + return mcp_tools diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index f6325369..b8144847 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -4,7 +4,9 @@ import asyncio import logging from datetime import UTC, datetime - +from apps.llm.reasoning import ReasoningLLM +from apps.schemas.config import LLMConfig +from apps.llm.patterns.rewrite import QuestionRewrite from apps.common.config import Config from apps.common.mongo import MongoDB from apps.common.queue import MessageQueue @@ -67,8 +69,8 @@ class Scheduler: except Exception as e: logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") - async def run(self) -> None: # noqa: PLR0911 - """运行调度器""" + async def get_llm_use_in_chat_with_rag(self) -> LLM: + """获取RAG大模型""" try: # 获取当前会话使用的大模型 llm_id = await LLMManager.get_llm_id_by_conversation_id( @@ -97,14 +99,25 @@ class Scheduler: logger.exception("[Scheduler] 获取大模型失败") await self.queue.close() return + + async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]: + """获取知识库ID列表""" try: - # 获取当前会话使用的知识库 kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id( - self.task.ids.user_sub, self.task.ids.conversation_id) + self.task.ids.user_sub, self.task.ids.conversation_id, + ) + if not kb_ids: + logger.error("[Scheduler] 获取知识库ID失败") + await self.queue.close() + return [] + return kb_ids except Exception: logger.exception("[Scheduler] 获取知识库ID失败") await self.queue.close() - return + return [] + + async def run(self) -> None: # noqa: PLR0911 + """运行调度器""" try: # 获取当前问答可供关联的文档 docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body) @@ -114,13 +127,18 @@ class Scheduler: return history, _ = await get_context(self.task.ids.user_sub, self.post_body, 3) # 已使用文档 - # 如果是智能问答,直接执行 logger.info("[Scheduler] 开始执行") # 创建用于通信的事件 kill_event = asyncio.Event() monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub)) if not self.post_body.app or self.post_body.app.app_id == "": + llm = await self.get_llm_use_in_chat_with_rag() + kb_ids = await self.get_kb_ids_use_in_chat_with_rag() + if not llm: + logger.error("[Scheduler] 获取大模型失败") + await self.queue.close() + return self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( kbIds=kb_ids, @@ -199,6 +217,27 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return + llm = await LLMManager.get_llm_by_id( + self.task.ids.user_sub, app_metadata.llm_id, + ) + if not llm: + logger.error("[Scheduler] 获取大模型失败") + await self.queue.close() + return + reasion_llm = ReasoningLLM( + LLMConfig( + endpoint=llm.openai_base_url, + key=llm.openai_api_key, + model=llm.model_name, + max_tokens=llm.max_tokens, + ) + ) + if background.conversation: + try: + question_obj = QuestionRewrite() + post_body.question = await question_obj.generate(history=background.conversation, question=post_body.question, llm=reasion_llm) + except Exception: + logger.exception("[Scheduler] 问题重写失败") if app_metadata.app_type == AppType.FLOW.value: logger.info("[Scheduler] 获取工作流元数据") flow_info = await Pool().get_flow_metadata(app_info.app_id) @@ -229,8 +268,6 @@ class Scheduler: # 初始化Executor logger.info("[Scheduler] 初始化Executor") - logger.error(f"{flow_data}") - logger.error(f"{self.task}") flow_exec = FlowExecutor( flow_id=flow_id, flow=flow_data, @@ -258,6 +295,7 @@ class Scheduler: servers_id=servers_id, background=background, agent_id=app_info.app_id, + params=post_body.app.params ) # 开始运行 logger.info("[Scheduler] 运行Executor") diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 3fb65028..3bcabd57 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" UNKNOWN = "unknown" + INIT = "init" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" @@ -55,12 +56,15 @@ class EventType(str, Enum): STEP_WAITING_FOR_START = "step.waiting_for_start" STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" + STEP_INIT = "step.init" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" + STEP_CANCEL = "step.cancel" + STEP_ERROR = "step.error" FLOW_STOP = "flow.stop" FLOW_FAILED = "flow.failed" FLOW_SUCCESS = "flow.success" - FLOW_CANCELLED = "flow.cancelled" + FLOW_CANCEL = "flow.cancel" DONE = "done" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 368865ac..21c403d4 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -111,6 +111,13 @@ class GoalEvaluationResult(BaseModel): reason: str = Field(description="评估原因") +class RestartStepIndex(BaseModel): + """MCP重新规划的步骤索引""" + + start_index: int = Field(description="重新规划的起始步骤索引") + reasoning: str = Field(description="重新规划的原因") + + class Risk(str, Enum): """MCP工具风险类型""" @@ -126,6 +133,20 @@ class ToolRisk(BaseModel): reason: str = Field(description="风险原因", default="") +class ErrorType(str, Enum): + """MCP工具错误类型""" + + MISSING_PARAM = "missing_param" + DECORRECT_PLAN = "decorrect_plan" + + +class ToolExcutionErrorType(BaseModel): + """MCP工具执行错误""" + + type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM) + reason: str = Field(description="错误原因", default="") + + class MCPSelectResult(BaseModel): """MCP选择结果""" @@ -138,6 +159,12 @@ class MCPToolSelectResult(BaseModel): name: str = Field(description="工具名称") +class MCPToolIdsSelectResult(BaseModel): + """MCP工具ID选择结果""" + + tool_ids: list[str] = Field(description="工具ID列表") + + class MCPPlanItem(BaseModel): """MCP 计划""" step_id: str = Field(description="步骤的ID", default="") diff --git a/apps/schemas/message.py b/apps/schemas/message.py index e7341324..1f46ff57 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -84,7 +84,7 @@ class FlowStartContent(BaseModel): """flow.start消息的content""" question: str = Field(description="用户问题") - params: dict[str, Any] = Field(description="预先提供的参数") + params: dict[str, Any] | None = Field(description="预先提供的参数", default=None) class MessageBase(HeartbeatData): @@ -95,5 +95,5 @@ class MessageBase(HeartbeatData): conversation_id: str = Field(min_length=36, max_length=36, alias="conversationId") task_id: str = Field(min_length=36, max_length=36, alias="taskId") flow: MessageFlow | None = None - content: dict[str, Any] = {} + content: Any | None = Field(default=None, description="消息内容") metadata: MessageMetadata diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py index 27e16b37..7df6dab8 100644 --- a/apps/schemas/pool.py +++ b/apps/schemas/pool.py @@ -110,3 +110,6 @@ class AppPool(BaseData): flows: list[AppFlow] = Field(description="Flow列表", default=[]) hashes: dict[str, str] = Field(description="关联文件的hash值", default={}) mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + llm_id: str = Field( + default="empty", alias="llmId", description="应用使用的大模型ID(如果有的话)" + ) diff --git a/apps/schemas/task.py b/apps/schemas/task.py index eccc95a5..336bfedc 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -30,6 +30,7 @@ class FlowStepHistory(BaseModel): step_status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) + ex_data: dict[str, Any] | None = Field(description="额外数据", default=None) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) @@ -43,14 +44,13 @@ class ExecutorState(BaseModel): flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) # 任务级数据 step_id: str = Field(description="当前步骤ID", default="") + step_index: int = Field(description="当前步骤索引", default=0) step_name: str = Field(description="当前步骤名称", default="") step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID", default="") current_input: dict[str, Any] = Field(description="当前输入数据", default={}) - params: dict[str, Any] = Field(description="补充的参数", default={}) - params_description: str = Field(description="补充的参数描述", default="") - error_info: str = Field(description="错误信息", default="") + error_message: str = Field(description="错误信息", default="") retry_times: int = Field(description="当前步骤重试次数", default=0) -- Gitee From 842104a97f5ffb8096a9ab4d7a16e224cbf95981 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 14:41:59 +0800 Subject: [PATCH 52/60] =?UTF-8?q?=E5=AE=8C=E5=96=84app=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E7=9A=84=E6=95=B0=E6=8D=AE=E6=95=B0=E6=8D=AE=E7=BB=93=E6=9E=84?= =?UTF-8?q?&record=E7=9A=84rask=E5=8F=AF=E4=BB=A5=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E4=B8=BAnone?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/appcenter.py | 15 ++++++++++++--- apps/routers/mcp_service.py | 2 ++ apps/schemas/appcenter.py | 10 +++++++++- apps/schemas/record.py | 4 ++-- apps/services/mcp_service.py | 6 ++++++ 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 0ec4db91..df540eaf 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from apps.dependency.user import get_user, verify_user from apps.exceptions import InstancePermissionError -from apps.schemas.appcenter import AppFlowInfo, AppPermissionData +from apps.schemas.appcenter import AppFlowInfo, AppMcpServiceInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import CreateAppRequest, ModFavAppRequest from apps.schemas.response_data import ( @@ -25,7 +25,7 @@ from apps.schemas.response_data import ( ResponseData, ) from apps.services.appcenter import AppCenterManager - +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/app", @@ -214,6 +214,15 @@ async def get_application( ) for flow in app_data.flows ] + mcp_service = [] + if app_data.mcp_service: + for service in app_data.mcp_service: + mcp_collection = await MCPServiceManager.get_mcp_service(service) + mcp_service.append(AppMcpServiceInfo( + id=mcp_collection.id, + name=mcp_collection.name, + description=mcp_collection.description, + )) return JSONResponse( status_code=status.HTTP_200_OK, content=GetAppPropertyRsp( @@ -234,7 +243,7 @@ async def get_application( authorizedUsers=app_data.permission.users, ), workflows=workflows, - mcpService=app_data.mcp_service, + mcpService=mcp_service, ), ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index de484e78..848780e2 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -53,6 +53,7 @@ async def get_mcpservice_list( ] = SearchType.ALL, keyword: Annotated[str | None, Query(..., alias="keyword", description="搜索关键字")] = None, page: Annotated[int, Query(..., alias="page", ge=1, description="页码")] = 1, + is_active: Annotated[bool | None, Query(None, alias="isActive", description="是否激活")] = None, ) -> JSONResponse: """获取服务列表""" try: @@ -61,6 +62,7 @@ async def get_mcpservice_list( user_sub, keyword, page, + is_active ) except Exception as e: err = f"[MCPServiceCenter] 获取MCP服务列表失败: {e}" diff --git a/apps/schemas/appcenter.py b/apps/schemas/appcenter.py index a89f39df..a65fffb2 100644 --- a/apps/schemas/appcenter.py +++ b/apps/schemas/appcenter.py @@ -50,6 +50,14 @@ class AppFlowInfo(BaseModel): debug: bool = Field(..., description="是否经过调试") +class AppMcpServiceInfo(BaseModel): + """应用关联的MCP服务信息""" + + id: str = Field(..., description="MCP服务ID") + name: str = Field(..., description="MCP服务名称") + description: str = Field(..., description="MCP服务简介") + + class AppData(BaseModel): """应用信息数据结构""" @@ -64,4 +72,4 @@ class AppData(BaseModel): permission: AppPermissionData = Field( default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置") workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表") - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务id列表") diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 6a394375..144a6c57 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -130,8 +130,8 @@ class Record(RecordData): user_sub: str key: dict[str, Any] = {} - task_id: str - content: str + task_id: str | None = Field(default=None, description="任务ID") + content: str = Field(default="", description="Record内容,已加密") comment: RecordComment = Field(default=RecordComment()) flow: FlowHistory = Field( default=FlowHistory(), description="Flow执行历史信息") diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 2c84a211..ba510350 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -78,6 +78,7 @@ class MCPServiceManager: user_sub: str, keyword: str | None, page: int, + is_active: bool | None = None, ) -> list[MCPServiceCardItem]: """ 获取所有MCP服务列表 @@ -89,6 +90,11 @@ class MCPServiceManager: :return: MCP服务列表 """ filters = MCPServiceManager._build_filters(search_type, keyword) + if is_active is not None: + if is_active: + filters["activated"] = {"$in": [user_sub]} + else: + filters["activated"] = {"$nin": [user_sub]} mcpservice_pools = await MCPServiceManager._search_mcpservice(filters, page) return [ MCPServiceCardItem( -- Gitee From a537bf5ea71f958fdaa47b72997ea4b8d9381b9e Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 14:58:39 +0800 Subject: [PATCH 53/60] =?UTF-8?q?=E5=8E=BB=E9=99=A4chat=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E7=9A=84new=5Ftask=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 2 +- apps/schemas/request_data.py | 1 - apps/services/record.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 06bc2dd7..f92efd45 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -41,7 +41,7 @@ async def init_task(post_body: RequestData, user_sub: str) -> Task: post_body.group_id = str(uuid.uuid4()) # 更改信息并刷新数据库 - if post_body.new_task: + if post_body.task_id is None: conversation = await ConversationManager.get_conversation_by_conversation_id( user_sub=user_sub, conversation_id=post_body.conversation_id, diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 8719c2e9..8d053e1c 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -48,7 +48,6 @@ class RequestData(BaseModel): app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") task_id: str | None = Field(default=None, alias="taskId", description="任务ID") - new_task: bool = Field(default=True, description="是否新建任务") class QuestionBlacklistRequest(BaseModel): diff --git a/apps/services/record.py b/apps/services/record.py index 6b61f91e..cf8373b0 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -142,7 +142,7 @@ class RecordManager: record_group_collection = MongoDB().get_collection("record_group") try: await record_group_collection.update_many( - {"records.flow.flow_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], ) -- Gitee From 22cea94abf7a27e1de9044b3eee0dffae0c67421 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 15:02:10 +0800 Subject: [PATCH 54/60] fix bug --- apps/routers/mcp_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index 848780e2..82fa72de 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -53,7 +53,7 @@ async def get_mcpservice_list( ] = SearchType.ALL, keyword: Annotated[str | None, Query(..., alias="keyword", description="搜索关键字")] = None, page: Annotated[int, Query(..., alias="page", ge=1, description="页码")] = 1, - is_active: Annotated[bool | None, Query(None, alias="isActive", description="是否激活")] = None, + is_active: Annotated[bool | None, Query(..., alias="isActive", description="是否激活")] = None, ) -> JSONResponse: """获取服务列表""" try: -- Gitee From 4a8bae424599fa3fec6c3c2e92c1097d52a862a5 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 15:04:25 +0800 Subject: [PATCH 55/60] fix bug --- apps/schemas/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 144a6c57..dbc06b10 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -94,7 +94,7 @@ class RecordData(BaseModel): id: str group_id: str = Field(alias="groupId") conversation_id: str = Field(alias="conversationId") - task_id: str = Field(alias="taskId") + task_id: str | None = Field(default=None, alias="taskId") document: list[RecordDocument] = [] flow: RecordFlow | None = None content: RecordContent -- Gitee From 1000061b59f8de5b4d1001cb7f8a965fb8945e54 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 15:13:05 +0800 Subject: [PATCH 56/60] fix bug --- apps/routers/chat.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index f92efd45..dac7afe6 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -34,7 +34,7 @@ router = APIRouter( ) -async def init_task(post_body: RequestData, user_sub: str) -> Task: +async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task: """初始化Task""" # 生成group_id if not post_body.group_id: @@ -51,7 +51,7 @@ async def init_task(post_body: RequestData, user_sub: str) -> Task: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err) task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) - task = await TaskManager.init_new_task(user_sub=user_sub, conversation_id=post_body.conversation_id, post_body=post_body) + task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body) else: if not post_body.task_id: err = "[Chat] task_id 不可为空!" @@ -130,6 +130,7 @@ async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerato async def chat( post_body: RequestData, user_sub: Annotated[str, Depends(get_user)], + session_id: Annotated[str, Depends(get_session)], ) -> StreamingResponse: """LLM流式对话接口""" # 问题黑名单检测 @@ -142,7 +143,7 @@ async def chat( if await Activity.is_active(user_sub): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - res = chat_generator(post_body, user_sub) + res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, media_type="text/event-stream", -- Gitee From 611005dff5486c93592d68b41193f3e92912c779 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 15:25:00 +0800 Subject: [PATCH 57/60] fix bug --- apps/scheduler/executor/agent.py | 46 +++++++++++--------------------- 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 4db38587..a77478a9 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -108,8 +108,7 @@ class MCPAgentExecutor(BaseExecutor): max_steps=self.max_steps-start_index-1, reasoning_llm=self.resoning_llm ) - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_CANCEL, data={} ) @@ -156,9 +155,9 @@ class MCPAgentExecutor(BaseExecutor): logger.info("[MCPAgentExecutor] 等待用户确认步骤 %d", self.task.state.step_index) # 发送确认消息 confirm_message = await MCPPlanner.get_tool_risk(self.tools[self.task.state.step_id], self.task.state.current_input, "", self.resoning_llm) - self.msg_queue.push_output(self.task, EventType.STEP_WAITING_FOR_START, confirm_message.model_dump( + self.push_message(EventType.STEP_WAITING_FOR_START, confirm_message.model_dump( exclude_none=True, by_alias=True)) - self.msg_queue.push_output(self.task, EventType.FLOW_STOP, {}) + self.push_message(EventType.FLOW_STOP, {}) self.task.state.flow_status = FlowStatus.WAITING self.task.state.step_status = StepStatus.WAITING self.task.context.append( @@ -191,13 +190,11 @@ class MCPAgentExecutor(BaseExecutor): raise Exception(error) try: output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_INPUT, self.task.state.current_input ) - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_OUTPUT, output_params ) @@ -230,8 +227,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.error_message, self.resoning_llm ) - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_WAITING_FOR_PARAM, data={ "message": "当运行产生如下报错:\n" + self.task.state.error_message, @@ -265,8 +261,7 @@ class MCPAgentExecutor(BaseExecutor): # 最后一步 self.task.state.flow_status = FlowStatus.SUCCESS self.task.state.step_status = StepStatus.SUCCESS - self.msg_queue.push_output( - self.task, + self.push_message( EventType.FLOW_SUCCESS, data={} ) @@ -276,8 +271,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.step_description = self.task.runtime.temporary_plans.plans[self.task.state.step_index].content self.task.state.step_status = StepStatus.INIT self.task.state.current_input = {} - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_INIT, data={} ) @@ -285,8 +279,7 @@ class MCPAgentExecutor(BaseExecutor): # 没有下一步了,结束流程 self.task.state.flow_status = FlowStatus.SUCCESS self.task.state.step_status = StepStatus.SUCCESS - self.msg_queue.push_output( - self.task, + self.push_message( EventType.FLOW_SUCCESS, data={} ) @@ -296,8 +289,7 @@ class MCPAgentExecutor(BaseExecutor): """步骤执行失败后的错误处理""" self.task.state.step_status = StepStatus.ERROR self.task.state.flow_status = FlowStatus.ERROR - self.msg_queue.push_output( - self.task, + self.push_message( EventType.FLOW_FAILED, data={} ) @@ -338,13 +330,11 @@ class MCPAgentExecutor(BaseExecutor): else: self.task.state.flow_status = FlowStatus.CANCELLED self.task.state.step_status = StepStatus.CANCELLED - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_CANCEL, data={} ) - self.msg_queue.push_output( - self.task, + self.push_message( EventType.FLOW_CANCEL, data={} ) @@ -389,8 +379,7 @@ class MCPAgentExecutor(BaseExecutor): (await MCPHost.assemble_memory(self.task)), self.resoning_llm ): - self.msg_queue.push_output( - self.task, + self.push_message( EventType.TEXT_ADD, data=chunk ) @@ -409,8 +398,7 @@ class MCPAgentExecutor(BaseExecutor): self.reset_step_to_index(0) TaskManager.save_task(self.task.id, self.task) self.task.state.flow_status = FlowStatus.RUNNING - self.msg_queue.push_output( - self.task, + self.push_message( EventType.FLOW_START, data={} ) @@ -424,13 +412,11 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.flow_status = FlowStatus.ERROR self.task.state.error_message = str(e) self.task.state.step_status = StepStatus.ERROR - self.msg_queue.push_output( - self.task, + self.push_message( EventType.STEP_ERROR, data={} ) - self.msg_queue.push_output( - self.task, + self.push_message( EventType.FLOW_FAILED, data={} ) -- Gitee From 71fa94780210b602f2b42fe415713f5510bb377b Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 15:26:19 +0800 Subject: [PATCH 58/60] fix bug --- apps/scheduler/executor/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index cf2f4e68..d3563634 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -52,17 +52,16 @@ class BaseExecutor(BaseModel, ABC): elif event_type == EventType.FLOW_STOP.value: data = {} elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): - data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) + data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) if data is None: data = {} elif isinstance(data, str): data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) - await self.msg_queue.push_output( - self.task, + await self.push_message( event_type=event_type, - data=data, # type: ignore[arg-type] + data=data, # type: ignore[arg-type] ) @abstractmethod -- Gitee From abc1e2f16abb37df6bd35c78bd5d561a7edb8b97 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 5 Aug 2025 15:40:14 +0800 Subject: [PATCH 59/60] =?UTF-8?q?=E5=AE=8C=E5=96=84token=E6=B6=88=E8=80=97?= =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/agent.py | 16 +++++++++++++++- apps/scheduler/executor/base.py | 5 ++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index a77478a9..f4fc2061 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -55,6 +55,12 @@ class MCPAgentExecutor(BaseExecutor): description="推理大模型", ) + async def update_tokens(self) -> None: + """更新令牌数""" + self.task.tokens.input_tokens = self.resoning_llm.input_tokens + self.task.tokens.output_tokens = self.resoning_llm.output_tokens + await TaskManager.save_task(self.task.id, self.task) + async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -108,6 +114,7 @@ class MCPAgentExecutor(BaseExecutor): max_steps=self.max_steps-start_index-1, reasoning_llm=self.resoning_llm ) + self.update_tokens() self.push_message( EventType.STEP_CANCEL, data={} @@ -155,6 +162,7 @@ class MCPAgentExecutor(BaseExecutor): logger.info("[MCPAgentExecutor] 等待用户确认步骤 %d", self.task.state.step_index) # 发送确认消息 confirm_message = await MCPPlanner.get_tool_risk(self.tools[self.task.state.step_id], self.task.state.current_input, "", self.resoning_llm) + self.update_tokens() self.push_message(EventType.STEP_WAITING_FOR_START, confirm_message.model_dump( exclude_none=True, by_alias=True)) self.push_message(EventType.FLOW_STOP, {}) @@ -187,9 +195,10 @@ class MCPAgentExecutor(BaseExecutor): logger.error("[MCPAgentExecutor] MCP客户端未找到: %s", mcp_tool.mcp_id) self.task.state.flow_status = FlowStatus.ERROR error = "[MCPAgentExecutor] MCP客户端未找到: {}".format(mcp_tool.mcp_id) - raise Exception(error) + self.task.state.error_message = error try: output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) + self.update_tokens() self.push_message( EventType.STEP_INPUT, self.task.state.current_input @@ -227,6 +236,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.error_message, self.resoning_llm ) + self.update_tokens() self.push_message( EventType.STEP_WAITING_FOR_PARAM, data={ @@ -234,6 +244,10 @@ class MCPAgentExecutor(BaseExecutor): "params": params_with_null } ) + self.push_message( + EventType.FLOW_STOP, + data={} + ) self.task.state.flow_status = FlowStatus.WAITING self.task.state.step_status = StepStatus.PARAM self.task.context.append( diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index d3563634..8dcb99c7 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -49,8 +49,6 @@ class BaseExecutor(BaseModel, ABC): question=self.question, params=self.task.runtime.filled, ).model_dump(exclude_none=True, by_alias=True) - elif event_type == EventType.FLOW_STOP.value: - data = {} elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) @@ -59,7 +57,8 @@ class BaseExecutor(BaseModel, ABC): elif isinstance(data, str): data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) - await self.push_message( + await self.msg_queue.push_output( + self.task, event_type=event_type, data=data, # type: ignore[arg-type] ) -- Gitee From 246a193d9b3ffff89708461ee338e5c582f006aa Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 6 Aug 2025 11:11:22 +0800 Subject: [PATCH 60/60] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dscheduler=E5=92=8Ctask?= =?UTF-8?q?=E4=B8=AD=E5=AD=98=E5=9C=A8=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 7 +++++-- apps/scheduler/executor/agent.py | 19 +++++++++---------- apps/scheduler/scheduler/context.py | 1 - apps/scheduler/scheduler/scheduler.py | 23 ++++++++--------------- apps/services/activity.py | 4 +++- apps/services/task.py | 2 +- 6 files changed, 26 insertions(+), 30 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index dac7afe6..5e035475 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -52,6 +52,8 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body) + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id else: if not post_body.task_id: err = "[Chat] task_id 不可为空!" @@ -60,7 +62,7 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T return task -async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerator[str, None]: +async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: await Activity.set_active(user_sub) @@ -72,7 +74,7 @@ async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerato await Activity.remove_active(user_sub) return - task = await init_task(post_body, user_sub) + task = await init_task(post_body, user_sub, session_id) # 创建queue;由Scheduler进行关闭 queue = MessageQueue() @@ -80,6 +82,7 @@ async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerato # 在单独Task中运行Scheduler,拉齐queue.get的时机 scheduler = Scheduler(task, queue, post_body) + logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_sub)}") scheduler_task = asyncio.create_task(scheduler.run()) # 处理每一条消息 diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f4fc2061..fc799fa1 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -28,7 +28,6 @@ from apps.schemas.message import param from apps.services.task import TaskManager from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager -from apps.services.task import TaskManager from apps.services.user import UserManager logger = logging.getLogger(__name__) @@ -129,7 +128,7 @@ class MCPAgentExecutor(BaseExecutor): for i in range(start_index, len(self.task.runtime.temporary_plans.plans)): self.task.runtime.temporary_plans.plans[i].step_id = str(uuid.uuid4()) - async def get_tool_input_param(self, is_first: bool) -> dict[str, Any]: + async def get_tool_input_param(self, is_first: bool) -> None: if is_first: # 获取第一个输入参数 self.task.state.current_input = await MCPHost._get_first_input_params(self.tools[self.task.state.step_id], self.task.runtime.question, self.task) @@ -222,7 +221,7 @@ class MCPAgentExecutor(BaseExecutor): ) self.task.state.step_status = StepStatus.SUCCESS except Exception as e: - logging.warning("[MCPAgentExecutor] 执行步骤 %s 失败: %s", mcp_tool.name, str(e)) + logger.warning("[MCPAgentExecutor] 执行步骤 %s 失败: %s", mcp_tool.name, str(e)) import traceback self.task.state.error_message = traceback.format_exc() self.task.state.step_status = StepStatus.ERROR @@ -326,13 +325,13 @@ class MCPAgentExecutor(BaseExecutor): async def work(self) -> None: """执行当前步骤""" if self.task.state.step_status == StepStatus.INIT: - self.get_tool_input_param(is_first=True) + await self.get_tool_input_param(is_first=True) user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) if not user_info.auto_execute: # 等待用户确认 await self.confirm_before_step() return - self.step.state.step_status = StepStatus.RUNNING + self.task.state.step_status = StepStatus.RUNNING elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: if self.task.context[-1].step_status == StepStatus.PARAM: if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: @@ -355,11 +354,11 @@ class MCPAgentExecutor(BaseExecutor): if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: self.task.context[-1].step_status = StepStatus.CANCELLED if self.task.state.step_status == StepStatus.PARAM: - self.get_tool_input_param(is_first=False) - max_retry = 5 + await self.get_tool_input_param(is_first=False) + max_retry = 5 for i in range(max_retry): if i != 0: - self.get_tool_input_param(is_first=False) + await self.get_tool_input_param(is_first=False) await self.run_step() if self.task.state.step_status == StepStatus.SUCCESS: break @@ -408,7 +407,7 @@ class MCPAgentExecutor(BaseExecutor): # 初始化状态 self.task.state.flow_id = str(uuid.uuid4()) self.task.state.flow_name = await MCPPlanner.get_flow_name(self.task.runtime.question, self.resoning_llm) - self.task.runtime.temporary_plans = await self.plan(is_replan=False) + await self.plan(is_replan=False) self.reset_step_to_index(0) TaskManager.save_task(self.task.id, self.task) self.task.state.flow_status = FlowStatus.RUNNING @@ -419,7 +418,7 @@ class MCPAgentExecutor(BaseExecutor): try: while self.task.state.step_index < len(self.task.runtime.temporary_plans) and \ self.task.state.flow_status == FlowStatus.RUNNING: - self.work() + await self.work() TaskManager.save_task(self.task.id, self.task) except Exception as e: logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e)) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 3b26f42f..dc35d4bd 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -158,7 +158,6 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: facts=task.runtime.facts, data={}, ) - try: # 加密Record数据 encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index b8144847..8ea3619b 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -56,7 +56,7 @@ class Scheduler: while not kill_event.is_set(): # 检查用户活动状态 is_active = await Activity.is_active(user_sub) - + logger.info(f"[Scheduler] 用户 {user_sub} 活动状态: {is_active}") if not is_active: logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_sub) kill_event.set() @@ -78,8 +78,7 @@ class Scheduler: ) if not llm_id: logger.error("[Scheduler] 获取大模型ID失败") - await self.queue.close() - return + return None if llm_id == "empty": llm = LLM( _id="empty", @@ -89,16 +88,16 @@ class Scheduler: model_name=Config().get_config().llm.model, max_tokens=Config().get_config().llm.max_tokens, ) + return llm else: llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id) if not llm: logger.error("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None + return llm except Exception: logger.exception("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]: """获取知识库ID列表""" @@ -106,10 +105,6 @@ class Scheduler: kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id( self.task.ids.user_sub, self.task.ids.conversation_id, ) - if not kb_ids: - logger.error("[Scheduler] 获取知识库ID失败") - await self.queue.close() - return [] return kb_ids except Exception: logger.exception("[Scheduler] 获取知识库ID失败") @@ -131,14 +126,12 @@ class Scheduler: logger.info("[Scheduler] 开始执行") # 创建用于通信的事件 kill_event = asyncio.Event() + logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(self.task.ids.user_sub)}") monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub)) + logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(self.task.ids.user_sub)}") if not self.post_body.app or self.post_body.app.app_id == "": llm = await self.get_llm_use_in_chat_with_rag() kb_ids = await self.get_kb_ids_use_in_chat_with_rag() - if not llm: - logger.error("[Scheduler] 获取大模型失败") - await self.queue.close() - return self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( kbIds=kb_ids, diff --git a/apps/services/activity.py b/apps/services/activity.py index 299a49a6..88142b9e 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -3,11 +3,13 @@ import uuid from datetime import UTC, datetime - +import logging from apps.common.mongo import MongoDB from apps.constants import SLIDE_WINDOW_QUESTION_COUNT, SLIDE_WINDOW_TIME from apps.exceptions import ActivityError +logger = logging.getLogger(__name__) + class Activity: """用户活动控制,限制单用户同一时间只能提问一个问题""" diff --git a/apps/services/task.py b/apps/services/task.py index 2f75a8c3..eec4e197 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -115,7 +115,6 @@ class TaskManager: @staticmethod async def init_new_task( - cls, user_sub: str, session_id: str | None = None, post_body: RequestData | None = None, @@ -180,6 +179,7 @@ class TaskManager: task_ids.append(task["_id"]) if task_ids: await task_collection.delete_many({"conversation_id": conversation_id}) + return task_ids except Exception: logger.exception("[TaskManager] 删除ConversationID的Task信息失败") return [] -- Gitee