diff --git a/jiuwen/agent/common/schema.py b/jiuwen/agent/common/schema.py index 32c7726c98f2d83055abdfef0496c1efe9fc7521..03b4b7e51c70e8a304d26fef75e9b59d16954501 100644 --- a/jiuwen/agent/common/schema.py +++ b/jiuwen/agent/common/schema.py @@ -19,6 +19,7 @@ class WorkflowSchema(BaseModel): class PluginSchema(BaseModel): id: str = Field(default="") + version: str = Field(default="") name: str = Field(default="") description: str = Field(default="") inputs: Dict[str, Any] = Field(default_factory=dict) diff --git a/jiuwen/agent/config/workflow_config.py b/jiuwen/agent/config/workflow_config.py index 95361507c7228586b0b9c8f7a8b8d34014354d4d..5f035dd649917e9f4b7b51a9597aa4cf9dc56de5 100644 --- a/jiuwen/agent/config/workflow_config.py +++ b/jiuwen/agent/config/workflow_config.py @@ -17,3 +17,7 @@ class WorkflowAgentConfig(AgentConfig): global_variables: List[dict] = Field(default_factory=list) # 全局参数模板(可选) global_params: Dict[str, Any] = Field(default_factory=dict) + + @property + def is_single_workflow(self) -> bool: + return len(self.workflows) == 1 diff --git a/jiuwen/agent/react_agent.py b/jiuwen/agent/react_agent.py index 2807ff4ba47f4d4707e988e4d920995af65826eb..6407b1069728518accc16f1dbf5d4f1e9b20361f 100644 --- a/jiuwen/agent/react_agent.py +++ b/jiuwen/agent/react_agent.py @@ -13,7 +13,8 @@ from jiuwen.core.agent.controller.react_controller import ReActController, ReAct ReActControllerInput from jiuwen.core.agent.handler.base import AgentHandlerImpl, AgentHandlerInputs from jiuwen.agent.state.react_state import ReActState -from jiuwen.core.agent.task.task import SubTask +from jiuwen.core.agent.task.sub_task import SubTask +from jiuwen.core.agent.task.task import Task from jiuwen.core.component.common.configs.model_config import ModelConfig from jiuwen.core.context.context import Context from jiuwen.core.context.controller_context.controller_context_manager import ControllerContextMgr @@ -21,9 +22,9 @@ from jiuwen.core.utils.llm.messages import ToolMessage from jiuwen.core.utils.tool.base import Tool from jiuwen.core.workflow.base import Workflow - REACT_AGENT_STATE_KEY = "react_agent_state" + def create_react_agent_config(agent_id: str, agent_version: str, description: str, @@ -66,7 +67,9 @@ class ReActAgent(Agent): def _init_controller_context_manager(self) -> ControllerContextMgr: return ControllerContextMgr(self._config) - def invoke(self, inputs: Dict, context: Context) -> Dict: + async def invoke(self, inputs: Dict) -> Dict: + task: Task = self._task_manager.create_task(inputs.get("conversation_id")) + context = task.context context.set_controller_context_manager(self._controller_context_manager) self._load_state_from_context(context) controller_output = ReActControllerOutput() @@ -75,7 +78,7 @@ class ReActAgent(Agent): self._state.handle_llm_response_event(controller_output.llm_output, controller_output.sub_tasks) self._store_state_to_context(context) if controller_output.should_continue: - completed_sub_tasks = self._execute_sub_tasks(context) + completed_sub_tasks = await self._execute_sub_tasks(context) else: break self._state.handle_tool_invoked_event(completed_sub_tasks) @@ -86,15 +89,15 @@ class ReActAgent(Agent): self._store_state_to_context(context) return dict(output=self._state.final_result) - def stream(self, inputs: Dict, context: Context) -> Iterator[Any]: + async def stream(self, inputs: Dict) -> Iterator[Any]: pass - def _execute_sub_tasks(self, context: Context): + async def _execute_sub_tasks(self, context: Context): to_exec_sub_tasks = self._state.sub_tasks completed_sub_tasks = [] for st in to_exec_sub_tasks: inputs = AgentHandlerInputs(context=context, name=st.func_name, arguments=st.func_args) - exec_result = self._agent_handler.invoke(st.sub_task_type, inputs) + exec_result = await self._agent_handler.invoke(st.sub_task_type, inputs) st.result = json.dumps(exec_result, ensure_ascii=False) if isinstance(exec_result, dict) else exec_result completed_sub_tasks.append(st) self._update_chat_history_in_context(completed_sub_tasks, context) diff --git a/jiuwen/agent/state/react_state.py b/jiuwen/agent/state/react_state.py index bb2a76f7509cb8715b6a8f66b156da1ce6b2b267..393897acf74602d312abcbd3a66bb506ac54e005 100644 --- a/jiuwen/agent/state/react_state.py +++ b/jiuwen/agent/state/react_state.py @@ -7,7 +7,7 @@ from typing import Optional, List from pydantic import BaseModel, Field from jiuwen.agent.common.enum import ReActStatus -from jiuwen.core.agent.task.task import SubTask +from jiuwen.core.agent.task.sub_task import SubTask from jiuwen.core.utils.llm.messages import AIMessage diff --git a/jiuwen/agent/workflow_agent.py b/jiuwen/agent/workflow_agent.py index 81c4b5af40d52b415e0c330b73f1c2e4050bd69e..27267f4be6e72b74da84c7926b556d2792709955 100644 --- a/jiuwen/agent/workflow_agent.py +++ b/jiuwen/agent/workflow_agent.py @@ -1,38 +1,51 @@ from typing import Dict from jiuwen.agent.config.workflow_config import WorkflowAgentConfig +from jiuwen.core.agent.task.task import Task +from jiuwen.core.context.agent_context import AgentContext from jiuwen.core.context.controller_context.controller_context_manager import ControllerContextMgr from jiuwen.core.agent.controller.workflow_controller import WorkflowController, WorkflowControllerOutput from jiuwen.core.agent.agent import Agent -from jiuwen.core.agent.handler.base import AgentHandlerImpl -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context -from jiuwen.core.context.memory.base import InMemoryState +from jiuwen.core.agent.handler.base import AgentHandlerImpl, AgentHandlerInputs class WorkflowAgent(Agent): - def __init__(self, agent_config: WorkflowAgentConfig): - super().__init__(agent_config) + def __init__(self, agent_config: WorkflowAgentConfig, agent_context: AgentContext): + super().__init__(agent_config, agent_context) + self._config = agent_config def _init_controller(self): return WorkflowController(self._config, self._controller_context_manager) def _init_agent_handler(self): - return AgentHandlerImpl() + return AgentHandlerImpl(self._config) def _init_controller_context_manager(self) -> ControllerContextMgr: - context = Context(config=Config(), state=InMemoryState(), store=None, tracer=None) - return ControllerContextMgr(self._config, context) - - def _init_task_manager(self): - return None - - def invoke(self, inputs: Dict, context: Context) -> Dict: - output: WorkflowControllerOutput = self._controller.invoke(inputs) - - outputs = [self._agent_handler.invoke(st) for st in output.sub_tasks] - - return {"outputs": outputs} - - def stream(self, inputs: Dict, context: Context): + return ControllerContextMgr(self._config) + + async def invoke(self, inputs: Dict) -> Dict: + task: Task = self._task_manager.create_task(inputs.get("conversation_id")) + context = task.context + context.set_controller_context_manager(self._controller_context_manager) + current_inputs = inputs + + while True: + controller_output: WorkflowControllerOutput = self._controller.invoke(current_inputs, context) + results = {} + if controller_output.sub_tasks: + for sub_task in controller_output.sub_tasks: + inputs = AgentHandlerInputs(context=context, name=sub_task.func_name, arguments=sub_task.func_args) + result = await self._agent_handler.invoke(sub_task.sub_task_type, inputs) + results[sub_task.func_name] = result + if not self._controller.should_continue(controller_output): + output = self.handle_workflow_results(results) + return output + current_inputs = results + + async def stream(self, inputs: Dict): pass + + def handle_workflow_results(self, results): + if self._config.is_single_workflow: + return results[self._config.workflows[0].name] + raise Exception("Multi-workflow not implemented yet") diff --git a/jiuwen/core/agent/agent.py b/jiuwen/core/agent/agent.py index ef30a4197800152d8636708ab5eedc2b2c365211..bbecc9cc16ebfc7a017b2cee9a0ffa8c8c381d4d 100644 --- a/jiuwen/core/agent/agent.py +++ b/jiuwen/core/agent/agent.py @@ -2,6 +2,9 @@ from abc import ABC, abstractmethod from typing import Any, Iterator, Optional, Dict, List from jiuwen.agent.config.base import AgentConfig +from jiuwen.core.agent.task.task_manager import TaskManager +from jiuwen.core.context.agent_context import AgentContext +from jiuwen.core.context.controller_context.controller_context_manager import ControllerContextMgr from jiuwen.core.context.context import Context from jiuwen.core.context.controller_context.controller_context_manager import ControllerContextMgr from jiuwen.core.utils.tool.base import Tool @@ -16,33 +19,13 @@ class Agent(ABC): - stream : 流式调用 """ - def __init__(self, agent_config: "AgentConfig"): + def __init__(self, agent_config: "AgentConfig", agent_context: "AgentContext" = None) -> None: self._config = agent_config self._controller_context_manager: Optional["ControllerContextMgr"] = \ self._init_controller_context_manager() self._controller: "Controller | None" = self._init_controller() self._agent_handler: "AgentHandler | None" = self._init_agent_handler() - self._task_manager: "TaskManager | None" = self._init_task_manager() - - @abstractmethod - def invoke(self, inputs: Dict, context: Context) -> Dict: - """ - 同步调用,一次性返回最终结果 - """ - pass - - @abstractmethod - def stream(self, inputs: Dict, context: Context) -> Iterator[Any]: - """ - 流式调用,逐个 yield 中间结果 - """ - pass - - def bind_workflows(self, workflows: List[Workflow]): - self._controller_context_manager.workflow_mgr.add_workflows(workflows) - - def bind_tools(self, tools: List[Tool]): - self._controller_context_manager.workflow_mgr.add_tools(tools) + self._task_manager: "TaskManager | None" = self._init_task_manager(agent_context) def _init_controller(self) -> "Controller | None": """ @@ -56,11 +39,13 @@ class Agent(ABC): """ return None - def _init_task_manager(self) -> "TaskManager | None": + def _init_task_manager(self, agent_context: AgentContext) -> "TaskManager | None": """ 留给子类按需实例化 TaskManager;默认返回 None """ - return None + if not agent_context: + agent_context = AgentContext() + return TaskManager(agent_context) def _init_controller_context_manager(self) -> Optional["ControllerContextMgr"]: """ @@ -68,3 +53,23 @@ class Agent(ABC): 默认返回 None,表示无需上下文管理。 """ return None + + @abstractmethod + async def invoke(self, inputs: Dict) -> Dict: + """ + 同步调用,一次性返回最终结果 + """ + pass + + @abstractmethod + async def stream(self, inputs: Dict) -> Iterator[Any]: + """ + 流式调用,逐个 yield 中间结果 + """ + pass + + def bind_workflows(self, workflows: List[Workflow]): + self._controller_context_manager.workflow_mgr.add_workflows(workflows) + + def bind_tools(self, tools: List[Tool]): + self._controller_context_manager.workflow_mgr.add_tools(tools) diff --git a/jiuwen/core/agent/controller/base.py b/jiuwen/core/agent/controller/base.py index 5937f56ad6771549e8d3dfea8c3c78d836b789f5..1c5259a6b41dd30b70039eb37eb2e5acaff99b90 100644 --- a/jiuwen/core/agent/controller/base.py +++ b/jiuwen/core/agent/controller/base.py @@ -2,14 +2,17 @@ # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved """Controller of Agent""" +from typing import Optional + from pydantic import BaseModel, Field from jiuwen.agent.config.base import AgentConfig +from jiuwen.core.agent.task.task import Task from jiuwen.core.context.context import Context class ControllerOutput(BaseModel): - ... + is_task: bool = False class ControllerInput(BaseModel): @@ -24,3 +27,6 @@ class Controller: def invoke(self, inputs: ControllerInput, context: Context) -> ControllerOutput: pass + + def should_continue(self, output) -> bool: + pass diff --git a/jiuwen/core/agent/controller/react_controller.py b/jiuwen/core/agent/controller/react_controller.py index ca1982411580e598cabf20d13a425cb37bbeaa1a..7a1e9888eccd65865cd117233458af6e7e3e08c5 100644 --- a/jiuwen/core/agent/controller/react_controller.py +++ b/jiuwen/core/agent/controller/react_controller.py @@ -13,7 +13,7 @@ from jiuwen.agent.common.schema import WorkflowSchema, PluginSchema from jiuwen.core.agent.controller.base import ControllerOutput, ControllerInput, Controller from jiuwen.core.agent.handler.base import AgentHandler from jiuwen.agent.config.base import AgentConfig -from jiuwen.core.agent.task.task import SubTask +from jiuwen.core.agent.task.sub_task import SubTask from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.exception.status_code import StatusCode from jiuwen.core.context.context import Context diff --git a/jiuwen/core/agent/controller/workflow_controller.py b/jiuwen/core/agent/controller/workflow_controller.py index 3747c8a0f8eeab17e552355291aa73ba2d33a2b6..a835733682b23a81eed1957f15bba90ed6b2493b 100644 --- a/jiuwen/core/agent/controller/workflow_controller.py +++ b/jiuwen/core/agent/controller/workflow_controller.py @@ -1,19 +1,23 @@ -from typing import List +from typing import List, Union, Iterator, Any, Dict from pydantic import Field -from jiuwen.agent.config.base import AgentConfig +from jiuwen.agent.common.enum import SubTaskType from jiuwen.core.agent.controller.base import Controller, ControllerOutput, ControllerInput -from jiuwen.core.agent.task.task import SubTask -from jiuwen.core.context.context import Context +from jiuwen.core.agent.task.sub_task import SubTask + + +class Message: + ... class WorkflowControllerOutput(ControllerOutput): sub_tasks: List[SubTask] = Field(default_factory=list) + messages: Any = Field(default_factory=list) class WorkflowControllerInput(ControllerInput): - ... + workflow_inputs: dict = Field(default_factory=dict) class WorkflowController(Controller): @@ -21,8 +25,59 @@ class WorkflowController(Controller): 根据输入生成 WorkflowControllerOutput """ - def __init__(self, config: AgentConfig, context_mgr): - super().__init__(config, context_mgr) + @staticmethod + def _filter_inputs(schema: dict, user_data: dict) -> dict: + """ + 根据 schema 过滤并校验用户输入 + :param schema: workflow.inputs 的 schema,形如 {"query": {"type": "string", "required": True}} + :param user_data: 用户实际传入的数据,形如 {"query": "你好", "foo": "bar"} + :return: 仅保留 schema 中声明的字段 + :raises KeyError: 缺失必填字段时抛出 + """ + if not schema: + return {} + + required_fields = { + k for k, v in schema.items() + if isinstance(v, dict) and v.get("required") is True + } + + filtered = {} + for k in schema: + if k not in user_data: + if k in required_fields: + raise KeyError(f"缺少必填参数: {k}") + continue + filtered[k] = user_data[k] + + return filtered + + def invoke( + self, inputs: Dict, context + ) -> WorkflowControllerOutput: + if len(self._config.workflows) > 1: + raise NotImplementedError("Multi-workflow not implemented yet") + + workflow = self._config.workflows[0] + + filtered_inputs = self._filter_inputs( + schema=workflow.inputs or {}, + user_data=inputs + ) + + sub_tasks = [ + SubTask( + sub_task_type=SubTaskType.WORKFLOW, + func_name=workflow.name, + func_id=f"{workflow.id}_{workflow.version}", + func_args=filtered_inputs, + ) + ] + + return WorkflowControllerOutput(is_task=True, sub_tasks=sub_tasks) - def invoke(self, inputs: WorkflowControllerInput, context: Context) -> WorkflowControllerOutput: - return WorkflowControllerOutput() + def should_continue(self, output: WorkflowControllerOutput) -> bool: + """ + 当且仅当 output 是 Task 时继续下一轮 + """ + return not output.is_task diff --git a/jiuwen/core/agent/handler/base.py b/jiuwen/core/agent/handler/base.py index 812de3b27662e724462f79a7e4b1f72df8bf811b..1932de21a78def7a495fec6bb959f53f999dfabd 100644 --- a/jiuwen/core/agent/handler/base.py +++ b/jiuwen/core/agent/handler/base.py @@ -2,7 +2,7 @@ # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved """Handler of Agent""" -from typing import Dict, Callable, Any +from typing import Dict, Callable, Any, Awaitable from pydantic import BaseModel, Field @@ -21,28 +21,28 @@ class AgentHandlerInputs(BaseModel): class AgentHandler: def __init__(self, agent_config: AgentConfig): - self._function_map: Dict[SubTaskType, Callable[[AgentHandlerInputs], dict]] = { + self._function_map: Dict[SubTaskType, Callable[[AgentHandlerInputs], Awaitable[dict]]] = { SubTaskType.WORKFLOW: self.invoke_workflow, SubTaskType.PLUGIN: self.invoke_plugin } self._config = agent_config - def invoke(self, sub_task_type: SubTaskType, inputs: AgentHandlerInputs): + async def invoke(self, sub_task_type: SubTaskType, inputs: AgentHandlerInputs): handler = self._function_map.get(sub_task_type) if not handler: raise JiuWenBaseException() - return handler(inputs) + return await handler(inputs) - def invoke_workflow(self, inputs: AgentHandlerInputs): + async def invoke_workflow(self, inputs: AgentHandlerInputs): return dict() - def invoke_plugin(self, inputs: AgentHandlerInputs): + async def invoke_plugin(self, inputs: AgentHandlerInputs): return dict() - def invoke_llm(self, inputs: AgentHandlerInputs): + async def invoke_llm(self, inputs: AgentHandlerInputs): return dict() - def send_message(self, inputs: AgentHandlerInputs): + async def send_message(self, inputs: AgentHandlerInputs): return dict() @@ -50,31 +50,29 @@ class AgentHandlerImpl(AgentHandler): def __init__(self, agent_config: AgentConfig): super().__init__(agent_config) - def invoke(self, sub_task_type: SubTaskType, inputs: AgentHandlerInputs): + async def invoke(self, sub_task_type: SubTaskType, inputs: AgentHandlerInputs): handler = self._function_map.get(sub_task_type) if not handler: raise JiuWenBaseException() - return handler(inputs) + return await handler(inputs) - def invoke_workflow(self, inputs: AgentHandlerInputs): + async def invoke_workflow(self, inputs: AgentHandlerInputs): context = inputs.context - query = inputs.query workflow_name = inputs.name - context_manager = context._controller_context_manager + context_manager = context.controller_context_manager workflow_manager = context_manager.workflow_mgr workflow_metadata = self._search_workflow_metadata_by_workflow_name(workflow_name) workflow = workflow_manager.find_workflow_by_id_and_version(workflow_metadata.id, workflow_metadata.version) - workflow_inputs = dict(query=query, userFields=inputs.arguments) - workflow_result = workflow.invoke(workflow_inputs, context) + workflow_result = await workflow.invoke(inputs.arguments, context) return workflow_result - def invoke_plugin(self, inputs: AgentHandlerInputs): + async def invoke_plugin(self, inputs: AgentHandlerInputs): context = inputs.context plugin_name = inputs.name plugin_args = inputs.arguments - context_manager = context._controller_context_manager + context_manager = context.controller_context_manager workflow_manager = context_manager.workflow_mgr plugin = workflow_manager.find_tool_by_name(plugin_name) plugin_result = plugin.invoke(plugin_args) diff --git a/jiuwen/core/agent/task/sub_task.py b/jiuwen/core/agent/task/sub_task.py new file mode 100644 index 0000000000000000000000000000000000000000..32423fb86d769822886c2b26d74312379950baf3 --- /dev/null +++ b/jiuwen/core/agent/task/sub_task.py @@ -0,0 +1,15 @@ +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from jiuwen.agent.common.enum import SubTaskType + + +class SubTask(BaseModel): + id: str = Field(default="") + sub_task_type: SubTaskType = Field(default=SubTaskType.UNDEFINED) + func_id: str = Field(default="") + func_name: str = Field(default="") + func_args: dict = Field(default_factory=dict) + result: Optional[str] = Field(default=None) + sub_task_context: Any = Field(default=None) diff --git a/jiuwen/core/agent/task/task.py b/jiuwen/core/agent/task/task.py index 3a5e7105be2ec5381d14ea871442c2a7c8948953..56da9af75335c3e6f82d95227954dbffb1d13d27 100644 --- a/jiuwen/core/agent/task/task.py +++ b/jiuwen/core/agent/task/task.py @@ -1,28 +1,13 @@ -import uuid -from typing import Any, Dict, Optional, List, Union - -from pydantic import BaseModel, Field - -from jiuwen.agent.common.enum import SubTaskType, TaskStatus - - -class SubTask(BaseModel): - id: str = Field(default="") - sub_task_type: SubTaskType = Field(default=SubTaskType.UNDEFINED) - func_name: str = Field(default="") - func_args: dict = Field(default_factory=dict) - result: Optional[str] = Field(default=None) +from jiuwen.agent.common.enum import TaskStatus +from jiuwen.core.context.context import Context class Task: - def __init__(self, payload: Dict[str, Any], task_id: Optional[str] = None): - self.id: str = task_id or str(uuid.uuid4()) - self.payload: Dict[str, Any] = payload - self.sub_tasks: List[SubTask] = [] + def __init__(self, task_id: str, context: Context): + self.task_id = task_id + self.context = context self.status: TaskStatus = TaskStatus.PENDING - def add_sub_task(self, sub_task: SubTask) -> str: - if not sub_task.id: # 若调用方没给 id,自动生成 - sub_task.id = f"{self.id}_{len(self.sub_tasks)}" - self.sub_tasks.append(sub_task) - return sub_task.id + # 便捷方法:一键设置整体状态 + def set_status(self, status: TaskStatus) -> None: + self.status = status diff --git a/jiuwen/core/agent/task/task_manager.py b/jiuwen/core/agent/task/task_manager.py index e7123d7b94c3de59fd94f0b51c5f9e1e589662ae..0b72ebc35e15382e3f22a8eea124b19395179f1e 100644 --- a/jiuwen/core/agent/task/task_manager.py +++ b/jiuwen/core/agent/task/task_manager.py @@ -1,27 +1,53 @@ -import threading -from typing import Dict, Any, Optional +from typing import List, Optional, Any, Dict +from enum import Enum -from jiuwen.agent.common.enum import TaskStatus from jiuwen.core.agent.task.task import Task +from jiuwen.core.context.agent_context import AgentContext +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import Context, ExecutableContext +from jiuwen.core.context.memory.base import InMemoryState + + +class TaskStatus(Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + UNKNOWN = "unknown" class TaskManager: - def __init__(self): - self._lock = threading.Lock() + def __init__(self, agent_context: "AgentContext") -> None: + self.agent_context = agent_context self._tasks: Dict[str, Task] = {} - def submit(self, payload: Dict[str, Any], task_id: Optional[str] = None) -> str: - with self._lock: - task = Task(payload, task_id) - self._tasks[task.id] = task - return task.id - - def get(self, task_id: str) -> Optional[Task]: - with self._lock: - return self._tasks.get(task_id) - - def update_status(self, task_id: str, status: TaskStatus) -> None: - with self._lock: - t = self._tasks.get(task_id) - if t: - t.status = status + def create_task(self, conversation_id: str) -> Task: + """ + 如果 conversation_id 已存在则复用其 Context, + 否则新建 Context 并返回新的 Task。 + """ + task_id = conversation_id # 直接使用 conversation_id 作为 task_id + + # 复用已有context + if task_id in self._tasks: + return self._tasks[task_id] + + # 新建context + context = Context(config=Config(), + state=InMemoryState(), + store=self.agent_context.store) + executable_context = ExecutableContext(context=context, node_id="agent") + task = Task(task_id, executable_context) + + self._tasks[task_id] = task + self.agent_context.context_map[task_id] = executable_context + return task + + def get_task(self, conversation_id: str) -> Optional[Task]: + return self._tasks.get(conversation_id) + + def remove_task(self, conversation_id: str) -> None: + task = self._tasks.pop(conversation_id, None) + if task is None: + return + self.agent_context.context_map.pop(conversation_id, None) diff --git a/jiuwen/core/common/exception/status_code.py b/jiuwen/core/common/exception/status_code.py index 12ba8ad55a7e256b18693c0328eaa181261cf7ab..85f9d21bcf89a63567527adf71a5cdbca4c59e62 100644 --- a/jiuwen/core/common/exception/status_code.py +++ b/jiuwen/core/common/exception/status_code.py @@ -5,6 +5,9 @@ from enum import Enum class StatusCode(Enum): + # Agent模块 103025~103050 + AGENT_SUB_TASK_TYPE_ERROR = (103032, "SubTask type {msg} is not supported") + CONTROLLER_INTERRUPTED_ERROR = (10312, "controller interrupted error") PROMPT_JSON_SCHEMA_ERROR = (102056, "Invalid json schema, root cause = {error_msg}.") diff --git a/jiuwen/core/component/llm_comp.py b/jiuwen/core/component/llm_comp.py index bca67fedb276d55bbd4805a4dc8d8b44d6608e78..aca71ae94a8ea2497931875e4113d3699ca0025d 100644 --- a/jiuwen/core/component/llm_comp.py +++ b/jiuwen/core/component/llm_comp.py @@ -145,6 +145,7 @@ class LLMExecutable(Executable): logger.info("[%s] model inputs %s", self._context.executable_id, model_inputs) llm_response = await self._llm.ainvoke(model_inputs) response = llm_response.content + self._context.state.update({"response": response}) logger.info("[%s] model outputs %s", self._context.executable_id, response) return self._create_output(response) except JiuWenBaseException: diff --git a/jiuwen/core/context/agent_context.py b/jiuwen/core/context/agent_context.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0ef984849923ee3af9933a7da2869d4f73256f --- /dev/null +++ b/jiuwen/core/context/agent_context.py @@ -0,0 +1,8 @@ +from jiuwen.core.context.context import Context +from jiuwen.core.context.state import State +from jiuwen.core.context.store import Store + + +class AgentContext: + context_map: dict[str, Context] = {} + store: Store = None diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index 685ece8e18b2c3e4c1e37fb8680505efdeb4b436..19af7eceaa782c25d572a9b7ebcab676bc275b6e 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -57,6 +57,10 @@ class Context(ABC): def callback_manager(self) -> CallbackManager: return self._callback_manager + @property + def controller_context_manager(self): + return self._controller_context_manager + def create_executable_context(self, node_id: str) -> Self: context = ExecutableContext(self, node_id) context.set_stream_writer_manager(self._stream_writer_manager) diff --git a/jiuwen/core/context/controller_context/controller_context_manager.py b/jiuwen/core/context/controller_context/controller_context_manager.py index 90c998eb530d113fc93da4d1d29d1d87aa3def19..803f4f6e97ee72dda5599fbcef6f20b7d7605b64 100644 --- a/jiuwen/core/context/controller_context/controller_context_manager.py +++ b/jiuwen/core/context/controller_context/controller_context_manager.py @@ -9,9 +9,8 @@ class ControllerContextMgr: """ Agent上下文管理器: """ - def __init__(self, agent_config: AgentConfig): - self.workflow_mgr = WorkflowMgr() + self.workflow_mgr = WorkflowMgr(agent_config) self.tool_mgr = ToolMgr() self.model_mgr = ModelMgr() self.message_mgr = MessageMgr() diff --git a/jiuwen/core/context/controller_context/workflow_manager.py b/jiuwen/core/context/controller_context/workflow_manager.py index 5e4d0f474ee9c09d54c46bdf655f79007e0c365f..a61f6eacb9b9819ede1d135eba7d8c4dcf66b669 100644 --- a/jiuwen/core/context/controller_context/workflow_manager.py +++ b/jiuwen/core/context/controller_context/workflow_manager.py @@ -6,10 +6,13 @@ from jiuwen.core.workflow.base import Workflow class WorkflowMgr: - def __init__(self): + def __init__(self, agent_config): self._workflows: Dict[str, Workflow] = dict() self._tools: Dict[str, Tool] = dict() + def get_workflow(self, workflow_instance_id: str) -> Workflow: + return self._workflows.get(workflow_instance_id) + def add_workflows(self, workflows: List[Workflow]): if not workflows: return diff --git a/tests/system_tests/agent/test_react_agent.py b/tests/system_tests/agent/test_react_agent.py index f2dfd9877a08c7e72a96e29fd285008e241535eb..323b80bcf2b8c069b89aa66dbc56f5405cf431a2 100644 --- a/tests/system_tests/agent/test_react_agent.py +++ b/tests/system_tests/agent/test_react_agent.py @@ -1,6 +1,6 @@ import os import unittest -from unittest.mock import patch +from unittest.mock import patch, AsyncMock from jiuwen.agent.common.schema import WorkflowSchema, PluginSchema from jiuwen.agent.react_agent import create_react_agent_config, create_react_agent @@ -25,34 +25,40 @@ MODEL_PROVIDER = os.getenv("MODEL_PROVIDER", "") USER_PROMPT_FOR_TRIP_PLANNING = "帮我生成一份{{location}}的旅行攻略,同时旅行时间为{{duration}}。注意:若未明确指定旅行时间,则默认为一天!" -class ReActAgentTest(unittest.TestCase): +class ReActAgentTest(unittest.IsolatedAsyncioTestCase): # ① 关键改动 DEFAULT_TEMPLATE = [ dict(role="system", content="你是一个AI助手,在适当的时候调用合适的工具,帮助我完成任务!") ] @patch("jiuwen.core.utils.tool.service_api.restful_api.RestfulApi.invoke") - def test_react_agent_invoke(self, mock_restfulapi_invoke): + async def test_react_agent_invoke(self, mock_restfulapi_invoke): mock_restfulapi_invoke.return_value = {"result": "杭州今天天气晴,温度35度;注意局部地区有雷阵雨"} + # 下面所有代码无需再改动,已经是协程环境 workflows_schema = [self._create_workflow_schema()] tools_schema = [self._create_tool_schema()] model_config = self._create_model() prompt_template = self.DEFAULT_TEMPLATE - react_agent_config = create_react_agent_config(agent_id="react_agent_123", agent_version="0.0.1", - description="AI助手", plugins=tools_schema, workflows=[], - model=model_config, prompt_template=prompt_template) - workflow = self._create_workflow() + react_agent_config = create_react_agent_config( + agent_id="react_agent_123", + agent_version="0.0.1", + description="AI助手", + plugins=tools_schema, + workflows=[], + model=model_config, + prompt_template=prompt_template + ) + tool = self._create_tool() - react_agent = create_react_agent(agent_config=react_agent_config, - workflows=[], - tools=[tool]) - - context = Context(config=Config(), state=InMemoryState(), store=None) - executable_context = ExecutableContext(context=context, node_id="react_agent") - # inputs = dict(query="生成杭州一日游") - inputs = dict(query="查询杭州今天的天气") - result = react_agent.invoke(inputs, executable_context) + react_agent = create_react_agent( + agent_config=react_agent_config, + workflows=[], + tools=[tool] + ) + inputs = {"query": "查询杭州今天的天气"} + + result = await react_agent.invoke(inputs) # ③ 已经是 await print(result) @staticmethod @@ -97,22 +103,22 @@ class ReActAgentTest(unittest.TestCase): llm_component = ReActAgentTest._create_llm_component() end_component = MockEndNode("end") workflow.set_start_comp("start", start_component, - inputs_schema={ - "location": "${userFields.location}", - "duration": "${userFields.duration}" - }) + inputs_schema={ + "location": "${userFields.location}", + "duration": "${userFields.duration}" + }) workflow.add_workflow_comp("llm", llm_component, - inputs_schema={ - "userFields": { - "location": "${start.location}", - "duration": "{start.duration}" - } - }) + inputs_schema={ + "userFields": { + "location": "${start.location}", + "duration": "{start.duration}" + } + }) workflow.set_end_comp("end", end_component, - inputs_schema={ - "output": "${llm.userFields}" - }) + inputs_schema={ + "output": "${llm.userFields}" + }) workflow.add_connection("start", "llm") workflow.add_connection("llm", "end") return workflow @@ -120,14 +126,14 @@ class ReActAgentTest(unittest.TestCase): @staticmethod def _create_llm_component(): model_config = ModelConfig(model_provider=MODEL_PROVIDER, - model_info=BaseModelInfo( - model=MODEL_NAME, - api_base=API_BASE, - api_key=API_KEY, - temperature=0.7, - top_p=0.9, - timeout=30 # 添加超时设置 - )) + model_info=BaseModelInfo( + model=MODEL_NAME, + api_base=API_BASE, + api_key=API_KEY, + temperature=0.7, + top_p=0.9, + timeout=30 # 添加超时设置 + )) config = LLMCompConfig( model=model_config, template_content=[{"role": "user", "content": USER_PROMPT_FOR_TRIP_PLANNING}], @@ -160,17 +166,17 @@ class ReActAgentTest(unittest.TestCase): inputs={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "天气查询的地点", - "required": True - }, - "date": { - "type": "string", - "description": "天气查询的日期", - "required": True - } + "location": { + "type": "string", + "description": "天气查询的地点", + "required": True + }, + "date": { + "type": "string", + "description": "天气查询的日期", + "required": True } + } } ) return tool_info diff --git a/tests/unit_tests/agent/test_workflow_agent.py b/tests/unit_tests/agent/test_workflow_agent.py index 16d3b68025f4b8f8a5d2d49dd49a265ab1210d90..42c8bcc991dd9093f3f12839fdc431911584cdb0 100644 --- a/tests/unit_tests/agent/test_workflow_agent.py +++ b/tests/unit_tests/agent/test_workflow_agent.py @@ -1,33 +1,67 @@ import pytest -from unittest.mock import Mock, patch - +from jiuwen.agent.common.schema import WorkflowSchema from jiuwen.agent.config.workflow_config import WorkflowAgentConfig from jiuwen.agent.workflow_agent import WorkflowAgent +from jiuwen.core.context.agent_context import AgentContext +from jiuwen.core.context.config import WorkflowConfig +from jiuwen.core.workflow.base import Workflow +from jiuwen.core.workflow.workflow_config import WorkflowMetadata +from jiuwen.graph.pregel.graph import PregelGraph +from tests.unit_tests.workflow.test_mock_node import MockStartNode, Node1, MockEndNode class TestWorkflowAgent: - # 在所有测试方法前一次性打补丁 - @pytest.fixture(autouse=True, scope="class") - def _patch_deps(self): - # 把私有工厂方法替换掉 - with patch.object( - WorkflowAgent, "_init_agent_handler", autospec=True - ) as mock_handler, patch.object( - WorkflowAgent, "_init_controller_context_manager", autospec=True - ): - # 返回统一的 mock handler - handler = Mock() - handler.invoke = Mock(side_effect=lambda st: {"mock": st.func_name}) - mock_handler.return_value = handler - yield + @staticmethod + def _build_workflow(name, id, version): + workflow_config = WorkflowConfig( + metadata=WorkflowMetadata( + id=id, + version=version, + name=name, + ) + ) + flow = Workflow(workflow_config=workflow_config, graph=PregelGraph()) + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "query": "${query}"}) + flow.add_workflow_comp("node_a", Node1("node_a"), + inputs_schema={ + "output": "${start.query}"}) + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${node_a.output}"}) + flow.add_connection("start", "node_a") + flow.add_connection("node_a", "end") + return flow # 真正实例化 @pytest.fixture(scope="class") def agent(self): - return WorkflowAgent(WorkflowAgentConfig()) + agent_context = AgentContext() + id = "test_workflow" + name = "test_workflow" + version = "1" + description = "test_workflow" + workflow1 = self._build_workflow(name, id, version) + test_workflow_schema = WorkflowSchema( + id=id, + version=version, + name=name, + description=description, + inputs={"query": { + "type": "string", + }}, + ) + workflow_config = WorkflowAgentConfig( + workflows=[test_workflow_schema] + ) + agent = WorkflowAgent(workflow_config, agent_context) + agent.bind_workflows([workflow1]) + return agent # ---------- 测试用例 ---------- - def test_invoke_single(self, agent): - inputs = {"workflows": [{"name": "echo", "params": {"text": "hi"}}]} - result = agent.invoke(inputs) - assert result == {"outputs": []} + @pytest.mark.asyncio + async def test_invoke_single(self, agent): + inputs = {"query": "hi"} + result = await agent.invoke(inputs) # ✅ 使用 await + assert result == {'result': 'hi'}