diff --git a/jiuwen/core/component/base.py b/jiuwen/core/component/base.py index c7d837a54ce3c55a673f31c8be8592da9e0ac39e..4c48274fc2f5b8c26bbd9784bb208eacd3c70a79 100644 --- a/jiuwen/core/component/base.py +++ b/jiuwen/core/component/base.py @@ -40,24 +40,3 @@ class WorkflowComponent(ABC): def to_executable(self) -> Executable: pass - - - -class ComponentAbility(Enum): - INVOKE = ("invoke", "batch in, batch out") - STREAM = ("stream", "batch in, stream out") - COLLECT = ("collect", "stream in, batch out") - TRANSFORM = ("transform", "stream in, stream out") - - def __init__(self, name: str, desc: str): - self._name = name - self._desc = desc - - @property - def name(self) -> str: - return self._name - - @property - def desc(self) -> str: - return self._desc - diff --git a/jiuwen/core/context/context.py b/jiuwen/core/context/context.py index 1c8524045363396403e2dac1bbfe3c2fe914291a..c074e157ef89d8a00a12397d3c463d0837bbc991 100644 --- a/jiuwen/core/context/context.py +++ b/jiuwen/core/context/context.py @@ -43,6 +43,10 @@ class Context(ABC): def controller_context_manager(self): pass + @abstractmethod + def queue_manager(self) -> MessageQueueManager: + pass + @abstractmethod def session_id(self) -> str: pass @@ -74,7 +78,7 @@ class WorkflowContext(Context): self.__stream_writer_manager: StreamWriterManager = None self.__controller_context_manager = None self.__session_id = session_id if session_id else uuid.uuid4().hex - self.queue_manager: MessageQueueManager = None + self.__queue_manager: MessageQueueManager = None def set_stream_writer_manager(self, stream_writer_manager: StreamWriterManager) -> None: if self.__stream_writer_manager is not None: @@ -109,9 +113,12 @@ class WorkflowContext(Context): return self.__controller_context_manager def set_queue_manager(self, queue_manager: MessageQueueManager): - if self.queue_manager is not None: + if self.__queue_manager is not None: return - self.queue_manager = queue_manager + self.__queue_manager = queue_manager + + def queue_manager(self) -> MessageQueueManager: + return self.__queue_manager def session_id(self) -> str: return self.__session_id @@ -158,6 +165,9 @@ class NodeContext(Context): def controller_context_manager(self): return self.__context.controller_context_manager() + def queue_manager(self) -> MessageQueueManager: + return self.__context.queue_manager() + def session_id(self) -> str: return self.__context.session_id() diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index 0e3676215426be8ea780a2bc1147818b39c0fd1b..330f9236fffd680f30ef321f235cd6cf1549be41 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -8,6 +8,7 @@ from jiuwen.core.common.constants.constant import INTERACTIVE_INPUT, END_NODE_ST from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.logging.base import logger from jiuwen.core.component.condition.condition import INDEX +from jiuwen.core.component.end_comp import End from jiuwen.core.component.loop_callback.loop_id import LOOP_ID from jiuwen.core.component.exec_workflow_base import ExecWorkflowBase from jiuwen.core.context.context import Context, NodeContext @@ -86,8 +87,8 @@ class Vertex: return results async def _pre_stream(self, ability: ComponentAbility) -> AsyncIterator[dict]: - queue_manager = self._context.parent_context().queue_manager - workflow_config = self._context.parent_context().config().get_workflow_config() + queue_manager = self._context.queue_manager() + workflow_config = self._context.config().get_workflow_config() inputs_transformer = workflow_config.comp_stream_configs[self._node_id].inputs_transformer inputs_schema = workflow_config.comp_stream_configs[self._node_id].inputs_schema async for message in queue_manager.consume(self._node_id, ability): @@ -100,8 +101,8 @@ class Vertex: yield inputs async def _post_stream(self, results_iter: AsyncIterator) -> None: - queue_manager = self._context.parent_context().queue_manager - workflow_config = self._context.parent_context().config().get_workflow_config() + queue_manager = self._context.queue_manager() + workflow_config = self._context.config().get_workflow_config() output_transformer = workflow_config.comp_stream_configs[self._node_id].outputs_transformer output_schema = workflow_config.comp_stream_configs[self._node_id].outputs_schema end_stream_index = 0 @@ -115,19 +116,19 @@ class Vertex: await queue_manager.end_message(self._node_id) async def _process_chunk(self, end_stream_index: int, message: Any) -> None: - end_node = False - sub_graph = self._context.parent_id is not None + end_node = isinstance(self._executable, End) + sub_graph = self._context.parent_id() is not '' if end_node and not sub_graph: message_stream_data = { "type": END_NODE_STREAM, "index": ++end_stream_index, "payload": message } - await self._context.parent_context().stream_writer_manager.get_output_writer().write(message_stream_data) + await self._context.stream_writer_manager().get_output_writer().write(message_stream_data) elif end_node and sub_graph: - await self._context.parent_context().queue_manager.sub_workflow_stream.send(message) + await self._context.queue_manager().sub_workflow_stream.send(message) else: - await self._context.parent_context().queue_manager.produce(self._node_id, message) + await self._context.queue_manager().produce(self._node_id, message) def __clear_interactive__(self) -> None: @@ -148,14 +149,14 @@ class Vertex: if isinstance(self._executable, ExecWorkflowBase): self._context.tracer().register_workflow_span_manager(self._context.executable_id()) - async def call(self, state: GraphState, config: Any = None, ExecGraphComponent=None): + async def call(self, state: GraphState, config: Any = None): if self._context is None or self._executable is None: raise JiuWenBaseException(1, "vertex is not initialized, node is is " + self._node_id) is_subgraph = self._executable.graph_invoker() try: - workflow_config = self._context.parent_context().config().get_workflow_config() + workflow_config = self._context.config().get_workflow_config() component_ability = workflow_config.comp_abilities.get(self._node_id) component_ability = component_ability if component_ability else [ComponentAbility.INVOKE] call_ability = [ability for ability in component_ability if @@ -175,11 +176,11 @@ class Vertex: self._stream_called = True # 标记 stream_call 已被调用 self._stream_done.clear() # 清除之前的完成状态 - if self._context is None or self._context.parent_context().queue_manager is None: + if self._context is None or self._context.queue_manager() is None: raise JiuWenBaseException(1, "queue manager is not initialized") try: - workflow_config = self._context.parent_context().config().get_workflow_config() + workflow_config = self._context.config().get_workflow_config() component_ability = workflow_config.comp_abilities.get(self._node_id) call_ability = [ability for ability in component_ability if ability in [ComponentAbility.COLLECT, ComponentAbility.TRANSFORM]]