diff --git a/jiuwen/core/component/end_comp.py b/jiuwen/core/component/end_comp.py index 29e4e9552a245b75bf974304d25fa97674fa857a..9bc8b93ee44251d83e0b71c2f92023fac0299e14 100644 --- a/jiuwen/core/component/end_comp.py +++ b/jiuwen/core/component/end_comp.py @@ -1,10 +1,6 @@ #!/usr/bin/python3.10 # coding: utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved -import uuid -from abc import ABC -from copy import deepcopy -from dataclasses import dataclass, field from typing import AsyncIterator from jiuwen.core.common.logging.base import logger @@ -15,14 +11,8 @@ from jiuwen.core.component.base import WorkflowComponent from jiuwen.core.context.context import Context from jiuwen.core.graph.executable import Executable, Input, Output from jiuwen.core.stream.base import StreamCode -import time -async def get_stream_data(stream_code: str, data: dict, index: int, context: Context,): - - - stream_final_data = dict(type=stream_code, index = index, payload=data) - - await context.stream_writer_manager.get_output_writer().write(stream_final_data) +STREAM_CACHE_KEY = "_stream_cache_key" class End(Executable,WorkflowComponent): @@ -31,50 +21,66 @@ class End(Executable,WorkflowComponent): self.node_id = node_id self.node_name = node_name self.conf = conf - self.template = conf["responseTemplate"] + self.template = conf["responseTemplate"] if "responseTemplate" in conf and len(conf["responseTemplate"])>0 else None async def invoke(self, inputs: Input, context: Context) -> Output: - answer = TemplateUtils.render_template(self.template, inputs.get(USER_FIELDS)) + user_fields = inputs.get(USER_FIELDS) + if self.template: + answer = TemplateUtils.render_template(self.template, user_fields) + output = {} + else: + answer = "" + output = user_fields + return { + "responseContent": answer, + "output": output + } + - final_output = dict(responseContent=answer) + async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: try: - response_mode = inputs.get("response_mode") - if response_mode is not None and response_mode == "streaming": - response_list = TemplateUtils.render_template_to_list(self.template) - index = 0 - for res in response_list: - if res.startswith("{{") and res.endswith("}}"): + if self.template: + response_list = TemplateUtils.render_template_to_list(self.template) + index = 0 + for res in response_list: + if res.startswith("{{") and res.endswith("}}"): param_name = res[2:-2] param_value = inputs.get(USER_FIELDS).get(param_name) if param_value is None: - continue - await get_stream_data(StreamCode.PARTIAL_CONTENT.name,dict(answer=param_value), index, context) - - - else: - await get_stream_data(StreamCode.PARTIAL_CONTENT.name, - dict(answer=res), index, context) - - index += 1 - + continue + yield dict(type=StreamCode.PARTIAL_CONTENT.name, index=index, payload=dict(answer=param_value)) + else: + yield dict(type=StreamCode.PARTIAL_CONTENT.name, index=index, payload=dict(answer=res)) + index += 1 + final_output = TemplateUtils.render_template(self.template, inputs.get(USER_FIELDS)) + else: + index = 0 + for res in inputs.get(USER_FIELDS): + yield dict(type=StreamCode.PARTIAL_CONTENT.name, index=index, payload=dict(outputs={USER_FIELDS: res})) + index += 1 + final_output = dict(outputs={USER_FIELDS: inputs.get(USER_FIELDS)}) final_index = 0 - await get_stream_data(StreamCode.MESSAGE_END.name, dict(answer=final_output), final_index, context) - await get_stream_data(StreamCode.WORKFLOW_END.name, dict(answer=final_output), final_index, context) - await get_stream_data(StreamCode.FINISH.name, dict(answer=final_output), final_index, context) + + yield dict(type=StreamCode.MESSAGE_END.name, index=final_index, payload=dict(outputs={USER_FIELDS: final_output})) + yield dict(type=StreamCode.WORKFLOW_END.name, index=final_index, payload=dict(outputs={USER_FIELDS: final_output})) + yield dict(type=StreamCode.FINISH.name, index=final_index, payload=dict(outputs={USER_FIELDS: final_output})) except Exception as e: logger.info("stream output error: {}".format(e)) - - return final_output - - async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - pass - async def collect(self, inputs: AsyncIterator[Input], contex: Context) -> Output: pass async def transform(self, inputs: AsyncIterator[Input], context: Context) -> AsyncIterator[Output]: - pass + # 异步遍历输入迭代器 + index = 0 + stream_cache_key = self.node_id + STREAM_CACHE_KEY + stream_cache_value = {} + async for input_item in inputs: + # 将当前输入项存入context + stream_cache_value.update(input_item) + index += 1 + yield dict(type=StreamCode.PARTIAL_CONTENT.name, index=index, payload=dict(answer=input_item)) + context.state().update({stream_cache_key: stream_cache_value}) async def interrupt(self, message: dict): pass diff --git a/jiuwen/core/workflow/base.py b/jiuwen/core/workflow/base.py index 67649a521306234d81e267de1afb1c39a1699e7e..14ffac4ab866e4c164a487bf02cdefbcea0924bc 100644 --- a/jiuwen/core/workflow/base.py +++ b/jiuwen/core/workflow/base.py @@ -60,7 +60,8 @@ class BaseWorkFlow: stream_outputs_schema: dict = None, stream_inputs_transformer: Transformer = None, stream_outputs_transformer: Transformer = None, - comp_ability: list[ComponentAbility] = None + comp_ability: list[ComponentAbility] = None, + response_mode: str = None ) -> Self: if not isinstance(workflow_comp, WorkflowComponent): workflow_comp = self._convert_to_component(workflow_comp) @@ -79,6 +80,12 @@ class BaseWorkFlow: if ability in [ComponentAbility.STREAM, ComponentAbility.TRANSFORM, ComponentAbility.COLLECT]: if not wait_for_all: raise JiuWenBaseException(-1, "stream components need to wait for all") + if response_mode is not None: + if "streaming" == response_mode: + self._workflow_config.comp_abilities[ + comp_id] = [ComponentAbility.STREAM, ComponentAbility.TRANSFORM] + else: + self._workflow_config.comp_abilities[comp_id] = [ComponentAbility.INVOKE] return self def start_comp( @@ -151,6 +158,7 @@ class Workflow(BaseWorkFlow): stream_outputs_schema: dict = None, stream_inputs_transformer: Transformer = None, stream_outputs_transformer: Transformer = None, + response_mode: str = None, ) -> Self: self.add_workflow_comp(end_comp_id, component, wait_for_all=False, inputs_schema=inputs_schema, outputs_schema=outputs_schema, @@ -159,7 +167,8 @@ class Workflow(BaseWorkFlow): stream_inputs_schema=stream_inputs_schema, stream_outputs_schema=stream_outputs_schema, stream_inputs_transformer=stream_inputs_transformer, - stream_outputs_transformer=stream_outputs_transformer + stream_outputs_transformer=stream_outputs_transformer, + response_mode=response_mode ) self.end_comp(end_comp_id) self._end_comp_id = end_comp_id diff --git a/tests/unit_tests/workflow/test_end.py b/tests/unit_tests/workflow/test_end.py index 38a2c28334cf24ead5bac2f26e5dbf5f1db7ebbb..2e33a3a887e40ccdcc8da092fa737dd8e57f5549 100644 --- a/tests/unit_tests/workflow/test_end.py +++ b/tests/unit_tests/workflow/test_end.py @@ -9,17 +9,18 @@ from jiuwen.core.component.start_comp import Start from jiuwen.core.context.config import Config -from jiuwen.core.context.context import Context -from jiuwen.core.context.memory.base import InMemoryState +from jiuwen.core.context.context import Context, WorkflowContext +from jiuwen.core.context.state import InMemoryState from jiuwen.core.graph.base import Graph from jiuwen.core.workflow.base import WorkflowConfig, Workflow +from jiuwen.core.workflow.workflow_config import ComponentAbility from jiuwen.graph.pregel.graph import PregelGraph from tests.unit_tests.tracer.test_workflow import create_context_with_tracer -from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode, Node1, StreamNode +from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode, Node1, StreamNode, StreamCompNode def create_context() -> Context: - return Context(config=Config(), state=InMemoryState(), store=None) + return WorkflowContext(config=Config(), state=InMemoryState(), store=None) def create_graph() -> Graph: @@ -50,7 +51,7 @@ class EndNodeTest(unittest.TestCase): elif checker is not None: checker(self.invoke_workflow(inputs, context, flow)) - def test_simple_workflow(self): + def test_simple_template_workflow(self): # flow1: start -> a -> end flow = create_flow() flow.set_start_comp("start", Start("start",{"userFields":{"inputs":[],"outputs":[]},"systemFields":{"input":[{"id":"query","type":"String","required":"true","sourceType":"ref"}]}}), @@ -69,40 +70,82 @@ class EndNodeTest(unittest.TestCase): "response_mode": "${start.userFields.response_node}"}) flow.add_connection("start", "a") flow.add_connection("a", "end") - self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={'responseContent': 'hello:haha'}) + self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={'output': {}, 'responseContent': 'hello:haha'}) + def test_simple_output_schema_workflow(self): + # flow1: start -> a -> end + flow = create_flow() + flow.set_start_comp("start", Start("start",{"userFields":{"inputs":[],"outputs":[]},"systemFields":{"input":[{"id":"query","type":"String","required":"true","sourceType":"ref"}]}}), + inputs_schema={ + "systemFields": {"query": "${a}"}, + "userFields": {}, + "response_node": "${response_mode}", + "d": "${b}"}) + flow.add_workflow_comp("a", Node1("a"), + inputs_schema={ + "aa": "${start.d}", + "ac": "${start.d}"}) + flow.set_end_comp("end", End("end", "end",{}), + inputs_schema={ + "userFields": {"end_input": "${start.userFields.d}"}, + "response_mode": "${start.userFields.response_node}"}, + ) + flow.add_connection("start", "a") + flow.add_connection("a", "end") + self.assert_workflow_invoke({"a": 1, "b": "haha"}, create_context(), flow, expect_results={'output': {'end_input': 'haha'}, 'responseContent': ''}) - def test_simple_stream_workflow(self): + def test_end_from_invoke_stream_workflow(self): async def stream_workflow(): flow = create_flow() flow.set_start_comp("start", Start("start", {"userFields": {"inputs": [], "outputs": []}, "systemFields": { "input": [{"id": "query", "type": "String", "required": "true", "sourceType": "ref"}]}}), inputs_schema={ "systemFields": {"query": "${a}"}, - "userFields": {}, + "userFields": {"d": "${a}"}, "response_node": "${response_mode}", - "d": "${user.inputs.a}"}) - expected_datas = [ - {"id": 1, "data": "1"}, - {"id": 2, "data": "2"}, - ] - - flow.add_workflow_comp("a", StreamNode("a", expected_datas), - inputs_schema={ - "aa": "${start.a}", - "ac": "${start.c}"}) + "d": "${a}"}) + + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${a}"}, + comp_ability=[ComponentAbility.STREAM], wait_for_all=True) + + flow.set_end_comp("end", End("end", "end", {"responseTemplate": "hello:{{end_input}}"}), inputs_schema={ - "userFields": {"end_input": "${start.userFields.d}"}, - "response_mode": "${start.userFields.response_node}"}) + "userFields": {"end_input": "${start.userFields.d}"}},response_mode="streaming") flow.add_connection("start", "a") - flow.add_connection("a", "end") + flow.add_stream_connection("a", "end") index = 0 - async for chunk in flow.stream({"a": 1, "b": "haha","response_mode":"streaming"}, create_context_with_tracer()): + async for chunk in flow.stream({"a": 1, "b": "haha"}, create_context_with_tracer()): logger.info("stream chunk: {%s}", chunk) index += 1 self.loop.run_until_complete(stream_workflow()) + def test_end_transform_workflow(self): + async def stream_workflow(): + flow = create_flow() + flow.set_start_comp("start", Start("start", {"userFields": {"inputs": [], "outputs": []}, "systemFields": { + "input": [{"id": "query", "type": "String", "required": "true", "sourceType": "ref"}]}}), + inputs_schema={ + "systemFields": {"query": "${a}"}, + "userFields": {"d": "${a}"}, + "response_node": "${response_mode}", + "d": "${a}"}) + + flow.add_workflow_comp("a", StreamCompNode("a"), inputs_schema={"value": "${a}"}, + comp_ability=[ComponentAbility.STREAM], wait_for_all=True) + + flow.set_end_comp("end", End("end", "end", {"responseTemplate": "hello:{{end_input}}"}), + inputs_schema={ + "userFields": {"end_input": "${start.userFields.d}"}},response_mode="streaming") + flow.add_connection("start", "a") + flow.add_stream_connection("a", "end") + + index = 0 + async for chunk in flow.stream({"a": 1, "b": "haha"}, create_context_with_tracer()): + logger.info("stream chunk: {%s}", chunk) + index += 1 + + self.loop.run_until_complete(stream_workflow()) \ No newline at end of file diff --git a/tests/unit_tests/workflow/test_mock_node.py b/tests/unit_tests/workflow/test_mock_node.py index c1c267305a9891d3d13b50073c233885b5aa1f13..5cfa977263c63781e78780b4257b9a5b3792af7c 100644 --- a/tests/unit_tests/workflow/test_mock_node.py +++ b/tests/unit_tests/workflow/test_mock_node.py @@ -206,7 +206,7 @@ class StreamCompNode(MockNodeBase): self._node_id = node_id async def stream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: - logger.info(f"===StreamCompNode[{self._node_id}], input: {inputs}") + logger.debug(f"===StreamCompNode[{self._node_id}], input: {inputs}") if inputs is None: yield 1 else: