From ee7bfcfc8c8328e1a4f7f2e434d79372278896c6 Mon Sep 17 00:00:00 2001 From: Jiayuan Kang Date: Fri, 25 Jul 2025 16:48:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E5=A4=A7=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=BB=84=E4=BB=B6=E7=BB=84=E4=BB=B6=E9=97=B4=E6=B5=81?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 优化大模型组件组件间流式 #ICLT7L:组件功能实现 --- jiuwen/core/component/llm_comp.py | 7 +++- .../system_tests/agent/test_workflow_agent.py | 37 +++++++++++++++--- tests/unit_tests/workflow/test_llm_comp.py | 38 +++++++++++++++++++ 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/jiuwen/core/component/llm_comp.py b/jiuwen/core/component/llm_comp.py index ac97214..6df700a 100644 --- a/jiuwen/core/component/llm_comp.py +++ b/jiuwen/core/component/llm_comp.py @@ -328,6 +328,7 @@ class LLMExecutable(Executable): async def _invoke_for_json_format(self, inputs: Input) -> AsyncIterator[Output]: model_inputs = self._prepare_model_inputs(inputs) + logger.info("[%s] model inputs %s", self._context.executable_id(), model_inputs) llm_output = await self._llm.ainvoke(model_inputs) # 如果 invoke 是异步接口,要加 await yield self._create_output(llm_output) @@ -336,7 +337,11 @@ class LLMExecutable(Executable): # 假设 self._llm.stream 本身就是异步生成器 async for chunk in self._llm.astream(model_inputs): content = WorkflowLLMUtils.extract_content(chunk) - yield content + formatted_res = OutputFormatter.format_response(content, + self._config.response_format, + self._config.output_config) + stream_out = {USER_FIELDS: formatted_res} + yield stream_out def _format_response_content(self, response_content: str) -> dict: pass diff --git a/tests/system_tests/agent/test_workflow_agent.py b/tests/system_tests/agent/test_workflow_agent.py index 2635fa7..6de0900 100644 --- a/tests/system_tests/agent/test_workflow_agent.py +++ b/tests/system_tests/agent/test_workflow_agent.py @@ -24,8 +24,6 @@ from jiuwen.core.utils.tool.service_api.restful_api import RestfulApi from jiuwen.core.workflow.base import Workflow from jiuwen.core.workflow.workflow_config import WorkflowConfig, WorkflowMetadata from jiuwen.graph.pregel.graph import PregelGraph -from tests.system_tests.workflow.test_real_workflow import RealWorkflowTest, MODEL_PROVIDER, MODEL_NAME, API_BASE, \ - API_KEY, _FINAL_RESULT, _QUESTIONER_USER_TEMPLATE, _QUESTIONER_SYSTEM_TEMPLATE from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode API_BASE = os.getenv("API_BASE", "") @@ -46,6 +44,35 @@ _MOCK_TOOL = RestfulApi( response=[], ) +_FINAL_RESULT: str = "上海今天晴 30°C" + +# --------------------------- Prompt 模板 --------------------------- # +_QUESTIONER_SYSTEM_TEMPLATE = """\ +你是一个信息收集助手,你需要根据指定的参数收集用户的信息,然后提交到系统。 +请注意:不要使用任何工具、不用理会问题的具体含义,并保证你的输出仅有 JSON 格式的结果数据。 +请严格遵循如下规则: + 1. 让我们一步一步思考。 + 2. 用户输入中没有提及的参数提取为 None,并直接向询问用户没有明确提供的参数。 + 3. 通过用户提供的对话历史以及当前输入中提取 {{required_name}},不要追问任何其他信息。 + 4. 参数收集完成后,将收集到的信息通过 JSON 的方式展示给用户。 + +## 指定参数 +{{required_params_list}} + +## 约束 +{{extra_info}} + +## 示例 +{{example}} +""" + +_QUESTIONER_USER_TEMPLATE = """\ +对话历史 +{{dialogue_history}} + +请充分考虑以上对话历史及用户输入,正确提取最符合约束要求的 JSON 格式参数。 +""" + class WorkflowAgentTest(unittest.IsolatedAsyncioTestCase): """专门用于测试 WorkflowAgent.invoke 的类。""" @@ -68,7 +95,7 @@ class WorkflowAgentTest(unittest.IsolatedAsyncioTestCase): @staticmethod def _create_intent_detection_component() -> IntentDetectionComponent: """创建意图识别组件。""" - model_config = RealWorkflowTest._create_model_config() + model_config = WorkflowAgentTest._create_model_config() user_prompt = """ {{user_prompt}} @@ -103,7 +130,7 @@ class WorkflowAgentTest(unittest.IsolatedAsyncioTestCase): @staticmethod def _create_llm_component() -> LLMComponent: """创建 LLM 组件,仅用于抽取结构化字段(location/date)。""" - model_config = RealWorkflowTest._create_model_config() + model_config = WorkflowAgentTest._create_model_config() config = LLMCompConfig( model=model_config, template_content=[{"role": "user", "content": "{{query}}"}], @@ -126,7 +153,7 @@ class WorkflowAgentTest(unittest.IsolatedAsyncioTestCase): default_value="today", ), ] - model_config = RealWorkflowTest._create_model_config() + model_config = WorkflowAgentTest._create_model_config() config = QuestionerConfig( model=model_config, question_content="", diff --git a/tests/unit_tests/workflow/test_llm_comp.py b/tests/unit_tests/workflow/test_llm_comp.py index adf602e..4d500cf 100644 --- a/tests/unit_tests/workflow/test_llm_comp.py +++ b/tests/unit_tests/workflow/test_llm_comp.py @@ -1,5 +1,6 @@ import sys import types +from typing import Any import pytest from unittest.mock import Mock @@ -94,6 +95,43 @@ class TestLLMExecutableInvoke: assert output[USER_FIELDS] == {'result': 'mocked response'} fake_llm.ainvoke.assert_called_once() + @pytest.mark.asyncio + async def test_stream_success( + self, + mock_get_model, # 这就是补丁 + fake_ctx, + fake_input, + fake_model_config, + ): + config = LLMCompConfig( + model=fake_model_config, + template_content=[{"role": "user", "content": "Hello {query}"}], + response_format={"type": "text"}, + output_config={"result": { + "type": "string", + "required": True, + }}, + ) + exe = LLMExecutable(config) + + fake_llm = AsyncMock() + + # 模拟异步生成器,返回多个 AIMessage chunk + async def mock_stream_response(input: Any): + for chunk in ["mocked ", "response"]: + yield AIMessage(content=chunk) + + fake_llm.astream = mock_stream_response + mock_get_model.return_value = fake_llm + + # 调用 stream 方法,异步迭代所有 chunk + chunks = [] + async for chunk in exe.stream(fake_input(userFields=dict(query="pytest")), fake_ctx): + chunks.append(chunk) + + # 假设 LLMExecutable.stream 会把每个 AIMessage.content 直接 yield 出来 + assert len(chunks) == 2 + @pytest.mark.asyncio async def test_invoke_llm_exception( self, -- Gitee