diff --git a/jiuwen/core/component/questioner_comp.py b/jiuwen/core/component/questioner_comp.py index 2b914a948355b7ab75ff858781bab1b339884f12..8fe4411d31016931c6a5b32f171ad13bc8b8add2 100644 --- a/jiuwen/core/component/questioner_comp.py +++ b/jiuwen/core/component/questioner_comp.py @@ -18,6 +18,7 @@ from jiuwen.core.component.base import ComponentConfig, WorkflowComponent from jiuwen.core.context.context import Context from jiuwen.core.graph.executable import Executable, Input, Output from jiuwen.core.utils.llm.base import BaseChatModel +from jiuwen.core.utils.llm.messages import BaseMessage from jiuwen.core.utils.llm.model_utils.model_factory import ModelFactory from jiuwen.core.utils.prompt.template.template import Template from jiuwen.core.utils.prompt.template.template_manager import TemplateManager @@ -287,10 +288,10 @@ class QuestionerDirectReplyHandler: result.append(dict(role="user", content=self._query)) return result - def _build_llm_inputs(self, chat_history: list = None) -> List: + def _build_llm_inputs(self, chat_history: list = None) -> List[BaseMessage]: prompt_template_input = self._create_prompt_template_keywords(chat_history) - prompt_template = TemplateManager().format(prompt_template_input, self._prompt) - return prompt_template.content + formatted_template: Template = self._prompt.format(prompt_template_input) + return formatted_template.to_messages() def _create_prompt_template_keywords(self, chat_history): params_list, required_name_list = list(), list() @@ -304,9 +305,9 @@ class QuestionerDirectReplyHandler: return dict(required_name=required_name_str, required_params_list=all_param_str, extra_info=self._config.extra_prompt_for_fields_extraction, example=self._config.example_content, - dig_history=dialogue_history_str) + dialogue_history=dialogue_history_str) - def _invoke_llm_for_extraction(self, llm_inputs): + def _invoke_llm_for_extraction(self, llm_inputs: List[BaseMessage]): try: response = self._model.invoke(llm_inputs).content except Exception as e: @@ -441,7 +442,7 @@ class QuestionerExecutable(Executable): def _init_prompt(self) -> Template: if self._config.prompt_template: - return Template(name="user_prompt", content=self._config.prompt_template) + return Template(name="question_user_prompt", content=self._config.prompt_template) filters = dict(model_name=self._config.model.model_info.model_name) return TemplateManager().get(name=TEMPLATE_NAME, filters=filters) diff --git a/tests/system_tests/workflow/test_real_workflow.py b/tests/system_tests/workflow/test_real_workflow.py index fe9a215570e0531dfcb21bf5ba5d70430a49a76b..e797fd5e30dc1000faa71fa6502406e12281f149 100644 --- a/tests/system_tests/workflow/test_real_workflow.py +++ b/tests/system_tests/workflow/test_real_workflow.py @@ -1,4 +1,5 @@ import asyncio +import os import unittest from unittest.mock import patch @@ -10,7 +11,6 @@ from jiuwen.core.context.config import WorkflowConfig, Config from jiuwen.core.context.context import Context from jiuwen.core.context.memory.base import InMemoryState from jiuwen.core.utils.llm.base import BaseModelInfo -from jiuwen.core.utils.prompt.template.template import Template from jiuwen.core.utils.tool.service_api.param import Param from jiuwen.core.utils.tool.service_api.restful_api import RestfulApi from jiuwen.core.workflow.base import Workflow @@ -28,11 +28,32 @@ MOCK_TOOL = RestfulApi( response=[], ) FINAL_RESULT = "Success" -API_BASE = "" -API_KEY = "" -MODEL_NAME = "" -MODEL_PROVIDER = "" - +API_BASE = os.getenv("API_BASE", "") +API_KEY = os.getenv("API_KEY", "") +MODEL_NAME = os.getenv("MODEL_NAME", "") +MODEL_PROVIDER = os.getenv("MODEL_PROVIDER", "") + +QUESTIONER_SYSTEM_TEMPLATE = """你是一个信息收集助手,你需要根据指定的参数收集用户的信息,然后提交到系统。 +请注意:不要使用任何工具、不用理会问题的具体含义,并保证你的输出仅有json格式的结果数据,以保证返回结果可以被json.dump直接解析。 +请严格遵循如下规则: + 1、让我们一步一步思考。 + 2、用户输入中没有提及的参数提取为None,并直接向询问用户没有明确提供的参数,让用户来提供。 + 3、通过用户提供的对话历史以及当前输入中提取{{required_name}},不要追问任何其他信息。 + 4、参数收集完成后,将收集到的信息通过JSON的方式展示给用户。 +例如,指定的参数为[A,B,C,D] +当前用户输入为'A为a,B得更新,C是3',此时根据用户输入,B需要更新,D未提及,那么你的返回结果即为:{'A':'a', 'B':None, 'C':3, 'D':None} +##指定参数 +{{required_params_list}} +##约束 +{{extra_info}} +##示例 +{{example}} +""" + +QUESTIONER_USER_TEMPLATE = """对话历史 +{{dialogue_history}} +请充分考虑以上对话历史及用户输入正确提取最符合约束要求的json格式参数,保证生成结果可以直接被json.load解析 +""" class RealWorkflowTest(unittest.TestCase): def setUp(self): @@ -49,27 +70,10 @@ class RealWorkflowTest(unittest.TestCase): @patch('jiuwen.core.utils.tool.service_api.restful_api.RestfulApi.invoke') @patch('jiuwen.core.component.tool_comp.ToolExecutable.get_tool') - @patch("jiuwen.core.component.questioner_comp.QuestionerDirectReplyHandler._invoke_llm_for_extraction") - @patch("jiuwen.core.component.questioner_comp.QuestionerDirectReplyHandler._build_llm_inputs") - @patch("jiuwen.core.component.questioner_comp.QuestionerExecutable._init_prompt") - def test_workflow_llm_questioner_plugin(self, mock_questioner_init_prompt, - mock_questioner_llm_inputs, - mock_questioner_extraction, mock_plugin_get_tool, mock_plugin_invoke): + def test_workflow_llm_questioner_plugin(self, + mock_plugin_get_tool, + mock_plugin_invoke): """Start -> LLM -> Questioner -> Plugin -> End""" - # LLM的mock逻辑 - # fake_llm = AsyncMock() - # fake_llm.ainvoke = AsyncMock(return_value="mocked response") - # mock_get_model.return_value = fake_llm - - # 提问器的mock逻辑 - mock_prompt_template = [ - dict(role="system", content="系统提示词"), - dict(role="user", content="你是一个AI助手") - ] - mock_questioner_init_prompt.return_value = Template(name="test", content=mock_prompt_template) - mock_questioner_llm_inputs.return_value = mock_prompt_template - mock_questioner_extraction.return_value = dict(location="hangzhou") - # 插件组件的mock逻辑 mock_plugin_get_tool.return_value = MOCK_TOOL mock_plugin_invoke.return_value = {"result": FINAL_RESULT} @@ -77,7 +81,7 @@ class RealWorkflowTest(unittest.TestCase): # 实例化工作流和上下文 flow = self._create_flow() context = Context(config=Config(), state=InMemoryState(), store=None) - # context.state.update({WORKFLOW_CHAT_HISTORY: []}) + # 实例化组件 start_component = MockStartNode("start") llm_component = self._create_llm_component() @@ -99,7 +103,8 @@ class RealWorkflowTest(unittest.TestCase): }) flow.add_workflow_comp("questioner", questioner_component, inputs_schema={ - "query": "${llm.userFields.result}" + # "query": "${llm.userFields.result}" + "query": "${start.query}" # TODO:临时mock }) flow.add_workflow_comp("plugin", plugin_component, inputs_schema={ @@ -112,8 +117,10 @@ class RealWorkflowTest(unittest.TestCase): }) # 组件间的连边 - flow.add_connection("start", "llm") - flow.add_connection("llm", "questioner") + # flow.add_connection("start", "llm") + # flow.add_connection("llm", "questioner") + flow.add_connection("start", "questioner") # TODO:临时mock + flow.add_connection("questioner", "plugin") flow.add_connection("plugin", "end") @@ -137,7 +144,7 @@ class RealWorkflowTest(unittest.TestCase): def _get_mode_config(): model_config = ModelConfig(model_provider=MODEL_PROVIDER, model_info=BaseModelInfo( - model_name=MODEL_NAME, + model=MODEL_NAME, api_base=API_BASE, api_key=API_KEY, temperature=0.7, @@ -158,7 +165,9 @@ class RealWorkflowTest(unittest.TestCase): question_content="", extract_fields_from_response=True, field_names=key_fields, - with_chat_history=False + with_chat_history=False, + prompt_template=[dict(role="system", content=QUESTIONER_SYSTEM_TEMPLATE), + dict(role="user", content=QUESTIONER_USER_TEMPLATE)] ) return QuestionerComponent(questioner_comp_config=questioner_config)