diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 59c82a8e02d83d3989263dd08fdbb17a401d03af..a5ccdf19a1dad51e5f578df0c43f1ad6344eaa61 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -1,6 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Executor基类""" +import asyncio +import logging from abc import ABC, abstractmethod from typing import Any @@ -16,6 +18,8 @@ from apps.schemas.scheduler import ExecutorBackground from apps.schemas.task import TaskData from apps.services.record import RecordManager +_logger = logging.getLogger(__name__) + class BaseExecutor(BaseModel, ABC): """Executor基类""" @@ -98,6 +102,18 @@ class BaseExecutor(BaseModel, ABC): data=data, ) + async def _check_cancelled(self) -> None: + """ + 检查当前任务是否已被取消,如果已取消则抛出CancelledError + + :raises asyncio.CancelledError: 当任务已被取消时 + """ + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + _logger.warning("[%s] 检测到取消信号,终止执行", self.__class__.__name__) + raise + @abstractmethod async def run(self) -> None: """运行Executor""" diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 643afaab7c53f840a6e58918237745f027b935f8..391b033538cd186093394af456d811a94e9edfa8 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -25,7 +25,7 @@ from .prompt import FLOW_ERROR_PROMPT from .step import StepExecutor logger = logging.getLogger(__name__) -# 开始前的固定步骤 + FIXED_STEPS_BEFORE_START = [ { LanguageType.CHINESE: Step( @@ -42,7 +42,7 @@ FIXED_STEPS_BEFORE_START = [ ), }, ] -# 结束后的固定步骤 + FIXED_STEPS_AFTER_END = [ { LanguageType.CHINESE: Step( @@ -61,7 +61,6 @@ FIXED_STEPS_AFTER_END = [ ] -# 单个流的执行工具 class FlowExecutor(BaseExecutor): """用于执行工作流的Executor""" @@ -75,12 +74,10 @@ class FlowExecutor(BaseExecutor): logger.info("[FlowExecutor] 加载Executor状态") await self._load_history() - # 尝试恢复State if ( not self.task.state or self.task.state.executorStatus == ExecutorStatus.INIT ): - # 创建ExecutorState self.task.state = ExecutorCheckpoint( taskId=self.task.metadata.id, appId=self.post_body_app.app_id, @@ -90,17 +87,14 @@ class FlowExecutor(BaseExecutor): stepStatus=StepStatus.RUNNING, stepId=self.flow.basicConfig.startStep, stepName=self.flow.steps[self.flow.basicConfig.startStep].name, - # 先转换为StepType,再转换为str,确定Flow的类型在其中 stepType=str(StepType(self.flow.steps[self.flow.basicConfig.startStep].type)), ) - # 是否到达Flow结束终点(变量) self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() async def _invoke_runner(self) -> None: """单一Step执行""" - # 创建步骤Runner step_runner = StepExecutor( msg_queue=self.msg_queue, task=self.task, @@ -110,9 +104,7 @@ class FlowExecutor(BaseExecutor): llm=self.llm, ) - # 初始化步骤 await step_runner.init() - # 运行Step await step_runner.run() # 更新Task(已存过库) @@ -127,7 +119,6 @@ class FlowExecutor(BaseExecutor): except IndexError: break - # 执行Step await self._invoke_runner() @@ -147,11 +138,9 @@ class FlowExecutor(BaseExecutor): logger.error(err) raise RuntimeError(err) - # 如果当前步骤为结束,则直接返回 if self.task.state.stepId == "end" or not self.task.state.stepId: return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: - # 如果是choice节点,获取分支ID branch_id = self.task.context[-1].outputData["branch_id"] if branch_id: next_steps = await self._find_next_id(str(self.task.state.stepId) + "." + branch_id) @@ -187,19 +176,18 @@ class FlowExecutor(BaseExecutor): 数据通过向Queue发送消息的方式传输 """ logger.info("[FlowExecutor] 运行工作流") + await self._check_cancelled() if not self.task.state: err = "[FlowExecutor] 任务状态不存在" logger.error(err) raise RuntimeError(err) - # 获取首个步骤 first_step = StepQueueItem( step_id=self.task.state.stepId, step=self.flow.steps[self.task.state.stepId], ) - # 头插开始前的系统步骤,并执行 for step in FIXED_STEPS_BEFORE_START: self.step_queue.append( StepQueueItem( @@ -211,14 +199,12 @@ class FlowExecutor(BaseExecutor): ) await self._step_process() - # 插入首个步骤 self.step_queue.append(first_step) self.task.state.executorStatus = ExecutorStatus.RUNNING - # 运行Flow(未达终点) is_error = False while not self._reached_end: - # 如果当前步骤出错,执行错误处理步骤 + await self._check_cancelled() if self.task.state.stepStatus == StepStatus.ERROR: logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() @@ -246,27 +232,21 @@ class FlowExecutor(BaseExecutor): ), ) is_error = True - # 错误处理后结束 self._reached_end = True - # 执行步骤 await self._step_process() - # 查找下一个节点 next_step = await self._find_flow_next() if not next_step: - # 没有下一个节点,结束 self._reached_end = True for step in next_step: self.step_queue.append(step) - # 更新Task状态 if is_error: self.task.state.executorStatus = ExecutorStatus.ERROR else: self.task.state.executorStatus = ExecutorStatus.SUCCESS - # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: self.step_queue.append( StepQueueItem( @@ -278,5 +258,4 @@ class FlowExecutor(BaseExecutor): # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.fullTime - # 推送Flow停止消息 await self._push_message(EventType.EXECUTOR_STOP.value) diff --git a/apps/scheduler/executor/qa.py b/apps/scheduler/executor/qa.py index 4599b296e800b436ac3b4cca8da87be4e1c71763..be8f1676b041d4bae6d4ec2ae06a5631ec2e8c7c 100644 --- a/apps/scheduler/executor/qa.py +++ b/apps/scheduler/executor/qa.py @@ -209,6 +209,8 @@ class QAExecutor(BaseExecutor): async def _execute_remaining_steps(self) -> bool: """执行剩余步骤:问题推荐和记忆存储""" _logger.info("[QAExecutor] 开始执行问题推荐步骤") + + await self._check_cancelled() suggestion_exec = StepExecutor( msg_queue=self.msg_queue, task=self.task, @@ -230,6 +232,7 @@ class QAExecutor(BaseExecutor): self.task.state.executorStatus = ExecutorStatus.ERROR return False + await self._check_cancelled() _logger.info("[QAExecutor] 开始执行记忆存储步骤") facts_exec = StepExecutor( msg_queue=self.msg_queue, @@ -262,21 +265,25 @@ class QAExecutor(BaseExecutor): error = Exception("[QAExecutor] task.state不存在,无法执行") raise error + await self._check_cancelled() rag_success = await self._execute_rag_step() if not rag_success: _logger.error("[QAExecutor] RAG检索步骤失败,终止执行") return + await self._check_cancelled() llm_success = await self._execute_llm_step() if not llm_success: _logger.error("[QAExecutor] LLM问答步骤失败,终止执行") return + await self._check_cancelled() remaining_success = await self._execute_remaining_steps() if not remaining_success: _logger.error("[QAExecutor] 剩余步骤失败,终止执行") return + await self._check_cancelled() self.task.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.time self.task.state.executorStatus = ExecutorStatus.SUCCESS diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 2d9d72d87077a455cb04d806d8bbb295960b01b1..6f42987fd6adb09eda058733e24c449a9939955c 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -128,6 +128,7 @@ class StepExecutor(BaseExecutor): # 判断是否需要进行自动参数填充 if not self.obj.enable_filling: return + await self._check_cancelled() if not self.task.state: err = "[StepExecutor] 任务状态不存在" @@ -158,6 +159,7 @@ class StepExecutor(BaseExecutor): # 运行填参 iterator = slot_obj.exec(self, slot_obj.input) async for chunk in iterator: + await self._check_cancelled() result: SlotOutput = SlotOutput.model_validate(chunk.content) # 如果没有填全,则状态设置为待填参 @@ -183,8 +185,12 @@ class StepExecutor(BaseExecutor): ) -> str | dict[str, Any]: """处理Chunk""" content: str | dict[str, Any] = "" + chunk_count = 0 async for chunk in iterator: + chunk_count += 1 + if chunk_count % 10 == 0: + await self._check_cancelled() if not isinstance(chunk, CallOutputChunk): err = "[StepExecutor] 返回结果类型错误" logger.error(err) @@ -210,6 +216,7 @@ class StepExecutor(BaseExecutor): async def run(self) -> None: """运行单个步骤""" logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name) + await self._check_cancelled() if not self.task.state: err = "[StepExecutor] 任务状态不存在" @@ -224,6 +231,7 @@ class StepExecutor(BaseExecutor): self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self._push_message(EventType.STEP_INPUT.value, self.obj.input) + await self._check_cancelled() # 执行步骤 iterator = self.obj.exec(self, self.obj.input) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 7f44e5177b76c48ddcebe81d4793cb7977aa566a..b302fe8260d70c6da1a794a4316245603afe6eb3 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -2,6 +2,7 @@ """调度器;负责任务的分发与执行""" import asyncio +import contextlib import logging from apps.common.queue import MessageQueue @@ -70,13 +71,27 @@ class Scheduler( await self._push_executor_start_message() - done, pending = await asyncio.wait( - [main_task, monitor], - return_when=asyncio.FIRST_COMPLETED, - ) + while not main_task.done() and not monitor.done(): + if kill_event.is_set(): + await self._handle_task_cancellation(main_task) + monitor.cancel() + break - if kill_event.is_set(): + done, _ = await asyncio.wait( + [main_task, monitor], + timeout=0.1, # 100ms检查一次 + return_when=asyncio.FIRST_COMPLETED, + ) + + if done: + break + + if kill_event.is_set() and not main_task.done(): await self._handle_task_cancellation(main_task) + if not monitor.done(): + monitor.cancel() + with contextlib.suppress(asyncio.CancelledError): + await monitor await self._check_and_handle_executor_result() diff --git a/apps/scheduler/scheduler/util.py b/apps/scheduler/scheduler/util.py index 3b2fd009807d3ae22f249be4f42c68f7a8e58a42..d074e97d2cb49cc4a94774c83d3265ddce0cbbb2 100644 --- a/apps/scheduler/scheduler/util.py +++ b/apps/scheduler/scheduler/util.py @@ -56,9 +56,9 @@ class UtilMixin: check_interval = 0.5 while not kill_event.is_set(): - can_active = await Activity.can_active(user_id) + is_active = await Activity.is_active(user_id) - if not can_active: + if not is_active: _logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_id) kill_event.set() break diff --git a/apps/services/activity.py b/apps/services/activity.py index 5ef0b80c8ac9f2b1ba064bacff13d03724eae39a..69c599bc69414fa30fc66f3020f142fbb7cf4eaf 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -51,6 +51,15 @@ class Activity: session.add(SessionActivity(userId=user_id, timestamp=time)) await session.commit() + @staticmethod + async def is_active(user_id: str) -> bool: + """判断用户是否仍然活跃(即是否有活动记录)""" + async with postgres.session() as session: + count = (await session.scalars( + select(func.count(SessionActivity.id)) + .where(SessionActivity.userId == user_id), + )).one() + return count > 0 @staticmethod async def remove_active(user_id: str) -> None: