From e498e24fbb42f45c4c0001641ed7dbbf48d03714 Mon Sep 17 00:00:00 2001 From: zhongxiaotian Date: Thu, 17 Jul 2025 18:59:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加工具组件用例 #ICLT7L --- jiuwen/core/component/tool_comp.py | 44 +++++++--- tests/unit_tests/workflow/test_tool_comp.py | 94 +++++++++++++++++++++ 2 files changed, 126 insertions(+), 12 deletions(-) create mode 100644 tests/unit_tests/workflow/test_tool_comp.py diff --git a/jiuwen/core/component/tool_comp.py b/jiuwen/core/component/tool_comp.py index 3d33dbd..3d8dc47 100644 --- a/jiuwen/core/component/tool_comp.py +++ b/jiuwen/core/component/tool_comp.py @@ -3,19 +3,36 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved from dataclasses import dataclass, field -from typing import Dict, Any, AsyncIterator, Iterator +from typing import Dict, Any, AsyncIterator, Iterator, List from jiuwen.core.common.exception.exception import JiuWenBaseException from jiuwen.core.common.exception.status_code import StatusCode from jiuwen.core.component.base import ComponentConfig, WorkflowComponent from jiuwen.core.context.context import Context from jiuwen.core.graph.executable import Executable, Input, Output +from jiuwen.core.utils.tool.base import Tool @dataclass class ToolComponentConfig(ComponentConfig): header: Dict[str, Any] = field(default_factory=dict) + method: str = '' + auth: Dict[str, Any] = field(default_factory=dict) + pluginDependency: Dict[str, Any] = field(default_factory=dict) + systemFields: Dict[str, Any] = field(default_factory=dict) + exceptionEnable: bool = False + description: str = '' + url: str = '' + streaming: bool = False + userFields: Dict[str, Any] = field(default_factory=dict) + response: List[Any] = field(default_factory=list) + name: str = '' + arguments: List[Any] = field(default_factory=list) + id: str = '' + needValidate: bool = True + needConfirm: bool = False + apiId: str = '' class ToolExecutable(Executable): @@ -23,9 +40,9 @@ class ToolExecutable(Executable): def __init__(self, config: ToolComponentConfig): super().__init__() self._config = config - self._tool = None + self._tool: Tool = None - def invoke(self, inputs: Input, context: Context) -> Output: + async def invoke(self, inputs: Input, context: Context) -> Output: self._tool = self.get_tool(context) validated = inputs.get('validate', False) user_field = inputs.get('userFields', None) @@ -33,7 +50,7 @@ class ToolExecutable(Executable): self.validate_require_params(user_field) formatted_inputs = prepare_inputs(user_field, self.get_tool_param()) try: - response = self._tool.run(formatted_inputs) + response = self._tool.invoke(formatted_inputs) return self._create_output(response) except Exception as e: raise JiuWenBaseException( @@ -41,26 +58,26 @@ class ToolExecutable(Executable): message='tool component execution error' ) from e - async def ainvoke(self, inputs: Input, context: Context) -> Output: + async def stream(self, inputs: Input, context: Context) -> Iterator[Output]: pass - def stream(self, inputs: Input, context: Context) -> Iterator[Output]: + async def collect(self, inputs: AsyncIterator[Input], contex: Context) -> Output: pass - async def astream(self, inputs: Input, context: Context) -> AsyncIterator[Output]: + async def transform(self, inputs: AsyncIterator[Input], context: Context) -> AsyncIterator[Output]: pass - def interrupt(self, message: dict): + async def interrupt(self, message: dict): pass def _create_output(self, response): return response - def get_tool(self, context: Context): + def get_tool(self, context: Context) -> Tool: pass def get_tool_param(self): - pass + return self._tool.params def validate_require_params(self, user_field): require_params = self.get_tool_param() @@ -75,6 +92,7 @@ class ToolExecutable(Executable): } self.interrupt(interrupt_message) + TYPE_CASTER = { "str": str, "integer": int, @@ -82,6 +100,7 @@ TYPE_CASTER = { "bool": bool } + def _transform_type(value, expected_type, key): expected_type = expected_type.lower() caster = TYPE_CASTER.get(expected_type) @@ -97,7 +116,7 @@ def _transform_type(value, expected_type, key): return value -def prepare_inputs(user_field, defined_param): +def prepare_inputs(user_field, defined_param) -> dict: define_dict = {} formatted_inputs = {} for param in defined_param: @@ -112,6 +131,7 @@ def prepare_inputs(user_field, defined_param): error_code=StatusCode.TOOL_COMPONENT_INPUTS_ERROR.code, message=f'{StatusCode.TOOL_COMPONENT_INPUTS_ERROR.errmsg}, param is {k}' ) + return formatted_inputs class ToolComponent(WorkflowComponent): @@ -120,4 +140,4 @@ class ToolComponent(WorkflowComponent): self._config = config def to_executable(self) -> Executable: - return ToolExecutable(self._config) \ No newline at end of file + return ToolExecutable(self._config) diff --git a/tests/unit_tests/workflow/test_tool_comp.py b/tests/unit_tests/workflow/test_tool_comp.py new file mode 100644 index 0000000..9c46135 --- /dev/null +++ b/tests/unit_tests/workflow/test_tool_comp.py @@ -0,0 +1,94 @@ +from unittest.mock import patch, Mock, MagicMock + +import pytest + +from jiuwen.core.component.tool_comp import ToolComponentConfig, ToolExecutable, ToolComponent +from jiuwen.core.context.config import Config +from jiuwen.core.context.context import Context +from jiuwen.core.context.memory.base import InMemoryState +from jiuwen.core.utils.tool.service_api.param import Param +from jiuwen.core.utils.tool.service_api.restful_api import RestfulApi +from tests.unit_tests.workflow.test_mock_node import MockStartNode, MockEndNode +from tests.unit_tests.workflow.test_workflow import create_flow + + +@pytest.fixture +def fake_ctx(): + from unittest.mock import MagicMock + ctx = MagicMock(spec=Context) + ctx.store = MagicMock() + ctx.store.read.return_value = [] + return ctx + + +@pytest.fixture() +def mock_tool_config(): + return ToolComponentConfig( + needValidate=False + ) + + +@pytest.fixture +def mock_tool_input(): + return { + 'userFields': { + 'location': 'Beijing', + 'date': 15 + }, + 'validated': False + } + + +@pytest.fixture +def mock_tool(): + return RestfulApi( + name="test", + description="test", + params=[Param(name="location", description="location", type='string'), + Param(name="date", description="date", type='int')], + path="http://127.0.0.1:8000", + headers={}, + method="GET", + response=[], + ) + + +@patch('requests.request') +@patch('jiuwen.core.component.tool_comp.ToolExecutable.get_tool') +@pytest.mark.asyncio +async def test_tool_comp_invoke(mock_get_tool, mock_request, mock_tool, mock_tool_config, mock_tool_input, fake_ctx): + mock_get_tool.return_value = mock_tool + tool_executable = ToolExecutable(mock_tool_config) + + # mock request的response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "{}" + mock_response.content = b"{}" + mock_request.return_value = mock_response + res = await tool_executable.invoke(mock_tool_input, fake_ctx) + + assert res.get('errCode') == 0 + + +@patch('jiuwen.core.component.tool_comp.ToolExecutable.invoke') +@patch('jiuwen.core.component.tool_comp.ToolExecutable.get_tool') +@pytest.mark.asyncio +async def test_tool_comp_invoke(mock_get_tool, mock_invoke, mock_tool, mock_tool_config, fake_ctx): + mock_get_tool.return_value = mock_tool + mock_invoke.return_value = 'res' + context = Context(config=Config(), state=InMemoryState(), store=None, tracer=None) + flow = create_flow() + + start_component = MockStartNode("s") + end_component = MockEndNode("e") + tool_component = ToolComponent(mock_tool_config) + + flow.set_start_comp("s", start_component) + flow.set_end_comp("e", end_component) + flow.add_workflow_comp("tool", tool_component) + + flow.add_connection("s", "tool") + flow.add_connection("tool", "e") + + await flow.invoke({}, context) -- Gitee