From ef8d5b5847224f2fdb3c7e5b5b648ebffee88946 Mon Sep 17 00:00:00 2001 From: laihongsen Date: Wed, 23 Jul 2025 09:49:32 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat(core):=20=E5=AE=9E=E7=8E=B0=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E5=A4=84=E7=90=86=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 ExecWorkflowBase 接口和 StreamActor 类- 实现消息队列管理器和流式处理相关的逻辑 - 更新工作流配置,支持流式处理能力配置 - 修改执行图和顶点类以支持流式处理 --- jiuwen/core/common/constants/constant.py | 9 +- jiuwen/core/component/exec_workflow_base.py | 7 + jiuwen/core/component/workflow_comp.py | 3 +- jiuwen/core/context/config.py | 5 + jiuwen/core/context/context.py | 10 ++ jiuwen/core/context/mq_manager.py | 104 +++++++++++++ jiuwen/core/graph/base.py | 7 +- jiuwen/core/graph/executable.py | 17 +-- jiuwen/core/graph/vertex.py | 159 ++++++++++++++++---- jiuwen/core/stream_actor/__init__.py | 0 jiuwen/core/stream_actor/base.py | 23 +++ jiuwen/core/workflow/base.py | 43 +++++- jiuwen/core/workflow/workflow_config.py | 23 ++- jiuwen/graph/pregel/graph.py | 9 +- tests/unit_tests/workflow/test_llm_comp.py | 9 ++ tests/unit_tests/workflow/test_mock_node.py | 94 ++++++++++++ tests/unit_tests/workflow/test_workflow.py | 55 ++++++- 17 files changed, 521 insertions(+), 56 deletions(-) create mode 100644 jiuwen/core/component/exec_workflow_base.py create mode 100644 jiuwen/core/context/mq_manager.py create mode 100644 jiuwen/core/stream_actor/__init__.py create mode 100644 jiuwen/core/stream_actor/base.py diff --git a/jiuwen/core/common/constants/constant.py b/jiuwen/core/common/constants/constant.py index 2cee1d2..6ca78c5 100644 --- a/jiuwen/core/common/constants/constant.py +++ b/jiuwen/core/common/constants/constant.py @@ -8,4 +8,11 @@ SYSTEM_FIELDS = "systemFields" INTERACTION = sys.intern("__interaction__") # for dynamic interaction raised by nodes -INTERACTIVE_INPUT = sys.intern("__interactive_input__") \ No newline at end of file +INTERACTIVE_INPUT = sys.intern("__interactive_input__") + +INPUTS_KEY = "inputs" +CONFIG_KEY = "config" + +END_FRAME = "all streaming outputs finish" + +END_NODE_STREAM = "end node stream" \ No newline at end of file diff --git a/jiuwen/core/component/exec_workflow_base.py b/jiuwen/core/component/exec_workflow_base.py new file mode 100644 index 0000000..d1535f9 --- /dev/null +++ b/jiuwen/core/component/exec_workflow_base.py @@ -0,0 +1,7 @@ +from abc import ABC + + +class ExecWorkflowBase: + """ + ExecWorkflowBase + """ \ No newline at end of file diff --git a/jiuwen/core/component/workflow_comp.py b/jiuwen/core/component/workflow_comp.py index ffb1854..df7055c 100644 --- a/jiuwen/core/component/workflow_comp.py +++ b/jiuwen/core/component/workflow_comp.py @@ -4,13 +4,14 @@ from typing import AsyncIterator from jiuwen.core.component.base import WorkflowComponent +from jiuwen.core.component.exec_workflow_base import ExecWorkflowBase from jiuwen.core.context.context import Context from jiuwen.core.graph.base import INPUTS_KEY, CONFIG_KEY from jiuwen.core.graph.executable import Executable, Input, Output from jiuwen.core.workflow.base import Workflow -class ExecWorkflowComponent(WorkflowComponent, Executable): +class ExecWorkflowComponent(WorkflowComponent, Executable, ExecWorkflowBase): def __init__(self, node_id: str, sub_workflow: Workflow): super().__init__() self.node_id = node_id diff --git a/jiuwen/core/context/config.py b/jiuwen/core/context/config.py index e352ab6..d1d1656 100644 --- a/jiuwen/core/context/config.py +++ b/jiuwen/core/context/config.py @@ -43,7 +43,12 @@ class Config(ABC): self._workflow_config = workflow_config else: self._workflow_config.comp_configs.update(workflow_config.comp_configs) + self._workflow_config.comp_stream_configs.update(workflow_config.comp_stream_configs) self._workflow_config.stream_edges.update(workflow_config.stream_edges) + self._workflow_config.comp_abilities.update(workflow_config.comp_abilities) + + def get_workflow_config(self) -> WorkflowConfig: + return self._workflow_config def set_comp_io_config(self, node_id: str, comp_io_config: CompIOConfig) -> None: """ diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index e5f4223..1c85240 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, Self from jiuwen.core.context.config import Config +from jiuwen.core.context.mq_manager import MessageQueueManager from jiuwen.core.context.state import State from jiuwen.core.context.store import Store from jiuwen.core.runtime.callback_manager import CallbackManager @@ -55,6 +56,9 @@ class Context(ABC): def set_stream_writer_manager(self, stream_writer_manager: StreamWriterManager) -> None: return + def set_queue_manager(self, queue_manager: MessageQueueManager): + return + def clone(self) -> Self: return None @@ -70,6 +74,7 @@ class WorkflowContext(Context): self.__stream_writer_manager: StreamWriterManager = None self.__controller_context_manager = None self.__session_id = session_id if session_id else uuid.uuid4().hex + self.queue_manager: MessageQueueManager = None def set_stream_writer_manager(self, stream_writer_manager: StreamWriterManager) -> None: if self.__stream_writer_manager is not None: @@ -103,6 +108,11 @@ class WorkflowContext(Context): def controller_context_manager(self): return self.__controller_context_manager + def set_queue_manager(self, queue_manager: MessageQueueManager): + if self.queue_manager is not None: + return + self.queue_manager = queue_manager + def session_id(self) -> str: return self.__session_id diff --git a/jiuwen/core/context/mq_manager.py b/jiuwen/core/context/mq_manager.py new file mode 100644 index 0000000..8c3951a --- /dev/null +++ b/jiuwen/core/context/mq_manager.py @@ -0,0 +1,104 @@ +from typing import Dict, Any, AsyncIterator + +from jiuwen.core.common.logging.base import logger +from jiuwen.core.context.state import Transformer +from jiuwen.core.context.utils import get_by_schema +from jiuwen.core.stream.emitter import AsyncStreamQueue +from jiuwen.core.workflow.workflow_config import ComponentAbility + + +class StreamTransform: + def get_by_defined_transformer(self, origin_message: dict, transformer: Transformer) -> dict: + return transformer(origin_message) + + def get_by_default_transformer(self, origin_message: dict, stream_inputs_schema: dict) -> dict: + # nodeC schema: {"ca": "${nodeA.a_output}", "cb": "${nodeB.b_output}" + # 根据stream_inputs_schema进行封装数据{"nodeA": {"a_output": "a"}} + # 或者{"nodeB": {"b_output": "b"}}转换成{"ca": "a"}或者{"cb": "b"} + return get_by_schema(stream_inputs_schema, origin_message) + + +class MessageQueueManager: + def __init__(self, stream_edges: dict[str, list[str]], comp_abilities: dict[str, list[ComponentAbility]], + sub_graph: bool): + self._stream_edges: Dict[str, list[str]] = {} + self._streams: Dict[str, dict[ComponentAbility, AsyncStreamQueue]] = {} + self._streams_transform = StreamTransform() + for producer_id, consumer_ids in stream_edges.items(): + self._stream_edges[producer_id] = consumer_ids + for consumer_id in consumer_ids: + consumer_stream_ability = [ability for ability in comp_abilities[consumer_id] if + ability in [ComponentAbility.COLLECT, ComponentAbility.TRANSFORM]] + self._streams[consumer_id] = {ability: AsyncStreamQueue(maxsize=10 * 1024) + for ability in consumer_stream_ability} + self._sub_graph = sub_graph + self._sub_workflow_stream = AsyncStreamQueue(maxsize=10 * 1024) if sub_graph else None + + @property + def sub_workflow_stream(self): + if not self._sub_graph: + raise RuntimeError("only sub graph has sub_workflow_stream") + return self._sub_workflow_stream + + def _get_queue(self, consumer_id: str) -> dict[ComponentAbility, AsyncStreamQueue]: + return self._streams[consumer_id] + + @property + def stream_transform(self): + return self._streams_transform + + async def produce(self, producer_id: str, message_content: Any): + consumer_ids = self._stream_edges.get(producer_id) + if consumer_ids: + for consumer_id in consumer_ids: + stream_queues = self._get_queue(consumer_id) + for _, queue in stream_queues.items(): + await queue.send({producer_id: message_content}) + logger.debug(f"===produce message {producer_id} {consumer_id} {message_content}") + + async def end_message(self, producer_id: str): + end_message_content = f"END_{producer_id}" + await self.produce(producer_id, end_message_content) + + def _is_end_message(self, message: dict[str, Any], ended_producers: set) -> bool: + if not isinstance(message, dict) or len(message) != 1: + raise ValueError("message is invalid") + produce_id = next(iter(message)) + message_content = message[produce_id] + if isinstance(message_content, str) and message_content.startswith("END_"): + ended_producers.add(produce_id) + return True + return False + + async def consume(self, consumer_id: str, ability: ComponentAbility) -> AsyncIterator[dict[str, Any]]: + stream_queues = self._get_queue(consumer_id) + queue = stream_queues[ability] + if queue is not None: + ended_producers = set() + while True: + message = await queue.receive() + logger.debug(f"===consume message {consumer_id} {ability} {message}") + if message is None: + continue + if self._is_end_message(message, ended_producers): + if ended_producers == set(key for key, value in self._stream_edges.items() if consumer_id in value): + await self.close_stream(consumer_id) + logger.debug(f"===consumer end {consumer_id} {ability}") + break + else: + continue + yield message + + async def close_stream(self, consumer_id: str): + if consumer_id in self._streams: + stream_queues = self._streams.pop(consumer_id) + for _, queue in stream_queues.items(): + await queue.close() + + async def close_all_streams(self): + for consumer_id in list(self._streams.keys()): + await self.close_stream(consumer_id) + self._streams.clear() + + def is_empty(self, node_id) -> bool: + return self._streams[node_id] is None \ No newline at end of file diff --git a/jiuwen/core/graph/base.py b/jiuwen/core/graph/base.py index 9f3ce7f..e961cec 100644 --- a/jiuwen/core/graph/base.py +++ b/jiuwen/core/graph/base.py @@ -6,12 +6,10 @@ from typing import Self, Union, Any, AsyncIterator, Hashable, Callable, Awaitabl from langchain_core.runnables import Runnable +from jiuwen.core.common.constants.constant import INPUTS_KEY, CONFIG_KEY from jiuwen.core.context.context import Context from jiuwen.core.graph.executable import Executable, Output, Input -INPUTS_KEY = "inputs" -CONFIG_KEY = "config" - class ExecutableGraph(Executable[Input, Output]): async def invoke(self, inputs: Input, context: Context) -> Output: @@ -59,3 +57,6 @@ class Graph(ABC): def compile(self, context: Context) -> ExecutableGraph: pass + + def get_nodes(self) -> dict: + pass \ No newline at end of file diff --git a/jiuwen/core/graph/executable.py b/jiuwen/core/graph/executable.py index 8e05b26..a8ddcd9 100644 --- a/jiuwen/core/graph/executable.py +++ b/jiuwen/core/graph/executable.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from functools import partial from typing import TypeVar, Generic, Iterator, AsyncIterator, Any -from jiuwen.core.common.exception.exception import InterruptException +from jiuwen.core.common.exception.exception import InterruptException, JiuWenBaseException from jiuwen.core.common.exception.status_code import StatusCode from jiuwen.core.context.context import Context @@ -15,7 +15,7 @@ Input = TypeVar("Input", contravariant=True) Output = TypeVar("Output", contravariant=True) -class Executable(Generic[Input, Output], ABC): +class Executable(Generic[Input, Output]): memory: "ConversationMemory" = None memory_auto_save: bool = True local_params: dict = dict() @@ -23,23 +23,18 @@ class Executable(Generic[Input, Output], ABC): is_global: bool = False global_var_name: str = "" - @abstractmethod async def invoke(self, inputs: Input, context: Context) -> Output: - pass + raise JiuWenBaseException(-1, "Invoke is not supported") - @abstractmethod async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - pass + raise JiuWenBaseException(-1, "Invoke is not supported") - @abstractmethod async def collect(self, inputs: AsyncIterator[Input], contex: Context) -> Output: - pass + raise JiuWenBaseException(-1, "Invoke is not supported") - @abstractmethod async def transform(self, inputs: AsyncIterator[Input], context: Context) -> AsyncIterator[Output]: - pass + raise JiuWenBaseException(-1, "Invoke is not supported") - @abstractmethod async def interrupt(self, message: dict): raise InterruptException( error_code=StatusCode.CONTROLLER_INTERRUPTED_ERROR.code, diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 1447d7b..29c1446 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -1,19 +1,22 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved -from typing import Any, Optional +import asyncio +from typing import Any, Optional, AsyncIterator -from jiuwen.core.common.constants.constant import INTERACTIVE_INPUT +from jiuwen.core.common.constants.constant import INTERACTIVE_INPUT, END_NODE_STREAM, INPUTS_KEY, CONFIG_KEY from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.logging.base import logger from jiuwen.core.component.condition.condition import INDEX from jiuwen.core.component.loop_callback.loop_id import LOOP_ID +from jiuwen.core.component.exec_workflow_base import ExecWorkflowBase from jiuwen.core.component.workflow_comp import ExecWorkflowComponent from jiuwen.core.context.context import Context, NodeContext from jiuwen.core.context.utils import get_by_schema, NESTED_PATH_SPLIT -from jiuwen.core.graph.base import INPUTS_KEY, CONFIG_KEY +from jiuwen.core.graph.base import ExecutableGraph from jiuwen.core.graph.executable import Executable, Output from jiuwen.core.graph.graph_state import GraphState +from jiuwen.core.workflow.workflow_config import ComponentAbility class Vertex: @@ -21,33 +24,51 @@ class Vertex: self._node_id = node_id self._executable = executable self._context: NodeContext = None + # if stream_call is available, call should wait for it + self._stream_done = asyncio.Event() + self._stream_called = False def init(self, context: Context) -> bool: self._context = NodeContext(context, self._node_id) return True - async def __call__(self, state: GraphState, config: Any = None) -> Output: - if self._context is None or self._executable is None: - raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) - inputs = await self.__pre_invoke__() - logger.info("vertex[%s] inputs %s", self._context.executable_id(), inputs) - is_stream = self.__is_stream__(state) - if self._executable.graph_invoker(): - inputs = {INPUTS_KEY: inputs, CONFIG_KEY: config} + def get_executable(self) -> Executable: + return self._executable + + async def _run_executable(self, ability: ComponentAbility, is_subgraph: bool = False, config: Any = None): + if ability == ComponentAbility.INVOKE: + batch_inputs = await self._pre_invoke() + logger.info("====_run_executable39: vertex[%s] invoke inputs %s", self._node_id, batch_inputs) + if is_subgraph: + batch_inputs = {INPUTS_KEY: batch_inputs, CONFIG_KEY: config} + results = await self._executable.invoke(batch_inputs, context=self._context) + logger.info("====_run_executable43: vertex[%s] invoke result %s", self._node_id, results) + await self._post_invoke(results) + elif ability == ComponentAbility.STREAM: + batch_inputs = await self._pre_invoke() + logger.info("====_run_executable47: vertex[%s] stream inputs %s", self._node_id, batch_inputs) + if is_subgraph: + batch_inputs = {INPUTS_KEY: batch_inputs, CONFIG_KEY: config} + result_iter = self._executable.stream(batch_inputs, context=self._context) + logger.info("====_run_executable51: vertex[%s] stream result %s", self._node_id, result_iter) + await self._post_stream(result_iter) + elif ability == ComponentAbility.COLLECT: + collect_iter = self._pre_stream(ability) + batch_output = await self._executable.collect(collect_iter, self._context) + logger.info("====_run_executable55: vertex[%s] collect result %s", self._node_id, batch_output) + await self._post_invoke(batch_output) + elif ability == ComponentAbility.TRANSFORM: + transform_iter = self._pre_stream(ability) + output_iter = self._executable.transform(transform_iter, self._context) + await self._post_stream(output_iter) + else: + logger.error(f"error ComponentAbility: {ability.name}") - try: - if is_stream: - result_iter = await self._executable.stream(inputs, context=self._context) - self.__post_stream__(result_iter) - else: - results = await self._executable.invoke(inputs, context=self._context) - outputs = await self.__post_invoke__(results) - logger.info("vertex[%s] outputs %s", self._context.executable_id(), outputs) - except JiuWenBaseException as e: - raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) + async def __call__(self, state: GraphState, config: Any = None) -> Output: + await self.call(state, config) return {"source_node_id": [self._node_id]} - async def __pre_invoke__(self) -> Optional[dict]: + async def _pre_invoke(self) -> Optional[dict]: inputs_transformer = self._context.config().get_input_transformer(self._node_id) if inputs_transformer is None: inputs_schema = self._context.config().get_inputs_schema(self._node_id) @@ -55,10 +76,10 @@ class Vertex: else: inputs = self._context.state().get_inputs_by_transformer(inputs_transformer) if self._context.tracer() is not None: - await self.__trace_inputs__(inputs) + await self._trace_inputs(inputs) return inputs - async def __post_invoke__(self, results: Optional[dict]) -> Any: + async def _post_invoke(self, results: Optional[dict]) -> Any: output_transformer = self._context.config().get_output_transformer(self._node_id) if output_transformer is None: output_schema = self._context.config().get_outputs_schema(self._node_id) @@ -72,8 +93,50 @@ class Vertex: self.__clear_interactive__() return results - def __post_stream__(self, results_iter: Any) -> None: - pass + async def _pre_stream(self, ability: ComponentAbility) -> AsyncIterator[dict]: + queue_manager = self._context.queue_manager + workflow_config = self._context.config.get_workflow_config() + inputs_transformer = workflow_config.comp_stream_configs[self._node_id].inputs_transformer + inputs_schema = workflow_config.comp_stream_configs[self._node_id].inputs_schema + async for message in queue_manager.consume(self._node_id, ability): + # message 是{id: content} + if inputs_transformer is None: + inputs = queue_manager.stream_transform.get_by_default_transformer(message, inputs_schema)\ + if inputs_schema else message + else: + inputs = queue_manager.stream_transform.get_by_defined_transformer(message, inputs_transformer) + yield inputs + + async def _post_stream(self, results_iter: AsyncIterator) -> None: + queue_manager = self._context.queue_manager + workflow_config = self._context.config.get_workflow_config() + output_transformer = workflow_config.comp_stream_configs[self._node_id].outputs_transformer + output_schema = workflow_config.comp_stream_configs[self._node_id].outputs_schema + end_stream_index = 0 + async for chunk in results_iter: + if output_transformer is None: + message = queue_manager.stream_transform.get_by_default_transformer(chunk, output_schema) \ + if output_schema else chunk + else: + message = queue_manager.stream_transform.get_by_defined_transformer(chunk, output_transformer) + await self._process_chunk(end_stream_index, message) + await queue_manager.end_message(self._node_id) + + async def _process_chunk(self, end_stream_index: int, message: Any) -> None: + end_node = False + sub_graph = self._context.parent_id is not None + if end_node and not sub_graph: + message_stream_data = { + "type": END_NODE_STREAM, + "index": ++end_stream_index, + "payload": message + } + await self._context.stream_writer_manager.get_output_writer().write(message_stream_data) + elif end_node and sub_graph: + await self._context.queue_manager.sub_workflow_stream.send(message) + else: + await self._context.queue_manager.produce(self._node_id, message) + def __clear_interactive__(self) -> None: if self._context.state().get_comp(INTERACTIVE_INPUT): @@ -93,6 +156,47 @@ class Vertex: if isinstance(self._executable, ExecWorkflowComponent): self._context.tracer().register_workflow_span_manager(self._context.executable_id()) + async def call(self, state: GraphState, config: Any = None, ExecGraphComponent=None): + if self._context is None or self._executable is None: + raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) + + is_subgraph = self._executable.graph_invoker() + + try: + workflow_config = self._context.config.get_workflow_config() + component_ability = workflow_config.comp_abilities.get(self._node_id, []) + call_ability = [ability for ability in component_ability if + ability in [ComponentAbility.INVOKE, ComponentAbility.STREAM]] + for ability in call_ability: + await self._run_executable(ability, is_subgraph, config) + + except JiuWenBaseException as e: + raise JiuWenBaseException(e.error_code, "failed to invoke, caused by " + e.message) + + # 仅当 stream_call 被调用时才等待 + if self._stream_called: + await self._stream_done.wait() + logger.debug("node [%s] call finished", self._node_id) + + async def stream_call(self): + self._stream_called = True # 标记 stream_call 已被调用 + self._stream_done.clear() # 清除之前的完成状态 + + if self._context is None or self._context.queue_manager is None: + raise JiuWenBaseException(1, "queue manager is not initialized") + + try: + component_ability = self._context.config.get_workflow_config().comp_abilities[self._node_id] + call_ability = [ability for ability in component_ability if + ability in [ComponentAbility.COLLECT, ComponentAbility.TRANSFORM]] + for ability in call_ability: + await self._run_executable(ability) + except JiuWenBaseException as e: + raise JiuWenBaseException(e.error_code, "failed to stream, caused by " + e.message) + finally: + self._stream_done.set() # 标记完成 + logger.info("end to stream call, node %s", self._node_id) + async def __trace_outputs__(self, outputs: Optional[dict] = None) -> None: if self._executable.skip_trace(): return @@ -113,6 +217,3 @@ class Vertex: }) self._context.tracer().pop_workflow_span(self._context.executable_id(), self._context.parent_id()) return component_metadata - - def __is_stream__(self, state: GraphState) -> bool: - return False diff --git a/jiuwen/core/stream_actor/__init__.py b/jiuwen/core/stream_actor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/stream_actor/base.py b/jiuwen/core/stream_actor/base.py new file mode 100644 index 0000000..ed51700 --- /dev/null +++ b/jiuwen/core/stream_actor/base.py @@ -0,0 +1,23 @@ +import asyncio + +from jiuwen.core.context.context import Context +from jiuwen.core.graph.executable import Executable +from jiuwen.core.graph.vertex import Vertex + + +class StreamActor: + def __init__(self): + self.loop = asyncio.get_event_loop() + self._stream_nodes: dict[str, Vertex] = {} + + def init(self, context: Context): + for _, node in self._stream_nodes.items(): + node.init(context) + + def add_stream_consumer(self, consumer: Vertex, node_id: str): + if node_id not in self._stream_nodes.keys(): + self._stream_nodes[node_id] = consumer + + async def run(self): + streams = [node.stream_call() for node_id, node in self._stream_nodes.items()] + return asyncio.gather(*streams) \ No newline at end of file diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index d2e5dfc..3f0b8fe 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -12,15 +12,17 @@ from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.component.end_comp import End from jiuwen.core.component.start_comp import Start from jiuwen.core.context.config import CompIOConfig, Transformer -from jiuwen.core.context.context import Context +from jiuwen.core.context.context import Context, ExecutableContext +from jiuwen.core.context.mq_manager import MessageQueueManager from jiuwen.core.graph.base import Graph, Router, INPUTS_KEY, CONFIG_KEY, ExecutableGraph from jiuwen.core.graph.executable import Executable, Input, Output from jiuwen.core.stream.base import StreamMode, BaseStreamMode from jiuwen.core.stream.emitter import StreamEmitter from jiuwen.core.stream.manager import StreamWriterManager from jiuwen.core.stream.writer import OutputSchema +from jiuwen.core.stream_actor.base import StreamActor from jiuwen.core.tracer.tracer import Tracer -from jiuwen.core.workflow.workflow_config import WorkflowConfig +from jiuwen.core.workflow.workflow_config import WorkflowConfig, ComponentAbility class WorkflowOutput(BaseModel): @@ -38,6 +40,7 @@ class BaseWorkFlow: def __init__(self, workflow_config: WorkflowConfig, new_graph: Graph): self._graph = new_graph self._workflow_config = workflow_config + self._stream_actor = StreamActor() def config(self): return self._workflow_config @@ -52,7 +55,11 @@ class BaseWorkFlow: outputs_schema: dict = None, inputs_transformer: Transformer = None, outputs_transformer: Transformer = None, - + stream_inputs_schema: dict = None, + stream_outputs_schema: dict = None, + stream_inputs_transformer: Transformer = None, + stream_outputs_transformer: Transformer = None, + comp_ability: list[ComponentAbility] = None ) -> Self: if not isinstance(workflow_comp, WorkflowComponent): workflow_comp = self._convert_to_component(workflow_comp) @@ -61,6 +68,12 @@ class BaseWorkFlow: outputs_schema=outputs_schema, inputs_transformer=inputs_transformer, outputs_transformer=outputs_transformer) + self._workflow_config.comp_stream_configs[comp_id] = CompIOConfig(inputs_schema=stream_inputs_schema, + outputs_schema=stream_outputs_schema, + inputs_transformer=stream_inputs_transformer, + outputs_transformer=stream_outputs_transformer) + self._workflow_config.comp_abilities[ + comp_id] = comp_ability if comp_ability is not None else [ComponentAbility.INVOKE] return self def start_comp( @@ -83,6 +96,8 @@ class BaseWorkFlow: def add_stream_connection(self, src_comp_id: str, target_comp_id: str) -> Self: self._graph.add_edge(src_comp_id, target_comp_id) + stream_executables = self._graph.get_nodes() + self._stream_actor.add_stream_consumer(stream_executables[target_comp_id], target_comp_id) if target_comp_id not in self._workflow_config.stream_edges: self._workflow_config.stream_edges[src_comp_id] = [target_comp_id] else: @@ -126,12 +141,21 @@ class Workflow(BaseWorkFlow): inputs_schema: dict = None, outputs_schema: dict = None, inputs_transformer: Transformer = None, - outputs_transformer: Transformer = None + outputs_transformer: Transformer = None, + stream_inputs_schema: dict = None, + stream_outputs_schema: dict = None, + stream_inputs_transformer: Transformer = None, + stream_outputs_transformer: Transformer = None, ) -> Self: self.add_workflow_comp(end_comp_id, component, wait_for_all=False, inputs_schema=inputs_schema, outputs_schema=outputs_schema, inputs_transformer=inputs_transformer, - outputs_transformer=outputs_transformer) + outputs_transformer=outputs_transformer, + stream_inputs_schema=stream_inputs_schema, + stream_outputs_schema=stream_outputs_schema, + stream_inputs_transformer=stream_inputs_transformer, + stream_outputs_transformer=stream_outputs_transformer + ) self.end_comp(end_comp_id) self._end_comp_id = end_comp_id return self @@ -167,15 +191,22 @@ class Workflow(BaseWorkFlow): context: Context, stream_modes: list[StreamMode] = None ) -> AsyncIterator[WorkflowChunk]: + sub_graph = isinstance(context, ExecutableContext) + mq_manager = MessageQueueManager(self._workflow_config.stream_edges, self._workflow_config.comp_abilities, + sub_graph) + context.set_queue_manager(mq_manager) context.set_stream_writer_manager(StreamWriterManager(stream_emitter=StreamEmitter(), modes=stream_modes)) if context.tracer() is None and (stream_modes is None or BaseStreamMode.TRACE in stream_modes): tracer = Tracer() tracer.init(context.stream_writer_manager(), context.callback_manager()) context.set_tracer(tracer) compiled_graph = self.compile(context) - + self._stream_actor.init(context) async def stream_process(): try: + logger.info("Starting stream process") + await self._stream_actor.run() + logger.info("after stream actor run") await compiled_graph.invoke({INPUTS_KEY: inputs, CONFIG_KEY: None}, context) finally: await context.stream_writer_manager().stream_emitter.close() diff --git a/jiuwen/core/workflow/workflow_config.py b/jiuwen/core/workflow/workflow_config.py index 92907b6..f4fc045 100644 --- a/jiuwen/core/workflow/workflow_config.py +++ b/jiuwen/core/workflow/workflow_config.py @@ -1,6 +1,7 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved +from enum import Enum from typing import Optional, Dict, Any, List from pydantic import BaseModel, Field @@ -15,4 +16,24 @@ class WorkflowMetadata(BaseModel): class WorkflowConfig(BaseModel): metadata: Optional[WorkflowMetadata] = Field(default=None) comp_configs: Dict[str, Any] = Field(default_factory=dict) - stream_edges: Dict[str, List[str]] = Field(default_factory=dict) + comp_stream_configs: Dict[str, Any] = Field(default_factory=dict) + stream_edges: Dict[str, list[str]] = Field(default_factory=dict) + comp_abilities: Dict[str, list[Any]] = Field(default_factory=dict) + +class ComponentAbility(Enum): + INVOKE = ("invoke", "batch in, batch out") + STREAM = ("stream", "batch in, stream out") + COLLECT = ("collect", "stream in, batch out") + TRANSFORM = ("transform", "stream in, stream out") + + def __init__(self, name: str, desc: str): + self._name = name + self._desc = desc + + @property + def name(self) -> str: + return self._name + + @property + def desc(self) -> str: + return self._desc \ No newline at end of file diff --git a/jiuwen/graph/pregel/graph.py b/jiuwen/graph/pregel/graph.py index 548ef26..d3e0051 100644 --- a/jiuwen/graph/pregel/graph.py +++ b/jiuwen/graph/pregel/graph.py @@ -46,7 +46,7 @@ class PregelGraph(Graph): self.compiledStateGraph = None self.edges: list[Union[str, list[str]], str] = [] self.waits: set[str] = set() - self.nodes: list[Vertex] = [] + self.nodes: dict[str, Vertex] = {} self.checkpoint_saver = None def start_node(self, node_id: str) -> Self: @@ -59,12 +59,15 @@ class PregelGraph(Graph): def add_node(self, node_id: str, node: Executable, *, wait_for_all: bool = False) -> Self: vertex_node = Vertex(node_id, node) - self.nodes.append(vertex_node) + self.nodes[node_id] = vertex_node self.pregel.add_node(node_id, vertex_node) if wait_for_all: self.waits.add(node_id) return self + def get_nodes(self) -> dict[str, Vertex]: + return {key: vertex for key, vertex in self.nodes.items()} + def add_edge(self, source_node_id: Union[str, list[str]], target_node_id: str) -> Self: self.edges.append((source_node_id, target_node_id)) return self @@ -74,7 +77,7 @@ class PregelGraph(Graph): return self def compile(self, context: Context) -> ExecutableGraph: - for node in self.nodes: + for node_id, node in self.nodes.items(): node.init(context) if self.compiledStateGraph is None: self._pre_compile() diff --git a/tests/unit_tests/workflow/test_llm_comp.py b/tests/unit_tests/workflow/test_llm_comp.py index adf602e..cf85f94 100644 --- a/tests/unit_tests/workflow/test_llm_comp.py +++ b/tests/unit_tests/workflow/test_llm_comp.py @@ -157,3 +157,12 @@ class TestLLMExecutableInvoke: # 3. 直接异步调用 result = await flow.invoke(inputs={"a": 2}, context=context) assert result is not None + + async def test_llm_in_workflow_stream( + self, + mock_get_model, + fake_model_config, + ): + context = create_context() + + fake_llm = AsyncMock() \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index e957893..5b5c9be 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -199,3 +199,97 @@ class InteractiveNode4StreamCp(MockNodeBase): loop.run_until_complete( stream_writer.write(OutputSchema(type="output", index=0, payload=(self.node_id, result)))) return result + +class StreamCompNode(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + self._node_id = node_id + + async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + logger.info(f"===StreamCompNode[{self._node_id}], input: {inputs}") + if inputs is None: + yield 1 + else: + for i in range(1, 3): + yield {"value": i * inputs["value"]} + +class CollectCompNode(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + self._node_id = node_id + + async def collect(self, inputs: AsyncIterator[Input], context: Context) -> Output: + logger.info(f"===CollectCompNode[{self._node_id}], input stream started") + result = 0 + try: + async for input in inputs: + try: + value = input.get("value") + if value is None: + logger.warning(f"===CollectCompNode[{self._node_id}], missing 'value' in input: {input}") + continue + result += value + logger.info(f"===CollectCompNode[{self._node_id}], processed input: {input}") + except Exception as e: + logger.error(f"===CollectCompNode[{self._node_id}], error processing input: {input}, error: {e}") + continue # 可选:继续处理下一个输入 + return {"value": result} + except Exception as e: + logger.error(f"===CollectCompNode[{self._node_id}], critical error in collect: {e}") + raise # 重新抛出关键异常,如流中断 + +class TransformCompNode(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + self._node_id = node_id + + async def transform(self, inputs: AsyncIterator[Input], context: Context) -> AsyncIterator[Output]: + logger.debug(f"===TransformCompNode[{self._node_id}], input stream started") + try: + async for input in inputs: + try: + value = input.get("value") + logger.debug(f"===TransformCompNode[{self._node_id}], processed input: {value}") + yield {"value": value} + except Exception as e: + logger.error(f"===TransformCompNode[{self._node_id}], error processing input: {input}, error: {e}") + # 可选:继续处理下一个输入,或重新抛出异常以终止流 + continue + except Exception as e: + logger.error(f"===TransformCompNode[{self._node_id}], critical error in transform: {e}") + raise # 重新抛出关键异常(如流中断) + +class MultiCollectCompNode(MockNodeBase): + def __init__(self, node_id: str): + super().__init__(node_id) + self._node_id = node_id + self._is_stream_end = False + + async def invoke(self, inputs: Input, context: Context) -> Output: + while True: + if self._is_stream_end: + break + await asyncio.sleep(0.1) + + async def collect(self, inputs: AsyncIterator[Input], context: Context) -> Output: + logger.info(f"===CollectCompNode[{self._node_id}], input: {inputs}") + a_collect = 0 + b_collect = 0 + try: + async for input in inputs: + logger.info(f"===CollectCompNode[{self._node_id}], input: {input}") + a_value = input.get("a", {}).get("value") + if a_value is not None: + a_collect += a_value + + b_value = input.get("b", {}).get("value") + if b_value is not None: + b_collect += b_value + except Exception as e: + logger.error(f"Error during collection: {e}") + raise + # result = result + input["value"] + result = {"a_collect": a_collect, "b_collect": b_collect} + logger.info(f"===CollectCompNode243 [{self._node_id}], output: {result}") + self._is_stream_end = True + return result \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index da42c15..97adda3 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -19,11 +19,13 @@ from jiuwen.core.context.state import ReadableStateLike from jiuwen.core.graph.base import Graph from jiuwen.core.graph.executable import Input from jiuwen.core.graph.graph_state import GraphState +from jiuwen.core.workflow.base import Workflow +from jiuwen.core.workflow.workflow_config import WorkflowConfig, ComponentAbility from jiuwen.core.stream.base import BaseStreamMode from jiuwen.core.stream.writer import CustomSchema from jiuwen.core.workflow.base import WorkflowConfig, Workflow from jiuwen.graph.pregel.graph import PregelGraph -from test_mock_node import SlowNode, CountNode +from test_mock_node import SlowNode, CountNode, StreamCompNode, CollectCompNode, MultiCollectCompNode, TransformCompNode from test_node import AddTenNode, CommonNode from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode, Node1, StreamNode @@ -548,3 +550,54 @@ class WorkflowTest(unittest.TestCase): flow1.add_connection("a1", "end") flow1.add_connection("composite", "end") self.assert_workflow_invoke({"a1": 1, "a2": 2}, create_context(), flow1, expect_results={"b1": 1, "b2": 2}) + + def test_stream_comp_workflow(self): + # start -> a ---> b -> end + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), inputs_schema={"a": "${a}"}) + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) + flow.add_workflow_comp("b", CollectCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.COLLECT]) + flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result": "${b.value}"}) + flow.add_connection("start", "a") + flow.add_stream_connection("a", "b") + flow.add_connection("b", "end") + idx = 1 + self.assert_workflow_invoke({"a": idx}, create_context(), flow, expect_results={"result": idx * sum(range(1, 3))}) + + def test_transform_workflow(self): + # start -> a ---> b ---> c -> end + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), inputs_schema={"a": "${a}"}) + # a: throw 2 frames: {value: 1}, {value: 2} + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) + # b: transform 2 frames to c + flow.add_workflow_comp("b", TransformCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + # c: value = sum(value of frames) + flow.add_workflow_comp("c", CollectCompNode("c"), inputs_schema={"value": "${b.value}"}, stream_inputs_schema={"value": "${b.value}"}, comp_ability=[ComponentAbility.COLLECT]) + flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result": "${c.value}"}) + flow.add_connection("start", "a") + flow.add_stream_connection("a", "b") + flow.add_stream_connection("b", "c") + flow.add_connection("c", "end") + + self.assert_workflow_invoke({"a": 1}, create_context(), flow, expect_results={"result": 3}) + + def test_multi_stream_workflow(self): + # TODO: end 节点会比 c 节点先执行,需要解决 + # start -> a ---> c -> end + # | ^ + # v | + # b ------------ + flow = create_flow() + flow.set_start_comp("start", MockStartNode("a"), inputs_schema={"a": "${a}", "b": "${b}"}) + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) + flow.add_workflow_comp("b", StreamCompNode("b"), inputs_schema={"value": "${start.b}"}, comp_ability=[ComponentAbility.STREAM]) + flow.add_workflow_comp("c", MultiCollectCompNode("c"), inputs_schema={"a_collect": "${a.value}", "b_collect": "${b.value}"}, comp_ability=[ComponentAbility.INVOKE, ComponentAbility.COLLECT]) + flow.add_workflow_comp("end", MockEndNode("end"), inputs_schema={"result": "${c.a_collect}" + "${c.b_collect}"}) + flow.add_connection("start", "a") + flow.add_connection("start", "b") + flow.add_stream_connection("a", "c") + flow.add_stream_connection("b", "c") + flow.add_connection("c", "end") + idx = 1 + self.assert_workflow_invoke({"a": idx, "b": idx}, create_context(), flow, expect_results={"result": idx * sum(range(1, 3)) * 2}) \ No newline at end of file -- Gitee From 763f6c435090f8e374cfaacf47f2df656262c943 Mon Sep 17 00:00:00 2001 From: laihongsen Date: Thu, 24 Jul 2025 15:37:11 +0800 Subject: [PATCH 2/6] =?UTF-8?q?refactor(core):=20=E7=BB=84=E4=BB=B6?= =?UTF-8?q?=E9=97=B4=E6=B5=81=E5=BC=8F=E9=80=82=E9=85=8D=20context=20-=20?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86=20vertex.py=20=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E6=B5=81=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=20parent=5Fcontext()=20=E6=9B=BF=E4=BB=A3=E7=9B=B4?= =?UTF-8?q?=E6=8E=A5=E8=AE=BF=E9=97=AE=20context=20=E7=9A=84=E5=B1=9E?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/context/mq_manager.py | 3 -- jiuwen/core/graph/vertex.py | 37 ++++++++++------------ jiuwen/core/stream_actor/base.py | 1 - jiuwen/core/workflow/base.py | 7 ++-- tests/unit_tests/workflow/test_workflow.py | 22 +------------ 5 files changed, 19 insertions(+), 51 deletions(-) diff --git a/jiuwen/core/context/mq_manager.py b/jiuwen/core/context/mq_manager.py index 8c3951a..b9137de 100644 --- a/jiuwen/core/context/mq_manager.py +++ b/jiuwen/core/context/mq_manager.py @@ -12,9 +12,6 @@ class StreamTransform: return transformer(origin_message) def get_by_default_transformer(self, origin_message: dict, stream_inputs_schema: dict) -> dict: - # nodeC schema: {"ca": "${nodeA.a_output}", "cb": "${nodeB.b_output}" - # 根据stream_inputs_schema进行封装数据{"nodeA": {"a_output": "a"}} - # 或者{"nodeB": {"b_output": "b"}}转换成{"ca": "a"}或者{"cb": "b"} return get_by_schema(stream_inputs_schema, origin_message) diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 29c1446..db60162 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -10,10 +10,8 @@ from jiuwen.core.common.logging.base import logger from jiuwen.core.component.condition.condition import INDEX from jiuwen.core.component.loop_callback.loop_id import LOOP_ID from jiuwen.core.component.exec_workflow_base import ExecWorkflowBase -from jiuwen.core.component.workflow_comp import ExecWorkflowComponent from jiuwen.core.context.context import Context, NodeContext from jiuwen.core.context.utils import get_by_schema, NESTED_PATH_SPLIT -from jiuwen.core.graph.base import ExecutableGraph from jiuwen.core.graph.executable import Executable, Output from jiuwen.core.graph.graph_state import GraphState from jiuwen.core.workflow.workflow_config import ComponentAbility @@ -34,28 +32,23 @@ class Vertex: def get_executable(self) -> Executable: return self._executable - + async def _run_executable(self, ability: ComponentAbility, is_subgraph: bool = False, config: Any = None): if ability == ComponentAbility.INVOKE: batch_inputs = await self._pre_invoke() - logger.info("====_run_executable39: vertex[%s] invoke inputs %s", self._node_id, batch_inputs) if is_subgraph: batch_inputs = {INPUTS_KEY: batch_inputs, CONFIG_KEY: config} results = await self._executable.invoke(batch_inputs, context=self._context) - logger.info("====_run_executable43: vertex[%s] invoke result %s", self._node_id, results) await self._post_invoke(results) elif ability == ComponentAbility.STREAM: batch_inputs = await self._pre_invoke() - logger.info("====_run_executable47: vertex[%s] stream inputs %s", self._node_id, batch_inputs) if is_subgraph: batch_inputs = {INPUTS_KEY: batch_inputs, CONFIG_KEY: config} result_iter = self._executable.stream(batch_inputs, context=self._context) - logger.info("====_run_executable51: vertex[%s] stream result %s", self._node_id, result_iter) await self._post_stream(result_iter) elif ability == ComponentAbility.COLLECT: collect_iter = self._pre_stream(ability) batch_output = await self._executable.collect(collect_iter, self._context) - logger.info("====_run_executable55: vertex[%s] collect result %s", self._node_id, batch_output) await self._post_invoke(batch_output) elif ability == ComponentAbility.TRANSFORM: transform_iter = self._pre_stream(ability) @@ -76,7 +69,7 @@ class Vertex: else: inputs = self._context.state().get_inputs_by_transformer(inputs_transformer) if self._context.tracer() is not None: - await self._trace_inputs(inputs) + await self.__trace_inputs__(inputs) return inputs async def _post_invoke(self, results: Optional[dict]) -> Any: @@ -94,8 +87,8 @@ class Vertex: return results async def _pre_stream(self, ability: ComponentAbility) -> AsyncIterator[dict]: - queue_manager = self._context.queue_manager - workflow_config = self._context.config.get_workflow_config() + queue_manager = self._context.parent_context().queue_manager + workflow_config = self._context.parent_context().config().get_workflow_config() inputs_transformer = workflow_config.comp_stream_configs[self._node_id].inputs_transformer inputs_schema = workflow_config.comp_stream_configs[self._node_id].inputs_schema async for message in queue_manager.consume(self._node_id, ability): @@ -108,8 +101,8 @@ class Vertex: yield inputs async def _post_stream(self, results_iter: AsyncIterator) -> None: - queue_manager = self._context.queue_manager - workflow_config = self._context.config.get_workflow_config() + queue_manager = self._context.parent_context().queue_manager + workflow_config = self._context.parent_context().config().get_workflow_config() output_transformer = workflow_config.comp_stream_configs[self._node_id].outputs_transformer output_schema = workflow_config.comp_stream_configs[self._node_id].outputs_schema end_stream_index = 0 @@ -131,11 +124,11 @@ class Vertex: "index": ++end_stream_index, "payload": message } - await self._context.stream_writer_manager.get_output_writer().write(message_stream_data) + await self._context.parent_context().stream_writer_manager.get_output_writer().write(message_stream_data) elif end_node and sub_graph: - await self._context.queue_manager.sub_workflow_stream.send(message) + await self._context.parent_context().queue_manager.sub_workflow_stream.send(message) else: - await self._context.queue_manager.produce(self._node_id, message) + await self._context.parent_context().queue_manager.produce(self._node_id, message) def __clear_interactive__(self) -> None: @@ -153,7 +146,7 @@ class Vertex: self._context.state().update_trace(self._context.tracer().get_workflow_span(self._context.executable_id(), self._context.parent_id())) - if isinstance(self._executable, ExecWorkflowComponent): + if isinstance(self._executable, ExecWorkflowBase): self._context.tracer().register_workflow_span_manager(self._context.executable_id()) async def call(self, state: GraphState, config: Any = None, ExecGraphComponent=None): @@ -163,8 +156,9 @@ class Vertex: is_subgraph = self._executable.graph_invoker() try: - workflow_config = self._context.config.get_workflow_config() - component_ability = workflow_config.comp_abilities.get(self._node_id, []) + workflow_config = self._context.parent_context().config().get_workflow_config() + component_ability = workflow_config.comp_abilities.get(self._node_id) + component_ability = component_ability if component_ability else [ComponentAbility.INVOKE] call_ability = [ability for ability in component_ability if ability in [ComponentAbility.INVOKE, ComponentAbility.STREAM]] for ability in call_ability: @@ -182,11 +176,12 @@ class Vertex: self._stream_called = True # 标记 stream_call 已被调用 self._stream_done.clear() # 清除之前的完成状态 - if self._context is None or self._context.queue_manager is None: + if self._context is None or self._context.parent_context().queue_manager is None: raise JiuWenBaseException(1, "queue manager is not initialized") try: - component_ability = self._context.config.get_workflow_config().comp_abilities[self._node_id] + workflow_config = self._context.parent_context().config().get_workflow_config() + component_ability = workflow_config.comp_abilities.get(self._node_id) call_ability = [ability for ability in component_ability if ability in [ComponentAbility.COLLECT, ComponentAbility.TRANSFORM]] for ability in call_ability: diff --git a/jiuwen/core/stream_actor/base.py b/jiuwen/core/stream_actor/base.py index ed51700..c58d364 100644 --- a/jiuwen/core/stream_actor/base.py +++ b/jiuwen/core/stream_actor/base.py @@ -1,7 +1,6 @@ import asyncio from jiuwen.core.context.context import Context -from jiuwen.core.graph.executable import Executable from jiuwen.core.graph.vertex import Vertex diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 3f0b8fe..8054321 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -12,7 +12,7 @@ from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.component.end_comp import End from jiuwen.core.component.start_comp import Start from jiuwen.core.context.config import CompIOConfig, Transformer -from jiuwen.core.context.context import Context, ExecutableContext +from jiuwen.core.context.context import Context from jiuwen.core.context.mq_manager import MessageQueueManager from jiuwen.core.graph.base import Graph, Router, INPUTS_KEY, CONFIG_KEY, ExecutableGraph from jiuwen.core.graph.executable import Executable, Input, Output @@ -191,9 +191,8 @@ class Workflow(BaseWorkFlow): context: Context, stream_modes: list[StreamMode] = None ) -> AsyncIterator[WorkflowChunk]: - sub_graph = isinstance(context, ExecutableContext) mq_manager = MessageQueueManager(self._workflow_config.stream_edges, self._workflow_config.comp_abilities, - sub_graph) + False) context.set_queue_manager(mq_manager) context.set_stream_writer_manager(StreamWriterManager(stream_emitter=StreamEmitter(), modes=stream_modes)) if context.tracer() is None and (stream_modes is None or BaseStreamMode.TRACE in stream_modes): @@ -204,9 +203,7 @@ class Workflow(BaseWorkFlow): self._stream_actor.init(context) async def stream_process(): try: - logger.info("Starting stream process") await self._stream_actor.run() - logger.info("after stream actor run") await compiled_graph.invoke({INPUTS_KEY: inputs, CONFIG_KEY: None}, context) finally: await context.stream_writer_manager().stream_emitter.close() diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 97adda3..c825b9b 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -580,24 +580,4 @@ class WorkflowTest(unittest.TestCase): flow.add_stream_connection("b", "c") flow.add_connection("c", "end") - self.assert_workflow_invoke({"a": 1}, create_context(), flow, expect_results={"result": 3}) - - def test_multi_stream_workflow(self): - # TODO: end 节点会比 c 节点先执行,需要解决 - # start -> a ---> c -> end - # | ^ - # v | - # b ------------ - flow = create_flow() - flow.set_start_comp("start", MockStartNode("a"), inputs_schema={"a": "${a}", "b": "${b}"}) - flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) - flow.add_workflow_comp("b", StreamCompNode("b"), inputs_schema={"value": "${start.b}"}, comp_ability=[ComponentAbility.STREAM]) - flow.add_workflow_comp("c", MultiCollectCompNode("c"), inputs_schema={"a_collect": "${a.value}", "b_collect": "${b.value}"}, comp_ability=[ComponentAbility.INVOKE, ComponentAbility.COLLECT]) - flow.add_workflow_comp("end", MockEndNode("end"), inputs_schema={"result": "${c.a_collect}" + "${c.b_collect}"}) - flow.add_connection("start", "a") - flow.add_connection("start", "b") - flow.add_stream_connection("a", "c") - flow.add_stream_connection("b", "c") - flow.add_connection("c", "end") - idx = 1 - self.assert_workflow_invoke({"a": idx, "b": idx}, create_context(), flow, expect_results={"result": idx * sum(range(1, 3)) * 2}) \ No newline at end of file + self.assert_workflow_invoke({"a": 1}, create_context(), flow, expect_results={"result": 3}) \ No newline at end of file -- Gitee From 8596fead06d818369af3d4c8437eed11fa4783d0 Mon Sep 17 00:00:00 2001 From: laihongsen Date: Fri, 25 Jul 2025 09:53:45 +0800 Subject: [PATCH 3/6] =?UTF-8?q?test(workflow):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=A4=9A=E8=8A=82=E7=82=B9transform=E5=B7=A5=E4=BD=9C=E6=B5=81?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 test_five_transform_workflow 测试函数 --- tests/unit_tests/workflow/test_workflow.py | 30 ++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index c825b9b..9388850 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -580,4 +580,34 @@ class WorkflowTest(unittest.TestCase): flow.add_stream_connection("b", "c") flow.add_connection("c", "end") + self.assert_workflow_invoke({"a": 1}, create_context(), flow, expect_results={"result": 3}) + + def test_five_transform_workflow(self): + # start -> a ---> b ---> c ---> d ---> e ---> f ---> g -> end + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), inputs_schema={"a": "${a}"}) + # a: throw 2 frames: {value: 1}, {value: 2} + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) + # b: transform frame to c + flow.add_workflow_comp("b", TransformCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + # c: transform frame to d + flow.add_workflow_comp("c", TransformCompNode("c"), inputs_schema={"value": "${b.value}"}, stream_inputs_schema={"value": "${b.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + # d: transform frame to e + flow.add_workflow_comp("d", TransformCompNode("d"), inputs_schema={"value": "${c.value}"}, stream_inputs_schema={"value": "${c.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + # e: transform frame to f + flow.add_workflow_comp("e", TransformCompNode("e"), inputs_schema={"value": "${d.value}"}, stream_inputs_schema={"value": "${d.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + # f: transform frame to g + flow.add_workflow_comp("f", TransformCompNode("f"), inputs_schema={"value": "${e.value}"}, stream_inputs_schema={"value": "${e.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + # g: collect all frames + flow.add_workflow_comp("g", CollectCompNode("g"), inputs_schema={"value": "${f.value}"}, stream_inputs_schema={"value": "${f.value}"}, comp_ability=[ComponentAbility.COLLECT]) + flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result": "${g.value}"}) + flow.add_connection("start", "a") + flow.add_stream_connection("a", "b") + flow.add_stream_connection("b", "c") + flow.add_stream_connection("c", "d") + flow.add_stream_connection("d", "e") + flow.add_stream_connection("e", "f") + flow.add_stream_connection("f", "g") + flow.add_connection("g", "end") + self.assert_workflow_invoke({"a": 1}, create_context(), flow, expect_results={"result": 3}) \ No newline at end of file -- Gitee From 47b6e418e7ffa196ddaab4f9b07584f8aabd81cc Mon Sep 17 00:00:00 2001 From: laihongsen Date: Fri, 25 Jul 2025 10:19:39 +0800 Subject: [PATCH 4/6] =?UTF-8?q?fix(core):=20=E4=BF=AE=E5=A4=8D=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E7=BB=84=E4=BB=B6=E7=9A=84=E7=AD=89=E5=BE=85=E7=AD=96?= =?UTF-8?q?=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 Workflow 类中添加了对流式组件等待策略的检查 - 引入了 JiuWenBaseException异常类用于处理错误情况- 修改了测试用例,为流式组件添加 wait_for_all 参数 --- jiuwen/core/workflow/base.py | 5 ++++ tests/unit_tests/workflow/test_workflow.py | 34 +++++++++++----------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 8054321..67649a5 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -7,6 +7,7 @@ from typing import Self, Dict, Any, Union, AsyncIterator from pydantic import BaseModel from jiuwen.core.common.constants.constant import INTERACTION +from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.logging.base import logger from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.component.end_comp import End @@ -74,6 +75,10 @@ class BaseWorkFlow: outputs_transformer=stream_outputs_transformer) self._workflow_config.comp_abilities[ comp_id] = comp_ability if comp_ability is not None else [ComponentAbility.INVOKE] + for ability in self._workflow_config.comp_abilities[comp_id]: + if ability in [ComponentAbility.STREAM, ComponentAbility.TRANSFORM, ComponentAbility.COLLECT]: + if not wait_for_all: + raise JiuWenBaseException(-1, "stream components need to wait for all") return self def start_comp( diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index 9388850..cdd809b 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -553,27 +553,27 @@ class WorkflowTest(unittest.TestCase): def test_stream_comp_workflow(self): # start -> a ---> b -> end - flow = create_flow() + flow = Workflow(WorkflowConfig(), create_graph()) flow.set_start_comp("start", MockStartNode("start"), inputs_schema={"a": "${a}"}) - flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) - flow.add_workflow_comp("b", CollectCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.COLLECT]) - flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result": "${b.value}"}) + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM], wait_for_all=True) + flow.add_workflow_comp("b", CollectCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.COLLECT], wait_for_all=True) + flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result1": "${b.value}"}) flow.add_connection("start", "a") flow.add_stream_connection("a", "b") flow.add_connection("b", "end") idx = 1 - self.assert_workflow_invoke({"a": idx}, create_context(), flow, expect_results={"result": idx * sum(range(1, 3))}) + self.assert_workflow_invoke({"a": idx}, create_context(), flow, expect_results={"result1": idx * sum(range(1, 3))}) def test_transform_workflow(self): # start -> a ---> b ---> c -> end - flow = create_flow() + flow = Workflow(WorkflowConfig(), create_graph()) flow.set_start_comp("start", MockStartNode("start"), inputs_schema={"a": "${a}"}) # a: throw 2 frames: {value: 1}, {value: 2} - flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM], wait_for_all=True) # b: transform 2 frames to c - flow.add_workflow_comp("b", TransformCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + flow.add_workflow_comp("b", TransformCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.TRANSFORM], wait_for_all=True) # c: value = sum(value of frames) - flow.add_workflow_comp("c", CollectCompNode("c"), inputs_schema={"value": "${b.value}"}, stream_inputs_schema={"value": "${b.value}"}, comp_ability=[ComponentAbility.COLLECT]) + flow.add_workflow_comp("c", CollectCompNode("c"), inputs_schema={"value": "${b.value}"}, stream_inputs_schema={"value": "${b.value}"}, comp_ability=[ComponentAbility.COLLECT], wait_for_all=True) flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result": "${c.value}"}) flow.add_connection("start", "a") flow.add_stream_connection("a", "b") @@ -584,22 +584,22 @@ class WorkflowTest(unittest.TestCase): def test_five_transform_workflow(self): # start -> a ---> b ---> c ---> d ---> e ---> f ---> g -> end - flow = create_flow() + flow = Workflow(WorkflowConfig(), create_graph()) flow.set_start_comp("start", MockStartNode("start"), inputs_schema={"a": "${a}"}) # a: throw 2 frames: {value: 1}, {value: 2} - flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM]) + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${start.a}"}, comp_ability=[ComponentAbility.STREAM], wait_for_all=True) # b: transform frame to c - flow.add_workflow_comp("b", TransformCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + flow.add_workflow_comp("b", TransformCompNode("b"), inputs_schema={"value": "${a.value}"}, stream_inputs_schema={"value": "${a.value}"}, comp_ability=[ComponentAbility.TRANSFORM], wait_for_all=True) # c: transform frame to d - flow.add_workflow_comp("c", TransformCompNode("c"), inputs_schema={"value": "${b.value}"}, stream_inputs_schema={"value": "${b.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + flow.add_workflow_comp("c", TransformCompNode("c"), inputs_schema={"value": "${b.value}"}, stream_inputs_schema={"value": "${b.value}"}, comp_ability=[ComponentAbility.TRANSFORM], wait_for_all=True) # d: transform frame to e - flow.add_workflow_comp("d", TransformCompNode("d"), inputs_schema={"value": "${c.value}"}, stream_inputs_schema={"value": "${c.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + flow.add_workflow_comp("d", TransformCompNode("d"), inputs_schema={"value": "${c.value}"}, stream_inputs_schema={"value": "${c.value}"}, comp_ability=[ComponentAbility.TRANSFORM], wait_for_all=True) # e: transform frame to f - flow.add_workflow_comp("e", TransformCompNode("e"), inputs_schema={"value": "${d.value}"}, stream_inputs_schema={"value": "${d.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + flow.add_workflow_comp("e", TransformCompNode("e"), inputs_schema={"value": "${d.value}"}, stream_inputs_schema={"value": "${d.value}"}, comp_ability=[ComponentAbility.TRANSFORM], wait_for_all=True) # f: transform frame to g - flow.add_workflow_comp("f", TransformCompNode("f"), inputs_schema={"value": "${e.value}"}, stream_inputs_schema={"value": "${e.value}"}, comp_ability=[ComponentAbility.TRANSFORM]) + flow.add_workflow_comp("f", TransformCompNode("f"), inputs_schema={"value": "${e.value}"}, stream_inputs_schema={"value": "${e.value}"}, comp_ability=[ComponentAbility.TRANSFORM], wait_for_all=True) # g: collect all frames - flow.add_workflow_comp("g", CollectCompNode("g"), inputs_schema={"value": "${f.value}"}, stream_inputs_schema={"value": "${f.value}"}, comp_ability=[ComponentAbility.COLLECT]) + flow.add_workflow_comp("g", CollectCompNode("g"), inputs_schema={"value": "${f.value}"}, stream_inputs_schema={"value": "${f.value}"}, comp_ability=[ComponentAbility.COLLECT], wait_for_all=True) flow.set_end_comp("end", MockEndNode("end"), inputs_schema={"result": "${g.value}"}) flow.add_connection("start", "a") flow.add_stream_connection("a", "b") -- Gitee From 949491be4415ceec151d226ca2b29c4bf90db7d5 Mon Sep 17 00:00:00 2001 From: laihongsen Date: Fri, 25 Jul 2025 10:43:14 +0800 Subject: [PATCH 5/6] =?UTF-8?q?test:=20=E5=88=A0=E9=99=A4=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BB=A3=E7=A0=81=E4=B8=AD=E7=9A=84=E6=9C=AA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 删除了 test_llm_comp.py 文件中未使用的异步函数 test_llm_in_workflow_stream。。 --- tests/unit_tests/workflow/test_llm_comp.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/unit_tests/workflow/test_llm_comp.py b/tests/unit_tests/workflow/test_llm_comp.py index cf85f94..8c51e27 100644 --- a/tests/unit_tests/workflow/test_llm_comp.py +++ b/tests/unit_tests/workflow/test_llm_comp.py @@ -156,13 +156,4 @@ class TestLLMExecutableInvoke: # 3. 直接异步调用 result = await flow.invoke(inputs={"a": 2}, context=context) - assert result is not None - - async def test_llm_in_workflow_stream( - self, - mock_get_model, - fake_model_config, - ): - context = create_context() - - fake_llm = AsyncMock() \ No newline at end of file + assert result is not None \ No newline at end of file -- Gitee From cea4e44f54c02a89cb5553cf3ea7ed1c9a4ba818 Mon Sep 17 00:00:00 2001 From: laihongsen Date: Fri, 25 Jul 2025 17:22:45 +0800 Subject: [PATCH 6/6] =?UTF-8?q?refactor(core):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E9=98=9F=E5=88=97=E7=AE=A1=E7=90=86=E5=99=A8?= =?UTF-8?q?=E5=92=8C=E5=9B=BE=E5=BD=A2=E6=89=A7=E8=A1=8C=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 status_code.py 中添加消息队列管理器相关的错误代码 - 优化 MessageQueueManager 类的实现,移除冗余代码 - 更新 Executable 类中的异常处理方式 - 调整 Vertex 类中的 _run_executable 方法 - 修改测试用例以适应新的执行逻辑 --- jiuwen/core/common/exception/status_code.py | 3 +++ jiuwen/core/context/mq_manager.py | 12 ++++++------ jiuwen/core/graph/executable.py | 11 ++++------- jiuwen/core/graph/vertex.py | 3 --- tests/unit_tests/workflow/test_mock_node.py | 12 ++---------- 5 files changed, 15 insertions(+), 26 deletions(-) diff --git a/jiuwen/core/common/exception/status_code.py b/jiuwen/core/common/exception/status_code.py index 8d1349f..dae2471 100644 --- a/jiuwen/core/common/exception/status_code.py +++ b/jiuwen/core/common/exception/status_code.py @@ -62,6 +62,9 @@ class StatusCode(Enum): WORKFLOW_INTENT_DETECTION_LLM_INVOKE_ERROR = (101096, "Model invoke failed with error message = {error_msg}") WORKFLOW_INTENT_DETECTION_PROMPT_INVOKE_ERROR = (101098, "Prompt invoke failed with error message = {error_msg}") + # message queue manager 101,711-101,719 + WORKFLOW_MESSAGE_QUEUE_MANAGER_ERROR = (101711, "Message queue manager error: {error_msg}") + @property def code(self): return self.value[0] diff --git a/jiuwen/core/context/mq_manager.py b/jiuwen/core/context/mq_manager.py index b9137de..f850e96 100644 --- a/jiuwen/core/context/mq_manager.py +++ b/jiuwen/core/context/mq_manager.py @@ -1,5 +1,7 @@ from typing import Dict, Any, AsyncIterator +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode from jiuwen.core.common.logging.base import logger from jiuwen.core.context.state import Transformer from jiuwen.core.context.utils import get_by_schema @@ -18,11 +20,10 @@ class StreamTransform: class MessageQueueManager: def __init__(self, stream_edges: dict[str, list[str]], comp_abilities: dict[str, list[ComponentAbility]], sub_graph: bool): - self._stream_edges: Dict[str, list[str]] = {} + self._stream_edges = stream_edges self._streams: Dict[str, dict[ComponentAbility, AsyncStreamQueue]] = {} self._streams_transform = StreamTransform() for producer_id, consumer_ids in stream_edges.items(): - self._stream_edges[producer_id] = consumer_ids for consumer_id in consumer_ids: consumer_stream_ability = [ability for ability in comp_abilities[consumer_id] if ability in [ComponentAbility.COLLECT, ComponentAbility.TRANSFORM]] @@ -34,7 +35,9 @@ class MessageQueueManager: @property def sub_workflow_stream(self): if not self._sub_graph: - raise RuntimeError("only sub graph has sub_workflow_stream") + raise JiuWenBaseException( + error_code=StatusCode.WORKFLOW_MESSAGE_QUEUE_MANAGER_ERROR.code, + message=f"only sub graph has sub_workflow_stream") return self._sub_workflow_stream def _get_queue(self, consumer_id: str) -> dict[ComponentAbility, AsyncStreamQueue]: @@ -96,6 +99,3 @@ class MessageQueueManager: for consumer_id in list(self._streams.keys()): await self.close_stream(consumer_id) self._streams.clear() - - def is_empty(self, node_id) -> bool: - return self._streams[node_id] is None \ No newline at end of file diff --git a/jiuwen/core/graph/executable.py b/jiuwen/core/graph/executable.py index a8ddcd9..272669f 100644 --- a/jiuwen/core/graph/executable.py +++ b/jiuwen/core/graph/executable.py @@ -1,11 +1,8 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved -import asyncio import json -from abc import ABC, abstractmethod -from functools import partial -from typing import TypeVar, Generic, Iterator, AsyncIterator, Any +from typing import TypeVar, Generic, AsyncIterator, Any from jiuwen.core.common.exception.exception import InterruptException, JiuWenBaseException from jiuwen.core.common.exception.status_code import StatusCode @@ -27,13 +24,13 @@ class Executable(Generic[Input, Output]): raise JiuWenBaseException(-1, "Invoke is not supported") async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - raise JiuWenBaseException(-1, "Invoke is not supported") + raise JiuWenBaseException(-1, "Stream is not supported") async def collect(self, inputs: AsyncIterator[Input], contex: Context) -> Output: - raise JiuWenBaseException(-1, "Invoke is not supported") + raise JiuWenBaseException(-1, "Collect is not supported") async def transform(self, inputs: AsyncIterator[Input], context: Context) -> AsyncIterator[Output]: - raise JiuWenBaseException(-1, "Invoke is not supported") + raise JiuWenBaseException(-1, "Transform is not supported") async def interrupt(self, message: dict): raise InterruptException( diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index db60162..1e6b3f6 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -30,9 +30,6 @@ class Vertex: self._context = NodeContext(context, self._node_id) return True - def get_executable(self) -> Executable: - return self._executable - async def _run_executable(self, ability: ComponentAbility, is_subgraph: bool = False, config: Any = None): if ability == ComponentAbility.INVOKE: batch_inputs = await self._pre_invoke() diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index 5b5c9be..c1c2673 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -263,13 +263,6 @@ class MultiCollectCompNode(MockNodeBase): def __init__(self, node_id: str): super().__init__(node_id) self._node_id = node_id - self._is_stream_end = False - - async def invoke(self, inputs: Input, context: Context) -> Output: - while True: - if self._is_stream_end: - break - await asyncio.sleep(0.1) async def collect(self, inputs: AsyncIterator[Input], context: Context) -> Output: logger.info(f"===CollectCompNode[{self._node_id}], input: {inputs}") @@ -278,11 +271,11 @@ class MultiCollectCompNode(MockNodeBase): try: async for input in inputs: logger.info(f"===CollectCompNode[{self._node_id}], input: {input}") - a_value = input.get("a", {}).get("value") + a_value = input.get("value", {}).get("a") if a_value is not None: a_collect += a_value - b_value = input.get("b", {}).get("value") + b_value = input.get("value", {}).get("b") if b_value is not None: b_collect += b_value except Exception as e: @@ -291,5 +284,4 @@ class MultiCollectCompNode(MockNodeBase): # result = result + input["value"] result = {"a_collect": a_collect, "b_collect": b_collect} logger.info(f"===CollectCompNode243 [{self._node_id}], output: {result}") - self._is_stream_end = True return result \ No newline at end of file -- Gitee