diff --git a/jiuwen/core/component/condition/array.py b/jiuwen/core/component/condition/array.py index c57155e4072b9ce9f51a48acc737e315b27150ee..64103a732622cba8c1cf3c69d66485c1ab51f143 100644 --- a/jiuwen/core/component/condition/array.py +++ b/jiuwen/core/component/condition/array.py @@ -14,7 +14,6 @@ DEFAULT_PATH_ARRAY_LOOP_VAR = "arrLoopVar" class ArrayCondition(Condition): def __init__(self, node_id: str, arrays: dict[str, Union[str, list[Any]]]): super().__init__() - self._node_id = node_id self._arrays = arrays self._index_path = node_id + NESTED_PATH_SPLIT + INDEX self._arrays_root = node_id + NESTED_PATH_SPLIT + DEFAULT_PATH_ARRAY_LOOP_VAR diff --git a/jiuwen/core/component/condition/number.py b/jiuwen/core/component/condition/number.py index 1c3b9dcdd75545f46caec30a86ee61f86964367e..5dc56e60f2477fb5c7156955cd25bb3493ea1eb6 100644 --- a/jiuwen/core/component/condition/number.py +++ b/jiuwen/core/component/condition/number.py @@ -8,9 +8,9 @@ from jiuwen.core.context.utils import NESTED_PATH_SPLIT class NumberCondition(Condition): - def __init__(self, node_id: str, limit: Union[str, int], index_path: str = None): + def __init__(self, node_id: str, limit: Union[str, int]): super().__init__() - self._index_path = index_path if index_path else node_id + NESTED_PATH_SPLIT + INDEX + self._index_path = node_id + NESTED_PATH_SPLIT + INDEX self._limit = limit self._node_id = node_id diff --git a/jiuwen/core/component/loop_callback/intermediate_loop_var.py b/jiuwen/core/component/loop_callback/intermediate_loop_var.py index 3fb72f0ced67624fca5102daf766c635eebd750f..cdb942cea8cd3b35098e17113ef0a978f7cf81e6 100644 --- a/jiuwen/core/component/loop_callback/intermediate_loop_var.py +++ b/jiuwen/core/component/loop_callback/intermediate_loop_var.py @@ -3,18 +3,15 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Union, Any -from jiuwen.core.component.loop_callback.loop_callback import LoopCallback +from jiuwen.core.component.loop_callback.loop_callback import LoopCallback, INTERMEDIATE_LOOP_VAR from jiuwen.core.context.utils import NESTED_PATH_SPLIT, is_ref_path, extract_origin_key class IntermediateLoopVarCallback(LoopCallback): - def __init__(self, node_id: str, intermediate_loop_var: dict[str, Union[str, Any]], - intermediate_loop_var_root: str = None): + def __init__(self, node_id: str, intermediate_loop_var: dict[str, Union[str, Any]]): super().__init__() - self._node_id = node_id self._intermediate_loop_var = intermediate_loop_var - self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root \ - else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" + self._intermediate_loop_var_root = node_id + NESTED_PATH_SPLIT + INTERMEDIATE_LOOP_VAR def first_in_loop(self): for key, value in self._intermediate_loop_var.items(): diff --git a/jiuwen/core/component/loop_callback/loop_callback.py b/jiuwen/core/component/loop_callback/loop_callback.py index 0db2a63783fff21913be97a3630d10e44e3f83a8..8bc4dc87d64a8a2a5955d92df8e4c8ab4aeb13d5 100644 --- a/jiuwen/core/component/loop_callback/loop_callback.py +++ b/jiuwen/core/component/loop_callback/loop_callback.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from jiuwen.core.context.context import ContextSetter +INTERMEDIATE_LOOP_VAR = "intermediateLoopVar" class LoopCallback(ContextSetter, ABC): diff --git a/jiuwen/core/component/loop_callback/output.py b/jiuwen/core/component/loop_callback/output.py index 613c0cae1f0e6d0385a03c485043d7a38031c759..ef282703fe9bb90117bce5a66d4e1b88e885098e 100644 --- a/jiuwen/core/component/loop_callback/output.py +++ b/jiuwen/core/component/loop_callback/output.py @@ -3,19 +3,19 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from typing import Any -from jiuwen.core.component.loop_callback.loop_callback import LoopCallback +from jiuwen.core.component.loop_callback.loop_callback import LoopCallback, INTERMEDIATE_LOOP_VAR from jiuwen.core.context.utils import is_ref_path, extract_origin_key, NESTED_PATH_SPLIT +ROUND = "round" class OutputCallback(LoopCallback): - def __init__(self, node_id: str, outputs_format: dict[str, Any], round_result_root: str = None, - result_root: str = None, intermediate_loop_var_root: str = None): + def __init__(self, node_id: str, outputs_format: dict[str, Any]): super().__init__() self._node_id = node_id self._outputs_format = outputs_format - self._round_result_root = round_result_root if round_result_root else node_id + NESTED_PATH_SPLIT + "round" - self._result_root = result_root if result_root else node_id - self._intermediate_loop_var_root = intermediate_loop_var_root if intermediate_loop_var_root else node_id + NESTED_PATH_SPLIT + "intermediateLoopVar" + self._round_result_root = node_id + NESTED_PATH_SPLIT + ROUND + self._result_root = node_id + self._intermediate_loop_var_root = node_id + NESTED_PATH_SPLIT + INTERMEDIATE_LOOP_VAR def _generate_results(self, results: list[(str, Any)]): for key, value in self._outputs_format.items(): diff --git a/jiuwen/core/component/loop_comp.py b/jiuwen/core/component/loop_comp.py index c324d206bd2da26e046067ede20e71eef922335a..8e9b427e9ba702eec41010d6da5972d140a124f0 100644 --- a/jiuwen/core/component/loop_comp.py +++ b/jiuwen/core/component/loop_comp.py @@ -11,7 +11,6 @@ from jiuwen.core.component.condition.condition import Condition, AlwaysTrue, Fun from jiuwen.core.component.condition.expression import ExpressionCondition from jiuwen.core.component.loop_callback.loop_callback import LoopCallback from jiuwen.core.component.loop_callback.loop_id import LoopIdCallback -from jiuwen.core.component.set_variable_comp import SetVariableComponent from jiuwen.core.context.config import WorkflowConfig from jiuwen.core.context.context import Context, ContextSetter, NodeContext from jiuwen.core.context.utils import NESTED_PATH_SPLIT @@ -92,8 +91,7 @@ class LoopComponent(WorkflowComponent, LoopController, ContextSetter, Executable def __init__(self, node_id: str, body: Executable, new_graph: Graph, condition: Union[str, Callable[[], bool], Condition] = None, - break_nodes: list[BreakComponent] = None, callbacks: list[LoopCallback] = None, - set_variable_components: list[SetVariableComponent] = None): + break_nodes: list[BreakComponent] = None, callbacks: list[LoopCallback] = None): ContextSetter.__init__(self) self._node_id = node_id self._body = body @@ -131,7 +129,6 @@ class LoopComponent(WorkflowComponent, LoopController, ContextSetter, Executable self._context_setters: list[ContextSetter] = [self, self._condition] self._context_setters.extend(self._callbacks) - self._context_setters.extend(set_variable_components) def init(self): self._context.state().update_comp({self._context_root + NESTED_PATH_SPLIT + BROKEN: False}) diff --git a/jiuwen/core/component/set_variable_comp.py b/jiuwen/core/component/set_variable_comp.py index d6bca83722ebe109fa5fd1362a63f6def0c68ff8..a83d91825dfceb9a1a368344eefa90a5ced732ce 100644 --- a/jiuwen/core/component/set_variable_comp.py +++ b/jiuwen/core/component/set_variable_comp.py @@ -11,16 +11,12 @@ from jiuwen.core.context.utils import extract_origin_key, is_ref_path from jiuwen.core.graph.executable import Executable, Input, Output -class SetVariableComponent(WorkflowComponent, Executable, ContextSetter): +class SetVariableComponent(WorkflowComponent, Executable): - def __init__(self, node_id: str, variable_mapping: dict[str, Any]): + def __init__(self, variable_mapping: dict[str, Any]): super().__init__() - self._node_id = node_id self._variable_mapping = variable_mapping - def set_context(self, context: Context): - self._context = NodeContext(context, self._node_id) - async def invoke(self, inputs: Input, context: Context) -> Output: for left, right in self._variable_mapping.items(): left_ref_str = extract_origin_key(left) @@ -28,9 +24,9 @@ class SetVariableComponent(WorkflowComponent, Executable, ContextSetter): left_ref_str = left if isinstance(right, str) and is_ref_path(right): ref_str = extract_origin_key(right) - self._context.state().update_io({left_ref_str: self._context.state().get(ref_str)}) + context.state().update_io({left_ref_str: context.state().get(ref_str)}) continue - self._context.state().update_io({left_ref_str: right}) + context.state().update_io({left_ref_str: right}) return None diff --git a/jiuwen/core/component/workflow_comp.py b/jiuwen/core/component/workflow_comp.py index df7055c282483ea441980e6812104e16f640bbd7..f86e70d7df7ac8eea5e9bc92e399037d96f4d590 100644 --- a/jiuwen/core/component/workflow_comp.py +++ b/jiuwen/core/component/workflow_comp.py @@ -12,9 +12,8 @@ from jiuwen.core.workflow.base import Workflow class ExecWorkflowComponent(WorkflowComponent, Executable, ExecWorkflowBase): - def __init__(self, node_id: str, sub_workflow: Workflow): + def __init__(self, sub_workflow: Workflow): super().__init__() - self.node_id = node_id self._sub_workflow = sub_workflow async def invoke(self, inputs: Input, context: Context) -> Output: diff --git a/tests/unit_tests/tracer/test_mock_node_with_tracer.py b/tests/unit_tests/tracer/test_mock_node_with_tracer.py index 07100f7f2077e104acdf3557252f86caed2b6aaf..757c57247a4699878ccf9e09d22b03e7649a620e 100644 --- a/tests/unit_tests/tracer/test_mock_node_with_tracer.py +++ b/tests/unit_tests/tracer/test_mock_node_with_tracer.py @@ -39,12 +39,3 @@ class StreamNodeWithTracer(MockNodeBase): print("StreamNode: output = " + str(inputs)) return inputs - -class CompositeWorkflowNode(MockNodeBase): - def __init__(self, node_id: str, sub_workflow: Workflow): - super().__init__(node_id) - self._node_id = node_id - self._sub_workflow = sub_workflow - - async def invoke(self, inputs: Input, context: Context) -> Output: - return await self._sub_workflow.invoke(inputs, context) diff --git a/tests/unit_tests/tracer/test_workflow.py b/tests/unit_tests/tracer/test_workflow.py index b21b69095729be14d84fd9f267cd7d8abf69151c..7b09af023b24f9082c1900f39619b35899fbd0c1 100644 --- a/tests/unit_tests/tracer/test_workflow.py +++ b/tests/unit_tests/tracer/test_workflow.py @@ -21,7 +21,7 @@ fake_exception_module.JiuWenBaseException = Mock() sys.modules["jiuwen.core.common.logging.base"] = fake_base sys.modules["jiuwen.core.common.exception.base"] = fake_exception_module -from tests.unit_tests.tracer.test_mock_node_with_tracer import StreamNodeWithTracer, CompositeWorkflowNode +from tests.unit_tests.tracer.test_mock_node_with_tracer import StreamNodeWithTracer from jiuwen.core.common.logging.base import logger import asyncio @@ -250,7 +250,7 @@ class WorkflowTest(unittest.TestCase): "c": 1, "d": [1, 2, 3]}) - main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + main_workflow.add_workflow_comp("a", ExecWorkflowComponent(sub_workflow), inputs_schema={ "aa": "${start.a}", "ac": "${start.c}"}) @@ -314,7 +314,7 @@ class WorkflowTest(unittest.TestCase): "c": 1, "d": [1, 2, 3]}) - main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + main_workflow.add_workflow_comp("a", ExecWorkflowComponent(sub_workflow), inputs_schema={ "aa": "${start.a}", "ac": "${start.c}"}) @@ -414,7 +414,7 @@ class WorkflowTest(unittest.TestCase): "c": 1, "d": [1, 2, 3]}) - main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + main_workflow.add_workflow_comp("a", ExecWorkflowComponent(sub_workflow), inputs_schema={ "aa": "${start.a}", "ac": "${start.c}"}) @@ -424,7 +424,7 @@ class WorkflowTest(unittest.TestCase): {"node_id": "b", "id": 2, "data": "2"}, ] - main_workflow.add_workflow_comp("b", ExecWorkflowComponent("b", sub_workflow_2), + main_workflow.add_workflow_comp("b", ExecWorkflowComponent(sub_workflow_2), inputs_schema={ "aa": "${start.a}", "ac": "${start.c}"}) @@ -466,9 +466,8 @@ class WorkflowTest(unittest.TestCase): loop_group.add_workflow_comp("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) loop_group.add_workflow_comp("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - set_variable_component = SetVariableComponent("3", - {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_workflow_comp("3", set_variable_component) + loop_group.add_workflow_comp("3", SetVariableComponent( + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) loop_group.start_comp("1") loop_group.end_comp("3") loop_group.add_connection("1", "2") @@ -480,8 +479,7 @@ class WorkflowTest(unittest.TestCase): {"user_var": "${input_number}"}) loop = LoopComponent("l", loop_group, PregelGraph(), ArrayCondition("l", {"item": "${a.array}"}), - callbacks=[output_callback, intermediate_callback], - set_variable_components=[set_variable_component]) + callbacks=[output_callback, intermediate_callback]) flow.add_workflow_comp("l", loop) diff --git a/tests/unit_tests/workflow/test_checkpoint.py b/tests/unit_tests/workflow/test_checkpoint.py index 43d87a3955acafa588f9cd1dd0ef252a3da29e03..ca5646a15041bee7d2521f07369fc91b4d57e09c 100644 --- a/tests/unit_tests/workflow/test_checkpoint.py +++ b/tests/unit_tests/workflow/test_checkpoint.py @@ -139,7 +139,7 @@ class CheckpointTest(unittest.TestCase): "b": "${user.inputs.b}", "c": 1, "d": [1, 2, 3]}) - flow.add_workflow_comp("a", ExecWorkflowComponent("a", subflow), + flow.add_workflow_comp("a", ExecWorkflowComponent(subflow), inputs_schema={ "aa": "${start.a}", "ac": "${start.c}"}) @@ -179,9 +179,8 @@ class CheckpointTest(unittest.TestCase): loop_group.add_workflow_comp("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) loop_group.add_workflow_comp("2", AddTenNode4Cp("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - set_variable_component = SetVariableComponent("3", - {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_workflow_comp("3", set_variable_component) + loop_group.add_workflow_comp("3", SetVariableComponent( + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) loop_group.start_comp("1") loop_group.end_comp("3") loop_group.add_connection("1", "2") @@ -192,8 +191,7 @@ class CheckpointTest(unittest.TestCase): {"user_var": "${input_number}"}) loop = LoopComponent("l", loop_group, PregelGraph(), ArrayCondition("l", {"item": "${a.array}"}), - callbacks=[output_callback, intermediate_callback], - set_variable_components=[set_variable_component]) + callbacks=[output_callback, intermediate_callback]) flow.add_workflow_comp("l", loop) @@ -269,9 +267,8 @@ class CheckpointTest(unittest.TestCase): loop_group.add_workflow_comp("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) loop_group.add_workflow_comp("2", InteractiveNode4Cp("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - set_variable_component = SetVariableComponent("3", - {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_workflow_comp("3", set_variable_component) + loop_group.add_workflow_comp("3", SetVariableComponent( + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) loop_group.start_comp("1") loop_group.end_comp("3") loop_group.add_connection("1", "2") @@ -282,8 +279,7 @@ class CheckpointTest(unittest.TestCase): {"user_var": "${input_number}"}) loop = LoopComponent("l", loop_group, PregelGraph(), ArrayCondition("l", {"item": "${a.array}"}), - callbacks=[output_callback, intermediate_callback], - set_variable_components=[set_variable_component]) + callbacks=[output_callback, intermediate_callback]) flow.add_workflow_comp("l", loop) diff --git a/tests/unit_tests/workflow/test_workflow.py b/tests/unit_tests/workflow/test_workflow.py index cdd809b2a54a4c52a1154f53ef649307c6003367..7902fae2d7d30022f76d391de61e10f5e5e77eaa 100644 --- a/tests/unit_tests/workflow/test_workflow.py +++ b/tests/unit_tests/workflow/test_workflow.py @@ -196,32 +196,29 @@ class WorkflowTest(unittest.TestCase): inputs_schema={"array_result": "${b.array_result}", "user_var": "${b.user_var}"}) flow.add_workflow_comp("a", CommonNode("a"), inputs_schema={"array": "${input_array}"}) - flow.add_workflow_comp("b", CommonNode("b"), - inputs_schema={"array_result": "${l.results}", "user_var": "${l.user_var}"}) # create loop: (1->2->3) loop_group = LoopGroup(WorkflowConfig(), PregelGraph()) loop_group.add_workflow_comp("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) loop_group.add_workflow_comp("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - set_variable_component = SetVariableComponent("3", - {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_workflow_comp("3", set_variable_component) + loop_group.add_workflow_comp("3", SetVariableComponent( + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) loop_group.start_comp("1") loop_group.end_comp("3") loop_group.add_connection("1", "2") loop_group.add_connection("2", "3") - output_callback = OutputCallback("l", - {"results": "${1.result}", "user_var": "${l.intermediateLoopVar.user_var}"}) intermediate_callback = IntermediateLoopVarCallback("l", {"user_var": "${input_number}"}) + output_callback = OutputCallback("l", + {"results": "${1.result}", + "user_var": "${l.intermediateLoopVar.user_var}"}) + flow.add_workflow_comp("l", LoopComponent("l", loop_group, PregelGraph(), + ArrayCondition("l", {"item": "${a.array}"}), + callbacks=[output_callback, intermediate_callback])) - loop = LoopComponent("l", loop_group, PregelGraph(), ArrayCondition("l", {"item": "${a.array}"}), - callbacks=[output_callback, intermediate_callback], - set_variable_components=[set_variable_component]) - - flow.add_workflow_comp("l", loop) - + flow.add_workflow_comp("b", CommonNode("b"), + inputs_schema={"array_result": "${l.results}", "user_var": "${l.user_var}"}) # s->a->(1->2->3)->b->e flow.add_connection("s", "a") flow.add_connection("a", "l") @@ -249,9 +246,8 @@ class WorkflowTest(unittest.TestCase): loop_group.add_workflow_comp("1", AddTenNode("1"), inputs_schema={"source": "${l.arrLoopVar.item}"}) loop_group.add_workflow_comp("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - set_variable_component = SetVariableComponent("3", - {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_workflow_comp("3", set_variable_component) + loop_group.add_workflow_comp("3", SetVariableComponent( + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) break_node = BreakComponent() loop_group.add_workflow_comp("4", break_node) loop_group.start_nodes(["1"]) @@ -265,8 +261,7 @@ class WorkflowTest(unittest.TestCase): {"user_var": "${input_number}"}) loop = LoopComponent("l", loop_group, PregelGraph(), ArrayCondition("l", {"item": "${a.array}"}), - callbacks=[output_callback, intermediate_callback], break_nodes=[break_node], - set_variable_components=[set_variable_component]) + callbacks=[output_callback, intermediate_callback], break_nodes=[break_node]) flow.add_workflow_comp("l", loop) @@ -296,9 +291,8 @@ class WorkflowTest(unittest.TestCase): loop_group.add_workflow_comp("1", AddTenNode("1"), inputs_schema={"source": "${l.index}"}) loop_group.add_workflow_comp("2", AddTenNode("2"), inputs_schema={"source": "${l.intermediateLoopVar.user_var}"}) - set_variable_component = SetVariableComponent("3", - {"${l.intermediateLoopVar.user_var}": "${2.result}"}) - loop_group.add_workflow_comp("3", set_variable_component) + loop_group.add_workflow_comp("3", SetVariableComponent( + {"${l.intermediateLoopVar.user_var}": "${2.result}"})) loop_group.start_nodes(["1"]) loop_group.end_nodes(["3"]) loop_group.add_connection("1", "2") @@ -309,8 +303,7 @@ class WorkflowTest(unittest.TestCase): {"user_var": "${input_number}"}) loop = LoopComponent("l", loop_group, PregelGraph(), NumberCondition("l", "${loop_number}"), - callbacks=[output_callback, intermediate_callback], - set_variable_components=[set_variable_component]) + callbacks=[output_callback, intermediate_callback]) flow.add_workflow_comp("l", loop) @@ -505,7 +498,7 @@ class WorkflowTest(unittest.TestCase): "c": 1, "d": [1, 2, 3]}) - main_workflow.add_workflow_comp("a", ExecWorkflowComponent("a", sub_workflow), + main_workflow.add_workflow_comp("a", ExecWorkflowComponent(sub_workflow), inputs_schema={ "aa": "${start.a}", "ac": "${start.c}"}) @@ -541,7 +534,7 @@ class WorkflowTest(unittest.TestCase): # flow2: start->a1|composite->end flow1.add_workflow_comp("a1", Node1("a1"), inputs_schema={"value": "${start.a1}"}) - flow1.add_workflow_comp("composite", ExecWorkflowComponent("composite", flow2), + flow1.add_workflow_comp("composite", ExecWorkflowComponent(flow2), inputs_schema={"result": "${start.a2}"}) flow1.set_end_comp("end", MockEndNode("end"), inputs_schema={"b1": "${a1.value}", "b2": "${composite.result}"})