diff --git a/jiuwen/core/context/state.py b/jiuwen/core/context/state.py index 09f7707c84cc6840e7a3856cfc9263776f669494..ac5d11fbd2d6748307cb4ac0d547eceda79ebb1d 100644 --- a/jiuwen/core/context/state.py +++ b/jiuwen/core/context/state.py @@ -79,8 +79,8 @@ class State(ABC): return return self._io_state.get(key) - def update_trace(self, span): - self._trace_state.update({self._node_id: span}) + def update_trace(self, invoke_id: str, span): + self._trace_state.update({invoke_id: span}) def update_comp(self, data: dict) -> None: if self._comp_state is None: @@ -134,4 +134,3 @@ class State(ABC): "global": self._global_state.get_updates(self._node_id), "comp": self._comp_state.get_updates(self._node_id) } - diff --git a/jiuwen/core/graph/vertex.py b/jiuwen/core/graph/vertex.py index fdf3c20b52510d9a44fe7798f67dd93089406c70..2687912adbdfd9dc405ff8dfa4efefb4d3cd239e 100644 --- a/jiuwen/core/graph/vertex.py +++ b/jiuwen/core/graph/vertex.py @@ -11,6 +11,8 @@ from jiuwen.core.context.utils import get_by_schema from jiuwen.core.graph.base import ExecutableGraph, INPUTS_KEY, CONFIG_KEY from jiuwen.core.graph.executable import Executable, Output from jiuwen.core.graph.graph_state import GraphState +from jiuwen.core.tracer.tracer import Tracer + class Vertex: def __init__(self, node_id: str, executable: Executable = None): @@ -72,28 +74,28 @@ class Vertex: async def __trace_inputs__(self, inputs: Optional[dict]) -> None: # TODO 组件信息 - # TODO 传入嵌套组件的invoke_id作为被嵌套组件的parent_invoke_id - # # - """ - EI: workflow: - start -> 嵌套-> llm -> end, start.parent_invoke_id="", llm.parent_invoke_id="start", end.parent_invoke_id="llm" - 子workflow: - start -> llm -> end, start.parent_invoke_id="", llm.parent_invoke_id="start", end.parent_invoke_id="llm" - parentNodeId = "父workflow中的嵌套组件id" - (子workflow情况,通过parentNodeId字段记录触发调用的component_id,获取嵌套关系) - """ - # TODO 如果当前组件是workflow 并且 - trace_workflow_span_manager = self._context.tracer.tracer_workflow_span_manager() - tracer_workflow_span = trace_workflow_span_manager.create_workflow_span(trace_workflow_span_manager.last_span) - await self._context.tracer.trigger("tracer_workflow", "on_pre_invoke", span=tracer_workflow_span, inputs=inputs, - component_metadata={"component_type": "component_type"}) - self._context.state.update_trace(tracer_workflow_span) + + await self._context.tracer.trigger("tracer_workflow", "on_pre_invoke", invoke_id=self._context.executable_id, + inputs=inputs, + component_metadata={"component_type": self._context.executable_id}) + self._context.state.update_trace(self._node_id, + self._context.tracer.tracer_workflow_span_manager.get_span(self._node_id)) + + if isinstance(self._executable, ExecWorkflowComponent): + self._origin_tracer = self._context.tracer + sub_tracer = Tracer(tracer_id=self._context.tracer._trace_id, parent_node_id=self._context.executable_id) + sub_tracer.init(self._context.stream_writer_manager, self._origin_tracer._callback_manager) + self._context.set_tracer(sub_tracer) async def __trace_outputs__(self, outputs: Optional[dict] = None) -> None: - trace_workflow_span = self._context.tracer.tracer_workflow_span_manager().last_span - await self._context.tracer.trigger("tracer_workflow", "on_post_invoke", span=trace_workflow_span, + if isinstance(self._executable, ExecWorkflowComponent): + self._context.set_tracer(self._origin_tracer) + + await self._context.tracer.trigger("tracer_workflow", "on_post_invoke", invoke_id=self._context.executable_id, outputs=outputs) - self._context.state.update_trace(trace_workflow_span) + self._context.state.update_trace(self._context.executable_id, + self._context.tracer.tracer_workflow_span_manager.get_span( + self._context.executable_id)) def __is_stream__(self, state: GraphState) -> bool: return False diff --git a/jiuwen/core/runtime/callback_manager.py b/jiuwen/core/runtime/callback_manager.py index 68f60954f5129df90b2717c27006316372248c17..fa8a5144c87835892b2ca50b6e52f72580c2bab7 100644 --- a/jiuwen/core/runtime/callback_manager.py +++ b/jiuwen/core/runtime/callback_manager.py @@ -48,7 +48,7 @@ class CallbackManager: if handler_class_name not in self._trigger_events or event_name not in self._trigger_events[ handler_class_name ]: - raise TypeError("event name not exists") + raise TypeError(f"event name not exists: {handler_class_name}, {event_name}") handler = self._handlers[handler_class_name] if hasattr(handler, event_name): method = getattr(handler, event_name) diff --git a/jiuwen/core/tracer/handler.py b/jiuwen/core/tracer/handler.py index a99018e4fd0fb8f2f3d6933faf5d5a425acc5c6e..3dba858f8eeef6dc13ee729710594e9198ddab71 100644 --- a/jiuwen/core/tracer/handler.py +++ b/jiuwen/core/tracer/handler.py @@ -65,6 +65,12 @@ class TraceAgentHandler(TraceBaseHandler): def _format_data(self, span: TraceAgentSpan) -> dict: return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)} + def _get_tracer_agent_span(self, invoke_id: str) -> TraceAgentSpan: + span = self._span_manager.get_span(invoke_id) + if span is not None: + return span + return self._span_manager.create_agent_span(self._span_manager.last_span) + def _update_start_trace_data(self, span: TraceAgentSpan, invoke_type: str, inputs: Any, instance_info: dict, **kwargs): try: @@ -226,10 +232,16 @@ class TraceWorkflowHandler(TraceBaseHandler): return NodeStatus.FINISH.value if span.outputs else NodeStatus.RUNNING.value return NodeStatus.START.value + def _get_tracer_workflow_span(self, invoke_id: str) -> TraceWorkflowSpan: + span = self._span_manager.get_span(invoke_id) + if span is not None: + return span + return self._span_manager.create_workflow_span(invoke_id, self._span_manager.last_span) + @trigger_event - async def on_pre_invoke(self, span: TraceWorkflowSpan, inputs: Any, component_metadata: dict, + async def on_pre_invoke(self, invoke_id: str, inputs: Any, component_metadata: dict, **kwargs): - + span = self._get_tracer_workflow_span(invoke_id) try: meta_data = json.dumps({ "component_id": component_metadata.get("component_id", ""), @@ -252,8 +264,8 @@ class TraceWorkflowHandler(TraceBaseHandler): await self._send_data(span) @trigger_event - async def on_invoke(self, span: TraceWorkflowSpan, on_invoke_data: dict, exception: dict = None, **kwargs): - + async def on_invoke(self, invoke_id: str, on_invoke_data: dict, exception: dict = None, **kwargs): + span = self._get_tracer_workflow_span(invoke_id) update_data = {} end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None) if exception is not None: @@ -282,8 +294,8 @@ class TraceWorkflowHandler(TraceBaseHandler): self._span_manager.update_span(span, {}) @trigger_event - async def on_post_invoke(self, span: TraceWorkflowSpan, outputs, inputs=None, **kwargs): - + async def on_post_invoke(self, invoke_id: str, outputs, inputs=None, **kwargs): + span = self._get_tracer_workflow_span(invoke_id) end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None) update_data = { "outputs": outputs, diff --git a/jiuwen/core/tracer/span.py b/jiuwen/core/tracer/span.py index 56a02088f6ba1b4fdb812acf305ae2e7dc15da59..91c7aaa9aeb97fb424a3c590dd0a21feda8e656e 100644 --- a/jiuwen/core/tracer/span.py +++ b/jiuwen/core/tracer/span.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional, Dict, List, Callable from pydantic import ConfigDict, Field, BaseModel + class Span(BaseModel): trace_id: str start_time: datetime = Field(default=None, alias="startTime") @@ -13,74 +14,88 @@ class Span(BaseModel): invoke_id: str = Field(default=None, alias="invokeId") parent_invoke_id: Optional[str] = Field(default=None, alias="parentInvokeId") child_invokes_id: List[str] = Field(default=[], alias="childInvokes") - + model_config = ConfigDict(populate_by_name=True) - + def update(self, data: dict): for attr_name, value in data.items(): if not hasattr(self, attr_name): continue setattr(self, attr_name, value) - + + class TraceAgentSpan(Span): invoke_type: str = Field(default=None, alias="invokeType") name: str = Field(default=None, alias="name") elapsed_time: Optional[str] = Field(default=None, alias="elapsedTime") - meta_data: Optional[dict] = Field(default=None, alias="metaData") # include llm function tools and token infos - + meta_data: Optional[dict] = Field(default=None, alias="metaData") # include llm function tools and token infos + + class TraceWorkflowSpan(Span): execution_id: str = Field(default="", alias="executionId") conversation_id: str = Field(default="", alias="conversationId") - on_invoke_data: List[dict] = Field(default=[], alias="onInvokeData") # 用于记录当前组件执行时间的中间过程信息 + on_invoke_data: List[dict] = Field(default=[], alias="onInvokeData") # 用于记录当前组件执行时间的中间过程信息 agent_id: str = Field(default="", alias="agentId") - component_id: str = Field(default="", alias="componentId") # 放到metadata - component_name: str = Field(default="", alias="componentName") # 放到metadata - component_type: str = Field(default="", alias="componentType") # 即invoke_type - agent_parent_invoke_id: str = Field(default="", alias="agentParentInvokeId") # 给未来适配workflow节点中嵌套workflow预留 - meta_data: Optional[dict] = Field(default=None, alias="metaData") # 包括:模型的输入的function tools信息,模型的token使用信息 + component_id: str = Field(default="", alias="componentId") # 放到metadata + component_name: str = Field(default="", alias="componentName") # 放到metadata + component_type: str = Field(default="", alias="componentType") # 即invoke_type + agent_parent_invoke_id: str = Field(default="", alias="agentParentInvokeId") # 给未来适配workflow节点中嵌套workflow预留 + meta_data: Optional[dict] = Field(default=None, alias="metaData") # 包括:模型的输入的function tools信息,模型的token使用信息 # for loop component loop_node_id: Optional[str] = Field(default=None, alias="loopNodeId") loop_index: Optional[int] = Field(default=None, alias="loopIndex") # node status status: Optional[str] = Field(default=None, alias="status") # for llm invoke data - llm_invoke_data: Dict[str, dict] = Field(default=[], exclude=True) # 模型数据,临时存储 - parent_component_id: str = Field(default="", exclude=True) # 用于嵌套情况记录父组件 - + llm_invoke_data: Dict[str, dict] = Field(default=[], exclude=True) # 模型数据,临时存储 + # for subworkflow + parent_node_id: str = Field(default="", alias="parentNodeId") + + class SpanManager: """用于管理tracer handler运行期间的span""" - def __init__(self, trace_id: str): + + def __init__(self, trace_id: str, parent_node_id: str = ""): self._trace_id = trace_id + self._parent_node_id = parent_node_id self._order = [] self._runtime_spans = {} - + + def get_span(self, invoke_id: str): + if invoke_id not in self._order: + return None + return self._runtime_spans.get(invoke_id, None) + def refresh_span_record(self, invoke_id: str, runtime_span: Dict[str, Span]): if invoke_id not in self._order: self._order.append(invoke_id) self._runtime_spans[invoke_id] = runtime_span[invoke_id] - - def _create_span(self, span_class: Callable, parent_span = None): - invoke_id = str(uuid.uuid4()) - span = span_class(invoke_id=invoke_id, parent_invoke_id=parent_span.invoke_id if parent_span else None, - trace_id=self._trace_id) + def _refresh_parent_child_span(self, span, parent_span=None): if parent_span: parent_span.child_invokes_id.append(span.invoke_id) self.refresh_span_record(parent_span.invoke_id, {parent_span.invoke_id: parent_span}) - self.refresh_span_record(invoke_id, {invoke_id: span}) - - return span - + self.refresh_span_record(span.invoke_id, {span.invoke_id: span}) + def create_agent_span(self, parent_span: Optional[TraceAgentSpan] = None) -> TraceAgentSpan: - return self._create_span(TraceAgentSpan, parent_span) - - def create_workflow_span(self, parent_span: Optional[TraceWorkflowSpan] = None) -> TraceWorkflowSpan: - return self._create_span(TraceWorkflowSpan, parent_span) - + invoke_id = str(uuid.uuid4()) + span = TraceAgentSpan(invoke_id=invoke_id, parent_invoke_id=parent_span.invoke_id if parent_span else None, + trace_id=self._trace_id) + self._refresh_parent_child_span(span, parent_span) + return span + + def create_workflow_span(self, invoke_id: str, + parent_span: Optional[TraceWorkflowSpan] = None) -> TraceWorkflowSpan: + span = TraceWorkflowSpan(invoke_id=invoke_id, parent_invoke_id=parent_span.invoke_id if parent_span else None, + trace_id=self._trace_id, parent_node_id=self._parent_node_id, + execution_id=self._trace_id) + self._refresh_parent_child_span(span, parent_span) + return span + def update_span(self, span: Span, data: dict): span.update(data) self.refresh_span_record(span.invoke_id, {span.invoke_id: span}) - + def end_span(self): pass @@ -92,4 +107,3 @@ class SpanManager: if last_span_id not in self._runtime_spans: return None return self._runtime_spans[last_span_id] - \ No newline at end of file diff --git a/jiuwen/core/tracer/tracer.py b/jiuwen/core/tracer/tracer.py index ad5de56b2988f475e41d395722336257b864427a..5ea7ff09e0a35716e746a1885df391027f650f29 100644 --- a/jiuwen/core/tracer/tracer.py +++ b/jiuwen/core/tracer/tracer.py @@ -5,27 +5,30 @@ from jiuwen.core.tracer.span import SpanManager class Tracer: - def __init__(self): + def __init__(self, tracer_id=None, parent_node_id=""): self._callback_manager = None - self._trace_id = str(uuid.uuid4()) - self._tracer_agent_span_manager = SpanManager(self._trace_id) - self._tracer_workflow_span_manager = SpanManager(self._trace_id) + self._trace_id = str(uuid.uuid4()) if tracer_id is None else tracer_id + self.tracer_agent_span_manager = SpanManager(self._trace_id) + self.tracer_workflow_span_manager = SpanManager(self._trace_id, parent_node_id=parent_node_id) + self._parent_node_id = parent_node_id def init(self, stream_writer_manager, callback_manager): - trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, - self._tracer_agent_span_manager) - trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, - self._tracer_workflow_span_manager) - callback_manager.register_handler({TracerHandlerName.TRACE_AGENT.value: trace_agent_handler}) - callback_manager.register_handler({TracerHandlerName.TRACER_WORKFLOW.value: trace_workflow_handler}) + # 用于注册子workflow tracer handler,子workflow中使用新的tracer handler + if self._parent_node_id != "": + trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, + self.tracer_workflow_span_manager) + callback_manager.register_handler( + {TracerHandlerName.TRACER_WORKFLOW.value + "." + self._parent_node_id: trace_workflow_handler}) + else: + trace_agent_handler = TraceAgentHandler(callback_manager, stream_writer_manager, + self.tracer_agent_span_manager) + trace_workflow_handler = TraceWorkflowHandler(callback_manager, stream_writer_manager, + self.tracer_workflow_span_manager) + callback_manager.register_handler({TracerHandlerName.TRACE_AGENT.value: trace_agent_handler}) + callback_manager.register_handler({TracerHandlerName.TRACER_WORKFLOW.value: trace_workflow_handler}) self._callback_manager = callback_manager async def trigger(self, handler_class_name: str, event_name: str, **kwargs): + handler_class_name += "." + self._parent_node_id if self._parent_node_id != "" else "" await self._callback_manager.trigger(handler_class_name, event_name, **kwargs) - - def tracer_agent_span_manager(self): - return self._tracer_agent_span_manager - - def tracer_workflow_span_manager(self): - return self._tracer_workflow_span_manager diff --git a/tests/unit_tests/tracer/test_mock_node_with_tracer.py b/tests/unit_tests/tracer/test_mock_node_with_tracer.py index 3c8b41c963f979cebc43462f0e44d4c54b283569..96ee4628e5594562b85f56ea185e37964874dfee 100644 --- a/tests/unit_tests/tracer/test_mock_node_with_tracer.py +++ b/tests/unit_tests/tracer/test_mock_node_with_tracer.py @@ -1,8 +1,8 @@ import asyncio -from jiuwen.core.context.config import Config +import random + from jiuwen.core.context.context import Context -from jiuwen.core.context.memory.base import InMemoryState -from jiuwen.core.graph.executable import Executable, Input, Output +from jiuwen.core.graph.executable import Input, Output from jiuwen.core.workflow.base import Workflow from tests.unit_tests.workflow.test_mock_node import MockNodeBase @@ -15,11 +15,11 @@ class StreamNodeWithTracer(MockNodeBase): async def invoke(self, inputs: Input, context: Context) -> Output: context.state.set_outputs(self.node_id, inputs) - trace_workflow_span = context.tracer.tracer_workflow_span_manager().last_span - await context.tracer.trigger("tracer_workflow", "on_invoke", span=trace_workflow_span, + await context.tracer.trigger("tracer_workflow", "on_invoke", invoke_id=context.executable_id, on_invoke_data={"on_invoke_data": "mock with" + str(inputs)}) - context.state.update_trace(trace_workflow_span) - await asyncio.sleep(5) + context.state.update_trace(context.executable_id, + context.tracer.tracer_workflow_span_manager.get_span(context.executable_id)) + await asyncio.sleep(random.randint(0, 5)) for data in self._datas: await asyncio.sleep(1) await context.stream_writer_manager.get_custom_writer().write(data) diff --git a/tests/unit_tests/tracer/test_workflow.py b/tests/unit_tests/tracer/test_workflow.py index 8f88f43aaf8a1c2b1a14a8b223b60035de1c125d..2a5421ffe41a639edbf5e352fbfa5a970a626561 100644 --- a/tests/unit_tests/tracer/test_workflow.py +++ b/tests/unit_tests/tracer/test_workflow.py @@ -1,7 +1,10 @@ +import json import sys import types from unittest.mock import Mock +from jiuwen.core.component.workflow_comp import ExecWorkflowComponent + fake_base = types.ModuleType("base") fake_base.logger = Mock() @@ -26,12 +29,10 @@ from jiuwen.core.workflow.base import WorkflowConfig, Workflow from jiuwen.core.stream.writer import CustomSchema from jiuwen.graph.pregel.graph import PregelGraph from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode -from jiuwen.core.tracer.tracer import Tracer from jiuwen.core.stream.writer import TraceSchema def create_context_with_tracer() -> Context: - tracer = Tracer() return Context(config=Config(), state=InMemoryState(), store=None) @@ -46,6 +47,17 @@ def create_flow() -> Workflow: DEFAULT_WORKFLOW_CONFIG = WorkflowConfig() +def record_tracer_info(tracer_chunks, file_path): + try: + with open(file_path, "w", encoding="utf-8") as f: + for chunk in tracer_chunks: + json_data = json.dumps(chunk.model_dump(), default=str, ensure_ascii=False) + f.write(json_data + "\n") + print(f"调测信息已保存到文件:{file_path}") + except Exception as e: + print(f"调测信息保存失败:{e}") + + class WorkflowTest(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() @@ -63,8 +75,12 @@ class WorkflowTest(unittest.TestCase): elif checker is not None: checker(self.invoke_workflow(inputs, context, flow)) - def test_seq_exec_stream_workflow_with_tracer(self): + """ + start -> a -> b -> end + """ + tracer_chunks = [] + async def stream_workflow(): flow = create_flow() flow.set_start_comp("start", MockStartNode("start"), @@ -102,6 +118,70 @@ class WorkflowTest(unittest.TestCase): flow.add_connection("a", "b") flow.add_connection("b", "end") + expected_datas_model = { + "a": node_a_expected_datas_model, + "b": node_b_expected_datas_model + } + index_dict = {key: 0 for key in expected_datas_model.keys()} + + async for chunk in flow.stream({"a": 1, "b": "haha"}, create_context_with_tracer()): + if not isinstance(chunk, TraceSchema): + node_id = chunk.node_id + index = index_dict[node_id] + assert chunk == expected_datas_model[node_id][index], f"Mismatch at node {node_id} index {index}" + logger.info(f"stream chunk: {chunk}") + index_dict[node_id] = index_dict[node_id] + 1 + else: + print(f"stream chunk: {chunk}") + tracer_chunks.append(chunk) + + self.loop.run_until_complete(stream_workflow()) + record_tracer_info(tracer_chunks, "test_seq_exec_stream_workflow_with_tracer.json") + + def test_parallel_exec_stream_workflow_with_tracer(self): + """ + start -> a | b -> end + """ + tracer_chunks = [] + + async def stream_workflow(): + flow = create_flow() + flow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "a": "${user.inputs.a}", + "b": "${user.inputs.b}", + "c": 1, + "d": [1, 2, 3]}) + + node_a_expected_datas = [ + {"node_id": "a", "id": 1, "data": "1"}, + {"node_id": "a", "id": 2, "data": "2"}, + ] + node_a_expected_datas_model = [CustomSchema(**item) for item in node_a_expected_datas] + flow.add_workflow_comp("a", StreamNodeWithTracer("a", node_a_expected_datas), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + + node_b_expected_datas = [ + {"node_id": "b", "id": 1, "data": "1"}, + {"node_id": "b", "id": 2, "data": "2"}, + ] + node_b_expected_datas_model = [CustomSchema(**item) for item in node_b_expected_datas] + flow.add_workflow_comp("b", StreamNodeWithTracer("b", node_b_expected_datas), + inputs_schema={ + "ba": "${start.b}", + "bc": "${start.d}"}) + + flow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${b.ba}"}) + + flow.add_connection("start", "a") + flow.add_connection("start", "b") + flow.add_connection("a", "end") + flow.add_connection("b", "end") + expected_datas_model = { "a": node_a_expected_datas_model, "b": node_b_expected_datas_model @@ -116,5 +196,168 @@ class WorkflowTest(unittest.TestCase): index_dict[node_id] = index_dict[node_id] + 1 else: print(f"stream chunk: {chunk}") + tracer_chunks.append(chunk) + + self.loop.run_until_complete(stream_workflow()) + record_tracer_info(tracer_chunks, "test_parallel_exec_stream_workflow_with_tracer.json") + + def test_sub_stream_workflow_with_tracer(self): + """ + main_workflow: start -> a(sub_workflow) -> end + sub_workflow: sub_start -> sub_a -> sub_end + """ + tracer_chunks = [] + + async def stream_workflow(): + # sub_workflow: start->a(stream out)->end + sub_workflow = create_flow() + sub_workflow.set_start_comp("sub_start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + expected_datas = [ + {"node_id": "sub_start", "id": 1, "data": "1"}, + {"node_id": "sub_start", "id": 2, "data": "2"}, + ] + expected_datas_model = [CustomSchema(**item) for item in expected_datas] + + sub_workflow.add_workflow_comp("sub_a", StreamNodeWithTracer("a", expected_datas), + inputs_schema={ + "aa": "${sub_start.a}", + "ac": "${sub_start.c}"}) + sub_workflow.set_end_comp("sub_end", MockEndNode("end"), + inputs_schema={ + "result": "${sub_a.aa}"}) + sub_workflow.add_connection("sub_start", "sub_a") + sub_workflow.add_connection("sub_a", "sub_end") + + # main_workflow: start->a(sub workflow)->end + main_workflow = create_flow() + main_workflow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + + main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + main_workflow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${a.aa}"}) + main_workflow.add_connection("start", "a") + main_workflow.add_connection("a", "end") + + index = 0 + async for chunk in main_workflow.stream({"a": 1, "b": "haha"}, create_context_with_tracer()): + if not isinstance(chunk, TraceSchema): + assert chunk == expected_datas_model[index], f"Mismatch at index {index}" + logger.info(f"stream chunk: {chunk}") + index += 1 + else: + print(f"stream chunk: {chunk}") + tracer_chunks.append(chunk) + + self.loop.run_until_complete(stream_workflow()) + record_tracer_info(tracer_chunks, "test_sub_stream_workflow_with_tracer.json") + + def test_nested_stream_workflow_with_tracer(self): + """ + main_workflow: start -> a(sub_workflow) | b -> end + sub_workflow: sub_start -> sub_a -> sub_end + """ + tracer_chunks = [] + + async def stream_workflow(): + # sub_workflow: start->a(stream out)->end + sub_workflow = create_flow() + sub_workflow.set_start_comp("sub_start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + expected_datas = [ + {"node_id": "sub_start", "id": 1, "data": "1"}, + {"node_id": "sub_start", "id": 2, "data": "2"}, + ] + expected_datas_model = [CustomSchema(**item) for item in expected_datas] + + sub_workflow.add_workflow_comp("sub_a", StreamNodeWithTracer("a", expected_datas), + inputs_schema={ + "aa": "${sub_start.a}", + "ac": "${sub_start.c}"}) + sub_workflow.set_end_comp("sub_end", MockEndNode("end"), + inputs_schema={ + "result": "${sub_a.aa}"}) + sub_workflow.add_connection("sub_start", "sub_a") + sub_workflow.add_connection("sub_a", "sub_end") + + # main_workflow: start->a(sub workflow) | b ->end + main_workflow = create_flow() + main_workflow.set_start_comp("start", MockStartNode("start"), + inputs_schema={ + "a": "${a}", + "b": "${b}", + "c": 1, + "d": [1, 2, 3]}) + + main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + inputs_schema={ + "aa": "${start.a}", + "ac": "${start.c}"}) + + node_b_expected_datas = [ + {"node_id": "b", "id": 1, "data": "1"}, + {"node_id": "b", "id": 2, "data": "2"}, + ] + node_b_expected_datas_model = [CustomSchema(**item) for item in node_b_expected_datas] + main_workflow.add_workflow_comp("b", StreamNodeWithTracer("b", node_b_expected_datas), + inputs_schema={ + "ba": "${start.b}", + "bc": "${start.d}"}) + + main_workflow.set_end_comp("end", MockEndNode("end"), + inputs_schema={ + "result": "${a.aa}"}) + main_workflow.add_connection("start", "a") + main_workflow.add_connection("a", "end") + main_workflow.add_connection("start", "b") + main_workflow.add_connection("b", "end") + + async for chunk in main_workflow.stream({"a": 1, "b": "haha"}, create_context_with_tracer()): + if isinstance(chunk, TraceSchema): + print(f"stream chunk: {chunk}") + tracer_chunks.append(chunk) self.loop.run_until_complete(stream_workflow()) + for chunk in tracer_chunks: + payload = chunk.payload + payload.get("parentInvokeId") + payload.get("parentNodeId") + if payload.get("invokeId") == "start": + assert payload.get("parentInvokeId") == None, f"start node parent_invoke_id should be None" + assert payload.get("parentNodeId") == "", f"a node parent_node_id should be ''" + elif payload.get("invokeId") == "a": + assert payload.get("parentInvokeId") == "start", f"a node parent_invoke_id should be start" + assert payload.get("parentNodeId") == "", f"a node parent_node_id should be ''" + elif payload.get("invokeId") == "b": + assert payload.get("parentInvokeId") == "a", f"b node parent_invoke_id should be a" + assert payload.get("parentNodeId") == "", f"b node parent_node_id should be ''" + elif payload.get("invokeId") == "end": + assert payload.get("parentInvokeId") == "b", f"b node parent_invoke_id should be a" + assert payload.get("parentNodeId") == "", f"b node parent_node_id should be ''" + elif payload.get("invokeId") == "sub_start": + assert payload.get("parentInvokeId") == None, f"sub_start node parent_invoke_id should be None" + assert payload.get("parentNodeId") == "a", f"sub_start node parent_node_id should be a" + elif payload.get("invokeId") == "sub_a": + assert payload.get("parentInvokeId") == "sub_start", f"sub_a node parent_invoke_id should be sub_start" + assert payload.get("parentNodeId") == "a", f"sub_a node parent_node_id should be a" + elif payload.get("invokeId") == "sub_end": + assert payload.get("parentInvokeId") == "sub_a", f"sub_end node parent_invoke_id should be sub_a" + assert payload.get("parentNodeId") == "a", f"sub_end node parent_node_id should be a" + record_tracer_info(tracer_chunks, "test_nested_stream_workflow_with_tracer.json")