From ded858e346674ffa95131c691dcfd4eddeedd914 Mon Sep 17 00:00:00 2001 From: CandiceGuo Date: Tue, 29 Jul 2025 11:25:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=A2=9E=E5=8A=A0TaskContext=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=EF=BC=8C=E5=A2=9E=E5=8A=A0TaskContext?= =?UTF-8?q?=E4=B8=8EworkflowContext=E7=9A=84=E8=BD=AC=E6=8D=A2=E5=85=B3?= =?UTF-8?q?=E7=B3=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/agent/react_agent.py | 7 ++- jiuwen/core/agent/controller/base.py | 6 +-- .../core/agent/controller/react_controller.py | 11 ++-- jiuwen/core/agent/handler/base.py | 3 +- jiuwen/core/agent/task/task.py | 4 +- jiuwen/core/agent/task/task_context.py | 51 +++++++++++++++++++ jiuwen/core/agent/task/task_manager.py | 9 +--- jiuwen/core/component/questioner_comp.py | 1 - jiuwen/core/context/agent_context.py | 5 +- jiuwen/core/context/context.py | 10 +--- jiuwen/core/context/state.py | 11 +--- .../system_tests/agent/test_workflow_agent.py | 10 ++-- .../workflow/test_real_workflow.py | 13 +++-- tests/unit_tests/tracer/test_agent.py | 19 ++----- .../workflow/test_questioner_comp.py | 17 +++---- tests/unit_tests/workflow/test_tool_comp.py | 9 ++-- tests/unit_tests/workflow/test_workflow.py | 5 +- 17 files changed, 101 insertions(+), 90 deletions(-) create mode 100644 jiuwen/core/agent/task/task_context.py diff --git a/jiuwen/agent/react_agent.py b/jiuwen/agent/react_agent.py index e6cfb32..a5d5fcb 100644 --- a/jiuwen/agent/react_agent.py +++ b/jiuwen/agent/react_agent.py @@ -15,8 +15,8 @@ from jiuwen.core.agent.handler.base import AgentHandlerImpl, AgentHandlerInputs from jiuwen.agent.state.react_state import ReActState from jiuwen.core.agent.task.sub_task import SubTask from jiuwen.core.agent.task.task import Task +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.component.common.configs.model_config import ModelConfig -from jiuwen.core.context.context import Context from jiuwen.core.context.controller_context.controller_context_manager import ControllerContextMgr from jiuwen.core.utils.llm.messages import ToolMessage from jiuwen.core.utils.tool.base import Tool @@ -92,7 +92,7 @@ class ReActAgent(Agent): async def stream(self, inputs: Dict) -> Iterator[Any]: pass - async def _execute_sub_tasks(self, context: Context): + async def _execute_sub_tasks(self, context: TaskContext): to_exec_sub_tasks = self._state.sub_tasks completed_sub_tasks = [] for st in to_exec_sub_tasks: @@ -103,7 +103,7 @@ class ReActAgent(Agent): self._update_chat_history_in_context(completed_sub_tasks, context) return completed_sub_tasks - def _load_state_from_context(self, context: Context): + def _load_state_from_context(self, context: TaskContext): state_dict = context.state().get(REACT_AGENT_STATE_KEY) if state_dict: self._state = ReActState.deserialize(state_dict) @@ -113,7 +113,6 @@ class ReActAgent(Agent): def _store_state_to_context(self, context): state_dict = self._state.serialize() context.state().update({REACT_AGENT_STATE_KEY: state_dict}) - context.state().commit() @staticmethod def _update_chat_history_in_context(completed_sub_tasks: List[SubTask], context): diff --git a/jiuwen/core/agent/controller/base.py b/jiuwen/core/agent/controller/base.py index 1c5259a..584cd4b 100644 --- a/jiuwen/core/agent/controller/base.py +++ b/jiuwen/core/agent/controller/base.py @@ -2,13 +2,11 @@ # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved """Controller of Agent""" -from typing import Optional from pydantic import BaseModel, Field from jiuwen.agent.config.base import AgentConfig -from jiuwen.core.agent.task.task import Task -from jiuwen.core.context.context import Context +from jiuwen.core.agent.task.task_context import TaskContext class ControllerOutput(BaseModel): @@ -25,7 +23,7 @@ class Controller: self._agent_handler = None self._context_mgr = context_mgr - def invoke(self, inputs: ControllerInput, context: Context) -> ControllerOutput: + def invoke(self, inputs: ControllerInput, context: TaskContext) -> ControllerOutput: pass def should_continue(self, output) -> bool: diff --git a/jiuwen/core/agent/controller/react_controller.py b/jiuwen/core/agent/controller/react_controller.py index eb614dd..1741a9e 100644 --- a/jiuwen/core/agent/controller/react_controller.py +++ b/jiuwen/core/agent/controller/react_controller.py @@ -14,9 +14,9 @@ from jiuwen.core.agent.controller.base import ControllerOutput, ControllerInput, from jiuwen.core.agent.handler.base import AgentHandler from jiuwen.agent.config.base import AgentConfig from jiuwen.core.agent.task.sub_task import SubTask +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.exception.status_code import StatusCode -from jiuwen.core.context.context import Context from jiuwen.core.context.controller_context.controller_context_manager import ControllerContextMgr from jiuwen.core.utils.llm.messages import BaseMessage, ToolInfo, Function, Parameters, HumanMessage, AIMessage, \ ToolCall @@ -63,7 +63,7 @@ class ReActControllerUtils: pass @staticmethod - def get_dialogue_history_from_context(context: Context): + def get_dialogue_history_from_context(context: TaskContext): if hasattr(context, "context_manager"): chat_history = context.context_manager.message_mgr.get_chat_history() else: @@ -78,7 +78,6 @@ class ReActControllerUtils: else: # TODO: 临时存储对话历史 context.state().update({DIALOGUE_HISTORY_KEY: current_messages}) - context.state().commit() @staticmethod def json_loads(arguments: str): @@ -99,7 +98,7 @@ class ReActController(Controller): self._model = self._init_model() self._output_parser = self._init_output_parser() - def invoke(self, inputs: ReActControllerInput, context: Context) -> ReActControllerOutput: + def invoke(self, inputs: ReActControllerInput, context: TaskContext) -> ReActControllerOutput: query = inputs.query user_fields = inputs.user_fields chat_history = self._get_latest_chat_history(context) @@ -113,7 +112,7 @@ class ReActController(Controller): def set_agent_handler(self, agent_handler: AgentHandler): self._agent_handler = agent_handler - def _get_latest_chat_history(self, context: Context) -> List[BaseMessage]: + def _get_latest_chat_history(self, context: TaskContext) -> List[BaseMessage]: chat_history = ReActControllerUtils.get_dialogue_history_from_context(context) chat_history = chat_history[-2 * self._config.constrain.reserved_max_chat_rounds:] return chat_history @@ -189,7 +188,7 @@ class ReActController(Controller): model_info=self._config.model.model_info) @staticmethod - def _update_llm_response_to_context(llm_output: AIMessage, chat_history: List[BaseMessage], context: Context): + def _update_llm_response_to_context(llm_output: AIMessage, chat_history: List[BaseMessage], context: TaskContext): if llm_output: chat_history.append(llm_output) ReActControllerUtils.set_dialogue_history_to_context(chat_history, context) diff --git a/jiuwen/core/agent/handler/base.py b/jiuwen/core/agent/handler/base.py index dda79f7..ce62726 100644 --- a/jiuwen/core/agent/handler/base.py +++ b/jiuwen/core/agent/handler/base.py @@ -59,12 +59,11 @@ class AgentHandlerImpl(AgentHandler): async def invoke_workflow(self, inputs: AgentHandlerInputs): context = inputs.context workflow_name = inputs.name - context_manager = context.controller_context_manager() workflow_manager = context_manager.workflow_mgr workflow_metadata = self._search_workflow_metadata_by_workflow_name(workflow_name) workflow = workflow_manager.find_workflow_by_id_and_version(workflow_metadata.id, workflow_metadata.version) - workflow_result = await workflow.invoke(inputs.arguments, context) + workflow_result = await workflow.invoke(inputs.arguments, context.create_workflow_context()) return workflow_result async def invoke_plugin(self, inputs: AgentHandlerInputs): diff --git a/jiuwen/core/agent/task/task.py b/jiuwen/core/agent/task/task.py index 56da9af..3ac1a6a 100644 --- a/jiuwen/core/agent/task/task.py +++ b/jiuwen/core/agent/task/task.py @@ -1,9 +1,9 @@ from jiuwen.agent.common.enum import TaskStatus -from jiuwen.core.context.context import Context +from jiuwen.core.agent.task.task_context import TaskContext class Task: - def __init__(self, task_id: str, context: Context): + def __init__(self, task_id: str, context: TaskContext): self.task_id = task_id self.context = context self.status: TaskStatus = TaskStatus.PENDING diff --git a/jiuwen/core/agent/task/task_context.py b/jiuwen/core/agent/task/task_context.py new file mode 100644 index 0000000..2289c51 --- /dev/null +++ b/jiuwen/core/agent/task/task_context.py @@ -0,0 +1,51 @@ +#!/usr/bin/python3.10 +# coding: utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from typing import Any + +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import WorkflowContext +from jiuwen.core.context.state import InMemoryState, State, CommitState, StateLike, InMemoryCommitState, \ + InMemoryStateLike +from jiuwen.core.context.store import Store +from jiuwen.core.runtime.callback_manager import CallbackManager +from jiuwen.core.stream.base import BaseStreamMode +from jiuwen.core.stream.emitter import StreamEmitter +from jiuwen.core.stream.manager import StreamWriterManager +from jiuwen.core.tracer.tracer import Tracer + + +class TaskContext: + def __init__(self, id: str, store: Store = None, controller_context_manager: Any = None): + self.__id = id + self.__global_state = InMemoryStateLike() + self.__store = store + self.__controller_context_manager = controller_context_manager + self.__stream_writer_manager: StreamWriterManager = StreamWriterManager(StreamEmitter(),[BaseStreamMode.TRACE]) + self.__callback_manager = CallbackManager() + self.__tracer = Tracer() + self.__tracer.init(self.__stream_writer_manager, self.__callback_manager) + + def state(self) -> StateLike: + return self.__global_state + + def set_controller_context_manager(self, controller_context_manager: Any): + self.__controller_context_manager = controller_context_manager + + def controller_context_manager(self) -> Any: + return self.__controller_context_manager + + def tracer(self) -> Tracer: + return self.__tracer + + def create_workflow_context(self) -> WorkflowContext: + return WorkflowContext( + state=InMemoryState(InMemoryCommitState(self.__global_state)), + store=self.__store, + tracer=self.__tracer, + config=Config(), + session_id=self.__id, + controller_context_manager=self.__controller_context_manager) + + def stream_writer_manager(self) -> StreamWriterManager: + return self.__stream_writer_manager diff --git a/jiuwen/core/agent/task/task_manager.py b/jiuwen/core/agent/task/task_manager.py index 85bf1aa..1f58dd5 100644 --- a/jiuwen/core/agent/task/task_manager.py +++ b/jiuwen/core/agent/task/task_manager.py @@ -2,10 +2,8 @@ from typing import List, Optional, Any, Dict from enum import Enum from jiuwen.core.agent.task.task import Task +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.context.agent_context import AgentContext -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import NodeContext, WorkflowContext -from jiuwen.core.context.state import InMemoryState class TaskStatus(Enum): @@ -31,11 +29,8 @@ class TaskManager: # 复用已有context if task_id in self._tasks: return self._tasks[task_id] - # 新建context - context = WorkflowContext(config=Config(), - state=InMemoryState(), - store=self.agent_context.store) + context = TaskContext(id=task_id, store=self.agent_context.store) task = Task(task_id, context) self._tasks[task_id] = task diff --git a/jiuwen/core/component/questioner_comp.py b/jiuwen/core/component/questioner_comp.py index 74f453b..334ba73 100644 --- a/jiuwen/core/component/questioner_comp.py +++ b/jiuwen/core/component/questioner_comp.py @@ -397,7 +397,6 @@ class QuestionerExecutable(Executable): def _store_state_to_context(state: QuestionerState, context): state_dict = state.serialize() context.state().update({QUESTIONER_STATE_KEY: state_dict}) - context.state().commit() def state(self, state: QuestionerState): self._state = state diff --git a/jiuwen/core/context/agent_context.py b/jiuwen/core/context/agent_context.py index 3c0ef98..64bc67f 100644 --- a/jiuwen/core/context/agent_context.py +++ b/jiuwen/core/context/agent_context.py @@ -1,8 +1,7 @@ -from jiuwen.core.context.context import Context -from jiuwen.core.context.state import State +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.context.store import Store class AgentContext: - context_map: dict[str, Context] = {} + context_map: dict[str, TaskContext] = {} store: Store = None diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index c074e15..e3caae9 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -63,20 +63,17 @@ class Context(ABC): def set_queue_manager(self, queue_manager: MessageQueueManager): return - def clone(self) -> Self: - return None - class WorkflowContext(Context): def __init__(self, state: State, config: Config = Config(), store: Store = None, tracer: Tracer = None, - session_id: str = None): + session_id: str = None, controller_context_manager: Any = None): self.__config = config self.__state = state self.__store = store self.__tracer = tracer self.__callback_manager = CallbackManager() self.__stream_writer_manager: StreamWriterManager = None - self.__controller_context_manager = None + self.__controller_context_manager = controller_context_manager self.__session_id = session_id if session_id else uuid.uuid4().hex self.__queue_manager: MessageQueueManager = None @@ -123,9 +120,6 @@ class WorkflowContext(Context): def session_id(self) -> str: return self.__session_id - def clone(self) -> Self: - return WorkflowContext(state=self.state().clone(), session_id=self.session_id()) - class NodeContext(Context): def __init__(self, context: Context, node_id: str): diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index d2b22b8..ecdbbea 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -1,7 +1,6 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved -import uuid from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, Union, Optional, Callable, Self @@ -179,9 +178,6 @@ class State(ABC): self._global_state.set_updates(updates.get(GLOBAL_STATE_UPDATES_KEY)) self._comp_state.set_updates(updates.get(COMP_STATE_UPDATES_KEY)) - def clone(self) -> Self: - pass - class InMemoryStateLike(StateLike): def __init__(self): @@ -210,8 +206,8 @@ class InMemoryStateLike(StateLike): class InMemoryCommitState(CommitState): - def __init__(self): - self._state = InMemoryStateLike() + def __init__(self, state: StateLike = None): + self._state = state if not state else InMemoryStateLike() self._updates: dict[str, list[dict]] = dict() def update(self, node_id: str, data: dict) -> None: @@ -259,6 +255,3 @@ class InMemoryState(State): global_state=global_state, trace_state=dict(), comp_state=InMemoryCommitState()) - - def clone(self) -> Self: - return InMemoryState(global_state=self._global_state) diff --git a/tests/system_tests/agent/test_workflow_agent.py b/tests/system_tests/agent/test_workflow_agent.py index bdbe865..e8196a1 100644 --- a/tests/system_tests/agent/test_workflow_agent.py +++ b/tests/system_tests/agent/test_workflow_agent.py @@ -4,10 +4,10 @@ from datetime import datetime import unittest import pytest -from unittest.mock import patch from jiuwen.agent.common.schema import WorkflowSchema from jiuwen.agent.config.workflow_config import WorkflowAgentConfig +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.component.common.configs.model_config import ModelConfig from jiuwen.core.component.end_comp import End from jiuwen.core.component.intent_detection_comp import IntentDetectionComponent, IntentDetectionConfig @@ -16,9 +16,7 @@ from jiuwen.core.component.questioner_comp import QuestionerComponent, Questione from jiuwen.core.component.start_comp import Start from jiuwen.core.component.tool_comp import ToolComponent, ToolComponentConfig from jiuwen.core.context.agent_context import AgentContext -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context, WorkflowContext -from jiuwen.core.context.state import InMemoryState +from jiuwen.core.context.context import Context from jiuwen.core.utils.llm.base import BaseModelInfo from jiuwen.core.utils.prompt.template.template import Template from jiuwen.core.utils.tool.service_api.param import Param @@ -236,7 +234,7 @@ class WorkflowAgentTest(unittest.IsolatedAsyncioTestCase): workflow_config=workflow_config, graph=PregelGraph(), ) - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) + context = TaskContext(id="test") # 2. 实例化各组件 start = self._create_start_component() @@ -283,7 +281,7 @@ class WorkflowAgentTest(unittest.IsolatedAsyncioTestCase): flow.add_connection("questioner", "plugin") flow.add_connection("plugin", "end") - return context, flow + return context.create_workflow_context(), flow @staticmethod def _create_workflow_schema(id, name: str, version: str) -> WorkflowSchema: diff --git a/tests/system_tests/workflow/test_real_workflow.py b/tests/system_tests/workflow/test_real_workflow.py index f732be4..e5f08d9 100644 --- a/tests/system_tests/workflow/test_real_workflow.py +++ b/tests/system_tests/workflow/test_real_workflow.py @@ -19,6 +19,7 @@ import os import unittest from unittest.mock import patch +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.component.branch_comp import BranchComponent from jiuwen.core.component.common.configs.model_config import ModelConfig from jiuwen.core.component.end_comp import End @@ -34,9 +35,7 @@ from jiuwen.core.component.questioner_comp import ( ) from jiuwen.core.component.start_comp import Start from jiuwen.core.component.tool_comp import ToolComponent, ToolComponentConfig -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context, WorkflowContext -from jiuwen.core.context.state import InMemoryState +from jiuwen.core.context.context import Context from jiuwen.core.stream.writer import CustomSchema from jiuwen.core.utils.llm.base import BaseModelInfo from jiuwen.core.utils.prompt.template.template import Template @@ -239,7 +238,7 @@ class RealWorkflowTest(unittest.TestCase): workflow_config=WorkflowConfig(), graph=PregelGraph(), ) - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) + context = TaskContext(id="test") # 3. 实例化各组件 start = MockStartNode("start") @@ -293,7 +292,7 @@ class RealWorkflowTest(unittest.TestCase): flow.add_connection("questioner", "plugin") flow.add_connection("plugin", "end") - return context, flow + return context.create_workflow_context(), flow # ------------------------------------------------------------------ # # 测试用例本身 # @@ -327,7 +326,7 @@ class RealWorkflowTest(unittest.TestCase): """ 测试LLM组件通过StreamWriter流出数据 """ - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) + context = TaskContext(id="test") flow = Workflow(workflow_config=WorkflowConfig(), graph=PregelGraph()) start_component = Start("s", @@ -361,5 +360,5 @@ class RealWorkflowTest(unittest.TestCase): inputs = {"query": "写一个笑话。注意:不要超过20个字!"} writer_chunks = [] - self.loop.run_until_complete(self._async_stream_workflow_for_stream_writer(flow, inputs, context, writer_chunks)) + self.loop.run_until_complete(self._async_stream_workflow_for_stream_writer(flow, inputs, context.create_workflow_context(), writer_chunks)) print(writer_chunks) diff --git a/tests/unit_tests/tracer/test_agent.py b/tests/unit_tests/tracer/test_agent.py index a8da558..101e3a0 100644 --- a/tests/unit_tests/tracer/test_agent.py +++ b/tests/unit_tests/tracer/test_agent.py @@ -1,12 +1,8 @@ import asyncio import unittest +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.common.logging.base import logger -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import WorkflowContext -from jiuwen.core.context.state import InMemoryState -from jiuwen.core.stream.emitter import StreamEmitter -from jiuwen.core.stream.manager import StreamWriterManager from jiuwen.core.stream.writer import TraceSchema, CustomSchema from jiuwen.core.tracer.tracer import Tracer from tests.unit_tests.tracer.test_mock_node_with_tracer import StreamNodeWithTracer @@ -69,8 +65,7 @@ class MockAgent(unittest.TestCase): """ # workflow与agent共用一个tracer - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) - context.set_tracer(tracer) + context = TaskContext(id="test") # async def stream_workflow(): flow = create_flow() @@ -115,7 +110,7 @@ class MockAgent(unittest.TestCase): } index_dict = {key: 0 for key in expected_datas_model.keys()} - async for chunk in flow.stream({"a": 1, "b": "haha"}, context): + async for chunk in flow.stream({"a": 1, "b": "haha"}, context.create_workflow_context()): if not isinstance(chunk, TraceSchema): node_id = chunk.node_id index = index_dict[node_id] @@ -128,12 +123,8 @@ class MockAgent(unittest.TestCase): async def run_agent_workflow_seq_exec_stream_workflow_with_tracer(self): # context手动初始化tracer,agent和workflow共用一个tracer - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) - context.set_stream_writer_manager(StreamWriterManager(StreamEmitter())) - tracer = Tracer() - tracer.init(context.stream_writer_manager(), context.callback_manager()) - context.set_tracer(tracer) - self.tracer = tracer + context = TaskContext(id="test") + self.tracer = context.tracer() agent_span = self.tracer.tracer_agent_span_manager.create_agent_span() try: diff --git a/tests/unit_tests/workflow/test_questioner_comp.py b/tests/unit_tests/workflow/test_questioner_comp.py index 27b6e97..462fa5c 100644 --- a/tests/unit_tests/workflow/test_questioner_comp.py +++ b/tests/unit_tests/workflow/test_questioner_comp.py @@ -2,13 +2,11 @@ import asyncio import unittest from unittest.mock import patch +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.component.common.configs.model_config import ModelConfig from jiuwen.core.component.end_comp import End from jiuwen.core.component.questioner_comp import FieldInfo, QuestionerConfig, QuestionerComponent from jiuwen.core.component.start_comp import Start -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import WorkflowContext, Context -from jiuwen.core.context.state import InMemoryState from jiuwen.core.graph.executable import Input from jiuwen.core.graph.interrupt.interactive_input import InteractiveInput from jiuwen.core.stream.writer import TraceSchema @@ -26,15 +24,15 @@ class QuestionerTest(unittest.TestCase): asyncio.set_event_loop(self.loop) @staticmethod - def invoke_workflow(inputs: Input, context: Context, flow: Workflow): + def invoke_workflow(inputs: Input, context: TaskContext, flow: Workflow): loop = asyncio.get_event_loop() - feature = asyncio.ensure_future(flow.invoke(inputs=inputs, context=context)) + feature = asyncio.ensure_future(flow.invoke(inputs=inputs, context=context.create_workflow_context())) loop.run_until_complete(feature) return feature.result() @staticmethod def _create_context(session_id): - return WorkflowContext(config=Config(), state=InMemoryState(), store=None, session_id=session_id) + return TaskContext(id=session_id) @patch("jiuwen.core.component.questioner_comp.QuestionerDirectReplyHandler._invoke_llm_for_extraction") @patch("jiuwen.core.component.questioner_comp.QuestionerDirectReplyHandler._build_llm_inputs") @@ -51,7 +49,7 @@ class QuestionerTest(unittest.TestCase): mock_llm_inputs.return_value = mock_prompt_template mock_extraction.return_value = dict(location="hangzhou") - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) + context = TaskContext(id = "test") flow = create_flow() key_fields = [ @@ -185,7 +183,7 @@ class QuestionerTest(unittest.TestCase): mock_llm_inputs.return_value = mock_prompt_template mock_extraction.return_value = dict(location="hangzhou") - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None) + context = TaskContext(id = "test") flow = create_flow() key_fields = [ @@ -227,6 +225,7 @@ class QuestionerTest(unittest.TestCase): _tracer_chunks.append(chunk) tracer_chunks = [] - self.loop.run_until_complete(_async_stream_workflow_for_tracer(flow, {"query": "查询杭州的天气"}, context, + self.loop.run_until_complete(_async_stream_workflow_for_tracer(flow, {"query": "查询杭州的天气"}, + context.create_workflow_context(), tracer_chunks)) print(tracer_chunks) diff --git a/tests/unit_tests/workflow/test_tool_comp.py b/tests/unit_tests/workflow/test_tool_comp.py index b5e4b53..f348d7b 100644 --- a/tests/unit_tests/workflow/test_tool_comp.py +++ b/tests/unit_tests/workflow/test_tool_comp.py @@ -2,10 +2,9 @@ from unittest.mock import patch, Mock, MagicMock import pytest +from jiuwen.core.agent.task.task_context import TaskContext from jiuwen.core.component.tool_comp import ToolComponentConfig, ToolExecutable, ToolComponent -from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context, WorkflowContext -from jiuwen.core.context.state import InMemoryState +from jiuwen.core.context.context import Context from jiuwen.core.utils.tool.service_api.param import Param from jiuwen.core.utils.tool.service_api.restful_api import RestfulApi from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode @@ -77,7 +76,7 @@ async def test_tool_comp_invoke(mock_get_tool, mock_request, mock_tool, mock_too async def test_tool_comp_in_workflow(mock_get_tool, mock_invoke, mock_tool, mock_tool_config, fake_ctx): mock_get_tool.return_value = mock_tool mock_invoke.return_value = 'res' - context = WorkflowContext(config=Config(), state=InMemoryState(), store=None, tracer=None) + context = TaskContext(id="test") flow = create_flow() start_component = MockStartNode("s") @@ -91,4 +90,4 @@ async def test_tool_comp_in_workflow(mock_get_tool, mock_invoke, mock_tool, mock flow.add_connection("s", "tool") flow.add_connection("tool", "e") - await flow.invoke({}, context) + await flow.invoke({}, context.create_workflow_context()) diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 7902fae..87972cc 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -19,13 +19,12 @@ from jiuwen.core.context.state import ReadableStateLike from jiuwen.core.graph.base import Graph from jiuwen.core.graph.executable import Input from jiuwen.core.graph.graph_state import GraphState -from jiuwen.core.workflow.base import Workflow -from jiuwen.core.workflow.workflow_config import WorkflowConfig, ComponentAbility +from jiuwen.core.workflow.workflow_config import ComponentAbility from jiuwen.core.stream.base import BaseStreamMode from jiuwen.core.stream.writer import CustomSchema from jiuwen.core.workflow.base import WorkflowConfig, Workflow from jiuwen.graph.pregel.graph import PregelGraph -from test_mock_node import SlowNode, CountNode, StreamCompNode, CollectCompNode, MultiCollectCompNode, TransformCompNode +from test_mock_node import SlowNode, CountNode, StreamCompNode, CollectCompNode, TransformCompNode from test_node import AddTenNode, CommonNode from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode, Node1, StreamNode -- Gitee