diff --git a/apps/llm/generator.py b/apps/llm/generator.py index 170c21e7e9033389ff9b64df84cdc105f3fa6f95..b62bcb29909206dee8d8169499763e5e7a7b3220 100644 --- a/apps/llm/generator.py +++ b/apps/llm/generator.py @@ -243,7 +243,7 @@ class JsonGenerator: err_info = err_info.split("\n\n")[0] function_name = function["name"] - result_json = json.dumps(e, ensure_ascii=False) + result_json = json.dumps(err_info, ensure_ascii=False) retry_messages.append({ "role": "assistant", diff --git a/apps/models/__init__.py b/apps/models/__init__.py index 4c0301575ba0bd2f7d23a5b79e3f447e6292551c..e5951bd96ca657ebedc6f960d8fdf9dd8436279c 100644 --- a/apps/models/__init__.py +++ b/apps/models/__init__.py @@ -3,7 +3,6 @@ from .app import App, AppACL, AppHashes, AppMCP, AppType, PermissionType from .base import Base -from .blacklist import Blacklist from .comment import Comment, CommentType from .conversation import ConvDocAssociated, Conversation, ConversationDocument from .document import Document @@ -35,7 +34,6 @@ __all__ = [ "AppMCP", "AppType", "Base", - "Blacklist", "Comment", "CommentType", "ConvDocAssociated", diff --git a/apps/models/blacklist.py b/apps/models/blacklist.py deleted file mode 100644 index 878b41d4b3148a4be1270dbee69e0b22c1341b1d..0000000000000000000000000000000000000000 --- a/apps/models/blacklist.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""黑名单 数据库表结构""" - -import uuid -from datetime import UTC, datetime - -from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String, Text -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column - -from .base import Base - - -class Blacklist(Base): - """黑名单""" - - __tablename__ = "framework_blacklist" - id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) - """主键ID""" - recordId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 - UUID(as_uuid=True), ForeignKey("framework_record.id"), nullable=False, index=True, - ) - """关联的问答对ID""" - question: Mapped[str] = mapped_column(Text, nullable=False) - """黑名单问题""" - answer: Mapped[str | None] = mapped_column(Text, default=None, nullable=False) - """应做出的固定回答""" - isAudited: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # noqa: N815 - """黑名单是否生效""" - reasonType: Mapped[str] = mapped_column(String(255), default="", nullable=False) # noqa: N815 - """举报类型""" - reason: Mapped[str] = mapped_column(Text, default="", nullable=False) - """举报原因""" - updatedAt: Mapped[DateTime] = mapped_column( # noqa: N815 - DateTime, - default_factory=lambda: datetime.now(tz=UTC), - init=False, - nullable=False, - ) - """更新时间""" diff --git a/apps/models/task.py b/apps/models/task.py index 28ff783bd4771b8213f6a93319e2d932c8b17e9c..f458cd9b5682802022e48ac7df73d7cb0fb141ff 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -76,6 +76,10 @@ class Task(Base): nullable=True, default=None, ) """检查点ID""" + recordId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 + UUID(as_uuid=True), nullable=True, default=None, + ) + """记录ID""" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) """任务ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 @@ -153,6 +157,8 @@ class ExecutorCheckpoint(Base): """步骤额外数据""" errorMessage: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False, default_factory=dict) # noqa: N815 """错误信息""" + extraData: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False, default_factory=dict) # noqa: N815 + """执行器额外数据""" class ExecutorHistory(Base): diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f49a81d297c268a9cb51e9f2a612141fff8cfdbb..e6650f8fb32840430931ad06a94ad6334c8eb72c 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -3,6 +3,7 @@ import logging import uuid +from typing import Any from mcp.types import TextContent from pydantic import Field, ValidationError @@ -17,8 +18,7 @@ from apps.scheduler.pool.mcp.pool import mcp_pool from apps.schemas.enum_var import EventType from apps.schemas.flow import AgentAppMetadata from apps.schemas.mcp import MCPRiskConfirm, Step -from apps.schemas.message import FlowParams -from apps.schemas.task import AgentHistoryExtra +from apps.schemas.task import AgentCheckpointExtra, AgentHistoryExtra from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager from apps.services.user import UserManager @@ -30,7 +30,7 @@ class MCPAgentExecutor(BaseExecutor): agent_id: uuid.UUID = Field(default=uuid.uuid4(), description="App ID作为Agent ID") agent_description: str = Field(default="", description="Agent描述") - params: FlowParams | bool | None = Field( + params: dict[str, Any] | None = Field( default=None, description="流执行过程中的参数补充", alias="params", @@ -76,6 +76,30 @@ class MCPAgentExecutor(BaseExecutor): stepType="", ) + # 从state.extraData恢复状态(如果存在) + self._restore_extra_data() + + def _restore_extra_data(self) -> None: + """从 task.state.extraData 恢复所有状态""" + if not self.task.state or not self.task.state.extraData: + return + + try: + checkpoint_extra = AgentCheckpointExtra.model_validate(self.task.state.extraData) + self._current_input = checkpoint_extra.current_input + self._retry_times = checkpoint_extra.retry_times + self._current_goal = checkpoint_extra.step_goal + self._step_cnt = checkpoint_extra.step_cnt + _logger.info( + "[MCPAgentExecutor] 从checkpoint恢复extraData - " + "retry_times: %s, step_goal: %s, step_cnt: %s", + self._retry_times, + self._current_goal, + self._step_cnt, + ) + except (ValidationError, KeyError, TypeError) as e: + _logger.warning("[MCPAgentExecutor] 从checkpoint恢复extraData失败: %s", e) + async def load_mcp(self) -> None: """加载MCP服务器列表""" _logger.info("[MCPAgentExecutor] 加载MCP服务器列表") @@ -116,12 +140,6 @@ class MCPAgentExecutor(BaseExecutor): raise RuntimeError(err) # 获取输入参数 - if isinstance(self.params, FlowParams): - params = self.params.content - params_description = self.params.description - else: - params = {} - params_description = "" self._current_tool = self._tool_list[self.task.state.stepName] # 对于首次调用,使用空的current_input @@ -131,8 +149,8 @@ class MCPAgentExecutor(BaseExecutor): self._current_tool, self.task, current_input, - params, - params_description, + self.params, + self._current_goal, ) def _get_error_message_str(self, error_message: dict | str | None) -> str: @@ -217,6 +235,30 @@ class MCPAgentExecutor(BaseExecutor): if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: del self.task.context[-1] + def _update_checkpoint_extra_data(self) -> None: + """更新checkpoint的extraData""" + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + _logger.error(err) + raise RuntimeError(err) + + # 使用AgentCheckpointExtra数据结构保存checkpoint的extra_data + checkpoint_extra = AgentCheckpointExtra( + current_input=self._current_input, + error_message=self._get_error_message_str(self.task.state.errorMessage), + retry_times=self._retry_times, + step_goal=self._current_goal, + step_cnt=self._step_cnt, + ) + self.task.state.extraData = checkpoint_extra.model_dump() + _logger.info( + "[MCPAgentExecutor] 更新checkpoint extraData - " + "retry_times: %s, step_goal: %s, step_cnt: %s", + self._retry_times, + self._current_goal, + self._step_cnt, + ) + async def _handle_step_error_and_continue(self) -> None: """处理步骤错误并继续下一步""" if not self.task.state: @@ -227,6 +269,8 @@ class MCPAgentExecutor(BaseExecutor): # 先更新stepStatus self.task.state.stepStatus = StepStatus.ERROR + # 增加重试次数 + self._retry_times += 1 await self._push_message( EventType.STEP_OUTPUT, @@ -273,6 +317,8 @@ class MCPAgentExecutor(BaseExecutor): step_goal=self._current_goal, ), ) + # 进入等待状态前保存checkpoint + self._update_checkpoint_extra_data() async def run_step(self) -> None: """执行步骤""" @@ -360,6 +406,8 @@ class MCPAgentExecutor(BaseExecutor): step_goal=self._current_goal, ), ) + # 进入等待状态前保存checkpoint + self._update_checkpoint_extra_data() async def get_next_step(self) -> None: """获取下一步""" @@ -389,6 +437,10 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.stepStatus = StepStatus.INIT # 保存步骤目标 self._current_goal = step.description + # 重置重试次数 + self._retry_times = 0 + # 更新checkpoint的extraData + self._update_checkpoint_extra_data() else: # 没有下一步了,结束流程 self.task.state.stepName = AGENT_FINAL_STEP_NAME @@ -438,6 +490,7 @@ class MCPAgentExecutor(BaseExecutor): if risk_confirm.confirm: # 用户确认继续执行 self._remove_last_context_if_same_step() + await self.get_tool_input_param(is_first=False) else: # 用户拒绝执行 should_cancel = True @@ -575,6 +628,8 @@ class MCPAgentExecutor(BaseExecutor): await self._push_message(EventType.STEP_OUTPUT, data=error_output) await self._add_error_to_context(self.task.state.stepStatus) finally: + # 更新checkpoint的extraData(统一在执行结束时更新) + self._update_checkpoint_extra_data() # 清理MCP客户端 for mcp_service in self._mcp_list: try: diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index 2ca2c82fbdbbf3dc7bb4c467ad33c82bfbf89e6d..72b5b3559ae289e156442d8314cb5e9bd14abb3f 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -35,7 +35,7 @@ class MCPHost(MCPBase): task: TaskData, current_input: dict[str, Any], params: dict[str, Any] | None = None, - params_description: str = "", + current_goal: str = "", ) -> dict[str, Any]: """生成工具参数""" llm_query = _LLM_QUERY_FIX[task.runtime.language] @@ -43,13 +43,12 @@ class MCPHost(MCPBase): prompt = _env.from_string(await self._load_prompt("gen_params")).render( tool_name=mcp_tool.toolName, goal=task.runtime.userInput, - current_goal=task.runtime.userInput, + current_goal=current_goal, tool_description=mcp_tool.description, input_schema=mcp_tool.inputSchema, input_params=current_input, error_message=error_message, params=params, - params_description=params_description, ) # 组装OpenAI Function标准的Function结构 diff --git a/apps/scheduler/scheduler/data.py b/apps/scheduler/scheduler/data.py index 122a3444c0a3ca4df8b60df74c16e8c530e110c5..f01cc00ffdc3c5e4478c3a174688c36b90be9510 100644 --- a/apps/scheduler/scheduler/data.py +++ b/apps/scheduler/scheduler/data.py @@ -2,6 +2,7 @@ """数据管理相关的Mixin类""" import logging +import uuid from datetime import UTC, datetime from typing import Any from uuid import UUID @@ -56,19 +57,27 @@ class DataMixin: current_time: float, ) -> tuple[PgRecord, PgRecordMetadata]: """构建记录对象和元数据对象""" - task = self.task - user_id = task.metadata.userId - record_id = task.metadata.id + # 若Executor状态为WAITING,则使用task中的recordId;否则生成新UUID + if ( + self.task.state + and self.task.state.executorStatus == ExecutorStatus.WAITING + and self.task.metadata.recordId + ): + record_id = self.task.metadata.recordId + else: + record_id = uuid.uuid4() + # 更新task中的recordId + self.task.metadata.recordId = record_id - if task.metadata.conversationId is None: + if self.task.metadata.conversationId is None: msg = "conversationId cannot be None" raise ValueError(msg) pg_record = PgRecord( id=record_id, - conversationId=task.metadata.conversationId, - taskId=task.metadata.id, - userId=user_id, + conversationId=self.task.metadata.conversationId, + taskId=self.task.metadata.id, + userId=self.task.metadata.userId, content=encrypt_data, key=encrypt_config, createdAt=datetime.fromtimestamp(current_time, tz=UTC), diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 7a505ceee8371ec863608c11d66d6f99325c0687..5b7cb10f7132770ac318110ffd4e3c72fce30bbf 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -12,13 +12,6 @@ from apps.schemas.enum_var import EventType from .record import RecordMetadata -class FlowParams(BaseModel): - """流执行过程中的参数补充""" - - content: dict[str, Any] = Field(default={}, description="流执行过程中的参数补充内容") - description: str = Field(default="", description="流执行过程中的参数补充描述") - - class HeartbeatData(BaseModel): """心跳事件的数据结构""" diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 3884a135396252e1cb29cd2c4c7300b4000760fb..8f966f3f5b701a1c56119e1482ede18843cd56d2 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -9,7 +9,6 @@ from pydantic import BaseModel, Field from apps.models import LanguageType, LLMProvider from .flow_topology import FlowItem -from .message import FlowParams class RequestDataApp(BaseModel): @@ -17,7 +16,9 @@ class RequestDataApp(BaseModel): app_id: uuid.UUID = Field(description="应用ID", alias="appId") flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") - params: FlowParams | None = Field(default=None, description="流执行过程中的参数补充", alias="params") + params: dict[str, Any] | None = Field( + default=None, description="流执行过程中的参数补充", alias="params", + ) class RequestData(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 999034199b95cd538151e56a2846eeff360aac3f..4805465a5f66541922b9c9ff99869122afab62fc 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -26,6 +26,8 @@ class AgentCheckpointExtra(BaseModel): current_input: dict[str, Any] = Field(description="当前输入数据", default={}) error_message: str = Field(description="错误信息", default="") retry_times: int = Field(description="当前步骤重试次数", default=0) + step_goal: str = Field(description="当前步骤目标", default="") + step_cnt: int = Field(description="已执行步骤数", default=0) class AgentHistoryExtra(BaseModel): diff --git a/apps/services/record.py b/apps/services/record.py index 8e05259ed4afd697051d9ec9b6420cf56bb953be..3e069cab0da9b27d56f6ccf9edf10240c1ffc62f 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -21,6 +21,48 @@ logger = logging.getLogger(__name__) class RecordManager: """问答对相关操作""" + @staticmethod + async def get_conversation_and_record( + user_id: str, + conversation_id: uuid.UUID, + record_id: uuid.UUID | None = None, + ) -> tuple[bool, PgRecord | None]: + """ + 验证对话是否存在且获取已存在的记录 + + :param user_id: 用户ID + :param conversation_id: 会话ID + :param record_id: 记录ID(可选) + :return: (对话是否存在, 已存在的记录或None) + """ + async with postgres.session() as session: + # 验证对话是否存在 + conv = (await session.scalars( + select(Conversation).where( + and_( + Conversation.id == conversation_id, + Conversation.userId == user_id, + ), + ), + )).one_or_none() + + if not conv: + return False, None + + # 如果提供了record_id,则查找已存在的记录 + if record_id: + existing_record = (await session.scalars( + select(PgRecord).where( + and_( + PgRecord.id == record_id, + PgRecord.conversationId == conversation_id, + ), + ), + )).one_or_none() + return True, existing_record + + return True, None + @staticmethod async def verify_record_in_conversation(record_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID) -> bool: """ @@ -51,18 +93,28 @@ class RecordManager: metadata: PgRecordMetadata, ) -> uuid.UUID | None: """Record和RecordMetadata插入PostgreSQL""" + # 验证对话存在并检查是否有同ID的记录 + conv_exists, existing_record = await RecordManager.get_conversation_and_record( + user_id, + conversation_id, + record.id, + ) + + if not conv_exists: + logger.error("[RecordManager] 对话不存在: %s", conversation_id) + return None + async with postgres.session() as session: - conv = (await session.scalars( - select(Conversation).where( - and_( - Conversation.id == conversation_id, - Conversation.userId == user_id, - ), - ), - )).one_or_none() - if not conv: - logger.error("[RecordManager] 对话不存在: %s", conversation_id) - return None + if existing_record: + logger.warning("[RecordManager] 记录已存在,删除旧记录后重新保存: %s", record.id) + # 删除已存在的记录及其元数据 + existing_metadata = (await session.scalars( + select(PgRecordMetadata).where(PgRecordMetadata.recordId == record.id), + )).one_or_none() + if existing_metadata: + await session.delete(existing_metadata) + await session.delete(existing_record) + await session.flush() session.add(record) session.add(metadata) @@ -79,6 +131,16 @@ class RecordManager: order: Literal["desc", "asc"] = "desc", ) -> list[RecordData]: """查询ConversationID的最后n条问答对""" + # 验证对话是否存在 + conv_exists, _ = await RecordManager.get_conversation_and_record( + user_id, + conversation_id, + ) + + if not conv_exists: + logger.error("[RecordManager] 对话不存在: %s", conversation_id) + return [] + async with postgres.session() as session: sql = select(PgRecord).where( and_( diff --git a/apps/services/task.py b/apps/services/task.py index b4bfa494bbcdbb7708072a435e14da9720b1f83c..552f50af14eab968bd6207e1779f525b7cf72b50 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -98,15 +98,17 @@ class TaskManager: async def delete_task_by_task_id(task_id: uuid.UUID) -> None: """通过task_id删除Task信息""" async with postgres.session() as session: - await session.execute( - delete(Task).where(Task.id == task_id), - ) + # Delete child tables first to avoid foreign key constraint violations await session.execute( delete(TaskRuntime).where(TaskRuntime.taskId == task_id), ) await session.execute( delete(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == task_id), ) + # Delete parent table last + await session.execute( + delete(Task).where(Task.id == task_id), + ) await session.commit()