From 28241f1fa732c0f9412af3e6cd9fd290a9b575b9 Mon Sep 17 00:00:00 2001 From: lilei Date: Mon, 14 Jul 2025 14:21:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20prompt=E6=A8=A1=E5=9D=97=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- jiuwen/core/__init__.py | 0 jiuwen/core/common/exception/status_code.py | 15 ++ jiuwen/core/utils/__init__.py | 0 jiuwen/core/utils/llm/messages.py | 9 + jiuwen/core/utils/output_parser/base.py | 27 +++ .../output_parser/novel_tool_output_parser.py | 73 +++++++ .../utils/output_parser/null_output_parser.py | 13 ++ jiuwen/core/utils/prompt/assemble/__init__.py | 0 .../core/utils/prompt/assemble/assembler.py | 95 ++++++++ .../utils/prompt/assemble/message_handler.py | 155 +++++++++++++ .../prompt/assemble/variables/__init__.py | 0 .../prompt/assemble/variables/textable.py | 57 +++++ .../prompt/assemble/variables/variable.py | 45 ++++ jiuwen/core/utils/prompt/base.py | 0 jiuwen/core/utils/prompt/common/__init__.py | 0 jiuwen/core/utils/prompt/common/document.py | 20 ++ jiuwen/core/utils/prompt/common/singleton.py | 17 ++ jiuwen/core/utils/prompt/index/__init__.py | 0 .../prompt/index/template_store/__init__.py | 0 .../prompt/index/template_store/in_memory.py | 65 ++++++ .../in_memory_template_store.py | 72 +++++++ .../index/template_store/template_store.py | 40 ++++ jiuwen/core/utils/prompt/template/__init__.py | 0 jiuwen/core/utils/prompt/template/template.py | 46 ++++ .../utils/prompt/template/template_manager.py | 107 +++++++++ tests/unit_tests/prompt/__init__.py | 0 .../prompt/data/intent_recognition.pr | 3 + tests/unit_tests/prompt/data/prompts.yaml | 5 + tests/unit_tests/prompt/data/summary.pr | 2 + .../unit_tests/prompt/test_message_handler.py | 80 +++++++ tests/unit_tests/prompt/test_output_parser.py | 43 ++++ .../prompt/test_template_assemble.py | 204 ++++++++++++++++++ .../prompt/test_template_manager.py | 60 ++++++ 33 files changed, 1253 insertions(+) create mode 100644 jiuwen/core/__init__.py create mode 100644 jiuwen/core/utils/__init__.py create mode 100644 jiuwen/core/utils/llm/messages.py create mode 100644 jiuwen/core/utils/output_parser/base.py create mode 100644 jiuwen/core/utils/output_parser/novel_tool_output_parser.py create mode 100644 jiuwen/core/utils/output_parser/null_output_parser.py create mode 100644 jiuwen/core/utils/prompt/assemble/__init__.py create mode 100644 jiuwen/core/utils/prompt/assemble/assembler.py create mode 100644 jiuwen/core/utils/prompt/assemble/message_handler.py create mode 100644 jiuwen/core/utils/prompt/assemble/variables/__init__.py create mode 100644 jiuwen/core/utils/prompt/assemble/variables/textable.py create mode 100644 jiuwen/core/utils/prompt/assemble/variables/variable.py create mode 100644 jiuwen/core/utils/prompt/base.py create mode 100644 jiuwen/core/utils/prompt/common/__init__.py create mode 100644 jiuwen/core/utils/prompt/common/document.py create mode 100644 jiuwen/core/utils/prompt/common/singleton.py create mode 100644 jiuwen/core/utils/prompt/index/__init__.py create mode 100644 jiuwen/core/utils/prompt/index/template_store/__init__.py create mode 100644 jiuwen/core/utils/prompt/index/template_store/in_memory.py create mode 100644 jiuwen/core/utils/prompt/index/template_store/in_memory_template_store.py create mode 100644 jiuwen/core/utils/prompt/index/template_store/template_store.py create mode 100644 jiuwen/core/utils/prompt/template/__init__.py create mode 100644 jiuwen/core/utils/prompt/template/template.py create mode 100644 jiuwen/core/utils/prompt/template/template_manager.py create mode 100644 tests/unit_tests/prompt/__init__.py create mode 100644 tests/unit_tests/prompt/data/intent_recognition.pr create mode 100644 tests/unit_tests/prompt/data/prompts.yaml create mode 100644 tests/unit_tests/prompt/data/summary.pr create mode 100644 tests/unit_tests/prompt/test_message_handler.py create mode 100644 tests/unit_tests/prompt/test_output_parser.py create mode 100644 tests/unit_tests/prompt/test_template_assemble.py create mode 100644 tests/unit_tests/prompt/test_template_manager.py diff --git a/jiuwen/core/__init__.py b/jiuwen/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/common/exception/status_code.py b/jiuwen/core/common/exception/status_code.py index 4fa003b..fbef15d 100644 --- a/jiuwen/core/common/exception/status_code.py +++ b/jiuwen/core/common/exception/status_code.py @@ -7,6 +7,21 @@ from enum import Enum class StatusCode(Enum): CONTROLLER_INTERRUPTED_ERROR = (10312, "controller interrupted error") + # Prompt 模板管理 102050 - 102099 + PROMPT_ASSEMBLER_VARIABLE_INIT_ERROR = (102050, "Wrong arguments for initializing the variable") + PROMPT_ASSEMBLER_INIT_ERROR = (102051, "Wrong arguments for initializing the assembler") + PROMPT_ASSEMBLER_INPUT_KEY_ERROR = ( + 102052, + "Missing or unexpected key-value pairs passed in as arguments for the assembler or variable when updating" + ) + PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR = ( + 102053, + "Errors occur when formatting the template content due to wrong format") + + # Prompt 模板管理 102100 - 102149 + PROMPT_TEMPLATE_DUPLICATED_ERROR = (102101, "Template duplicated") + PROMPT_TEMPLATE_NOT_FOUND_ERROR = (102102, "Template not found") + @property def code(self): return self.value[0] diff --git a/jiuwen/core/utils/__init__.py b/jiuwen/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/llm/messages.py b/jiuwen/core/utils/llm/messages.py new file mode 100644 index 0000000..1135db2 --- /dev/null +++ b/jiuwen/core/utils/llm/messages.py @@ -0,0 +1,9 @@ +from typing import Union, List, Dict, Optional + +from pydantic import BaseModel + + +class BaseMessage(BaseModel): + role: str + content: Union[str, List[Union[str, Dict]]] + name: Optional[str] = None \ No newline at end of file diff --git a/jiuwen/core/utils/output_parser/base.py b/jiuwen/core/utils/output_parser/base.py new file mode 100644 index 0000000..9e294be --- /dev/null +++ b/jiuwen/core/utils/output_parser/base.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Any, Iterator + + +class BaseOutputParser: + """Base class for output parsers.""" + @classmethod + def from_config(cls, parse_method: str, parse_config: dict = None) -> BaseOutputParser: + """create a parser instance""" + if parse_config is None: + parse_config = dict() + if parse_method == "novel_tool": + from jiuwen.core.utils.output_parser.novel_tool_output_parser import NovelToolOutputParser + output_parser_cls = NovelToolOutputParser + else: + from jiuwen.core.utils.output_parser.null_output_parser import NullOutputParser + output_parser_cls = NullOutputParser + return output_parser_cls(**parse_config) + + def parse(self, llm_output: str) -> Any: + """convert content into its expected format""" + raise NotImplementedError() + + def stream_parse(self, streaming_inputs: Iterator[dict]) -> Iterator[dict]: + """parse in the streaming manner""" + raise NotImplementedError() \ No newline at end of file diff --git a/jiuwen/core/utils/output_parser/novel_tool_output_parser.py b/jiuwen/core/utils/output_parser/novel_tool_output_parser.py new file mode 100644 index 0000000..3b6976c --- /dev/null +++ b/jiuwen/core/utils/output_parser/novel_tool_output_parser.py @@ -0,0 +1,73 @@ +import json +import re +from typing import Iterator + +from jiuwen.core.utils.output_parser.base import BaseOutputParser + +MESSAGE_KEY = "message" +CONTENT_KEY = "content" +TOOL_CALLS_KEY = "tool_calls" +LLM_CONTENT_KEY = "llm_content" + + +class NovelToolOutputParser(BaseOutputParser): + """Novel tool output parser""" + def parse(self, llm_output: dict) -> dict: + """parse the llm output content""" + content = llm_output.get(MESSAGE_KEY, {}).get(CONTENT_KEY, "") + match_result = re.findall(r'^\[([a-zA-Z\d_]+)\(([\s\S]*)\)\]', content.strip()) + if match_result: + function_match = match_result[0] + arguments = dict() + params_pattern = re.compile(r'(\w+)\s*=\s*(.*?)(?=\s*,\s*\w+\s*=|$)') + for argument_match in params_pattern.finditer(function_match[1]): + value_str = argument_match[1].strip().replace('\'', '"') + try: + arg_value = json.loads(value_str) + except json.JSONDecodeError as _: + arg_value = argument_match[1] + arguments[argument_match[0]] = arg_value + function_call = { + "type": "function", + "function": { + "name": function_match[0].strip(), + "arguments": json.dumps(arguments, ensure_ascii=False) + } + } + else: + function_call = {} + if function_call: + message = llm_output.setdefault(MESSAGE_KEY, {}) + message[TOOL_CALLS_KEY] = [function_call] + message[LLM_CONTENT_KEY] = message.get(CONTENT_KEY, '') + message[CONTENT_KEY] = '' + return llm_output + + def stream_parse(self, streaming_inputs: Iterator[dict]) -> Iterator[dict]: + """parse the streaming input""" + is_valid_tool_call = True + cached_tokens = "" + + for output in streaming_inputs: + if output.get("type", "") == "full_result": + output = self.parse(output) + yield output + else: + cached_tokens += output.get(MESSAGE_KEY, {}).get(TOOL_CALLS_KEY, "") + if is_valid_tool_call: + if len(cached_tokens) > 0 and cached_tokens[0] != "[": + is_valid_tool_call = False + if len(cached_tokens) > 1: + if "(" in cached_tokens: + func_name = cached_tokens[1: cached_tokens.index("(")] + else: + func_name = cached_tokens[1:] + res = re.match(r"^[a-zA-Z\d_]+$", func_name) + if not res: + is_valid_tool_call = False + if not is_valid_tool_call: + output[MESSAGE_KEY][CONTENT_KEY] = cached_tokens + yield output + continue + if not is_valid_tool_call: + yield output \ No newline at end of file diff --git a/jiuwen/core/utils/output_parser/null_output_parser.py b/jiuwen/core/utils/output_parser/null_output_parser.py new file mode 100644 index 0000000..ad188e1 --- /dev/null +++ b/jiuwen/core/utils/output_parser/null_output_parser.py @@ -0,0 +1,13 @@ +from typing import Iterator + +from jiuwen.core.utils.output_parser.base import BaseOutputParser + + +class NullOutputParser(BaseOutputParser): + """Null output parser class""" + def __init__(self, llm_output: dict = None): + pass + + def stream_parse(self, streaming_input: Iterator[dict]) -> Iterator[dict]: + """parse in the streaming manner""" + return streaming_input \ No newline at end of file diff --git a/jiuwen/core/utils/prompt/assemble/__init__.py b/jiuwen/core/utils/prompt/assemble/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/assemble/assembler.py b/jiuwen/core/utils/prompt/assemble/assembler.py new file mode 100644 index 0000000..76a5b0d --- /dev/null +++ b/jiuwen/core/utils/prompt/assemble/assembler.py @@ -0,0 +1,95 @@ +import re +from copy import deepcopy +from typing import Union, List, Dict + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.utils.llm.messages import BaseMessage +from jiuwen.core.utils.prompt.assemble.variables.textable import TextableVariable +from jiuwen.core.utils.prompt.assemble.variables.variable import Variable +from jiuwen.core.utils.prompt.assemble.message_handler import messages_to_template, template_to_messages + + +class Assembler: + """class for creating prompt based on a given template""" + + def __init__(self, + template_content: Union[List[Dict], List[BaseMessage], str], + return_format: str = "message", + **variables): + if isinstance(template_content, List): + try: + template_content = messages_to_template(template_content) + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message="The List-type prompt should confirm to the format of LLM message, please check." + ) from e + self.return_format = return_format + template_formater = TextableVariable(template_content, name="__inner__") + for name, variable in variables.items(): + if name not in template_formater.input_keys: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_INIT_ERROR.code, + message=f"Variable {name} is not defined in the Template." + ) + if not isinstance(variable, Variable): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_INIT_ERROR.code, + message=f"Variable {name} must be instantiated as a `Variable` object." + ) + for placeholder in template_formater.input_keys: + if placeholder in variables: + variables[placeholder].name = placeholder + else: + variables[placeholder] = TextableVariable(name=placeholder, text="{{" + placeholder + "}}") + self.prompt = "" + self.variables = variables + self._template_formater = template_formater + + @property + def input_keys(self) -> List[str]: + """Get the list of argument names for updating all the variables""" + keys = [] + for variable in self.variables.values(): + keys.extend(variable.input_keys) + return list(set(keys)) + + def assemble(self, **kwargs) -> Union[str, List[dict]]: + """Update the variables and format the template into a string-type or message-type prompt""" + kwargs = {k: v for k, v in kwargs.items() if v is not None and k in self.input_keys} + all_kwargs = {} + for k in self.input_keys: + if k not in kwargs: + all_kwargs[k] = "" + all_kwargs.update(**kwargs) + self._update(**all_kwargs) + return self._format() + + def _update(self, **kwargs) -> None: + """Update the variables based on the arguments passed in as key-value pairs""" + missing_keys = set(self.input_keys) - set(kwargs.keys()) + if missing_keys: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message=f"Missing keys for updating the assembler: {list(missing_keys)}" + ) + unexpected_keys = set(kwargs.keys()) - set(self.input_keys) + if unexpected_keys: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message=f"Unexpected keys for updating the assembler: {list(unexpected_keys)}" + ) + for variable in self.variables.values(): + input_kwargs = {k: v for k, v in kwargs.items() if k in variable.input_keys} + variable.eval(**input_kwargs) + + def _format(self) -> Union[str, List[dict]]: + """Substitute placeholders in the template with variables values and get formatted prompt.""" + format_kwargs = {var.name: var.value for var in self.variables.values()} + formatted_prompt = self._template_formater.eval(**format_kwargs) + message_prefix_matches = list(re.finditer(r'`#(system|assistant|user|tool|function)#`', formatted_prompt)) + if self.return_format == "text" or not message_prefix_matches: + self.prompt = formatted_prompt + return formatted_prompt + return deepcopy(template_to_messages(formatted_prompt)) diff --git a/jiuwen/core/utils/prompt/assemble/message_handler.py b/jiuwen/core/utils/prompt/assemble/message_handler.py new file mode 100644 index 0000000..045df75 --- /dev/null +++ b/jiuwen/core/utils/prompt/assemble/message_handler.py @@ -0,0 +1,155 @@ +import json +import re +from typing import List + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode + +MESSAGE_VALIDATION_SCHEMA = { + "system": { + "role": str, + "content": str + }, + "assistant": { + "role": str, + "content": (type(None), str), + "function_call": (type(None), dict) + }, + "user": { + "role": str, + "content": str + }, + "function": { + "role": str, + "content": str, + "name": str + } +} + +EXTRA_VALIDATION_SCHEMA = { + "function_call": { + "name": str, + "arguments": str + } +} + + +def messages_to_template(messages: List[dict]) -> str: + """messages to template""" + template = "" + for message in messages: + if not isinstance(message, dict): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message="Each message in the template must be a dict" + ) + role = message.get("role") + validate_schema = MESSAGE_VALIDATION_SCHEMA.get(role) + if not validate_schema: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message=f"No validation schema found for message role `{role}`." + ) + validate(message, validate_schema) + content = message.get("content") + if not content: + content = "" + template += f"`#{role}#`\n{content}\n" + for extra_key in set(message.keys()) - {"role", "content"}: + if isinstance(message[extra_key], str): + extra_content = message[extra_key] + elif isinstance(message[extra_key], dict): + extra_content = json.dumps(message['function_call'], ensure_ascii=False) + else: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message="Cannot parse data into string" + ) + template += f"`*{extra_key}*`\n{extra_content}\n" + + return template + + +def template_to_messages(template: str) -> List[dict]: + """template to messages""" + messages = [] + message_prefix_matches = list(re.finditer(r'`#(system|assistant|user|tool|function)#`', template)) + for message_index, message_match in enumerate(message_prefix_matches): + message_content, message_prefix, validation_schema = get_message( + message_index, message_match, message_prefix_matches, template + ) + message = padding_message(message_prefix, message_content, validation_schema) + validate(message, validation_schema) + messages.append(message) + return messages + + +def get_message(message_index, message_match, message_prefix_matches, template): + """get message""" + message_prefix = message_match.group(1) + message_start = message_match.end() + if message_index < len(message_prefix_matches) - 1: + message_end = message_prefix_matches[message_index + 1].start() + else: + message_end = len(template) + message_content = template[message_start:message_end].strip() + validation_schema = MESSAGE_VALIDATION_SCHEMA.get(message_prefix) + if not validation_schema: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message=f"No validation schema found for message role `{message_prefix}`." + ) + return message_content, message_prefix, validation_schema + + +def validate(data: dict, schema: dict): + """validate data""" + if len(set(data.keys()) - set(schema.keys())) > 0: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message="Failed validate the data against the schema." + ) + for name, data_type in schema.items(): + if not isinstance(data.get(name), data_type): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message=f"Failed validate the data against the schema `{name}`." + ) + if name in EXTRA_VALIDATION_SCHEMA and data.get(name) is not None: + validate(data.get(name), EXTRA_VALIDATION_SCHEMA.get(name)) + + +def padding_message(message_prefix, message_content, validation_schema): + """message padding""" + key_role = "role" + key_content = "content" + message = { + key_role: message_prefix, + key_content: message_content + } + extra_fields_matches = list(re.finditer(r'`\*(name|function_call)\*`', message_content)) + for field_index, field_match in enumerate(extra_fields_matches): + field_name = field_match.group(1) + field_start = field_match.end() + if field_index < len(extra_fields_matches) - 1: + field_end = extra_fields_matches[field_index + 1].start() + else: + field_end = len(message_content) + field_content = message_content[field_start:field_end].strip() + try: + data_type = validation_schema.get(field_name) + if (isinstance(data_type, tuple) and dict in data_type) or dict == data_type: + field_content = json.loads(field_content) + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_TEMPLATE_FORMAT_ERROR.code, + message=f"Errors occur when parsing field `{field_name}` into dict." + ) from e + message[field_name] = field_content + validate(message, MESSAGE_VALIDATION_SCHEMA.get(message_prefix)) + if field_index == 0: + message[key_content] = message_content[:field_match.start()].strip() + if len(message[key_content]) == 0: + message[key_content] = None + return message + diff --git a/jiuwen/core/utils/prompt/assemble/variables/__init__.py b/jiuwen/core/utils/prompt/assemble/variables/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/assemble/variables/textable.py b/jiuwen/core/utils/prompt/assemble/variables/textable.py new file mode 100644 index 0000000..38bb726 --- /dev/null +++ b/jiuwen/core/utils/prompt/assemble/variables/textable.py @@ -0,0 +1,57 @@ +import re + +from jiuwen.core.common.logging.base import logger +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.utils.prompt.assemble.variables.variable import Variable + + +class TextableVariable(Variable): + """Variable class for processing string-type placeholders""" + def __init__(self, text: str, name: str = "default"): + clean_text = text + placeholders = [] + input_keys = [] + placeholder_matches = re.finditer(r"\{\{(.*?)\}\}", text) + for match in placeholder_matches: + placeholder = match.group(1).strip() + if len(placeholder) == 0: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_VARIABLE_INIT_ERROR.code, + message="Placeholders cannot be empty string" + ) + if placeholder not in placeholders: + placeholders.append(placeholder) + input_key = placeholder.split(".")[0] + if input_key not in input_keys: + input_keys.append(input_key) + clean_text = clean_text.replace(match.group(0), "{{"+ placeholder + "}}") + self.text = clean_text + self.placeholders = placeholders + super().__init__(name, input_keys=input_keys) + + def update(self, **kwargs): + """Replace placeholders in the text with passed-in key-values and update `self.value` + + Args: + **kwargs: arguments passed in as key-value pairs for updating the variable. + """ + formatted_text = self.text + for placeholder in self.placeholders: + value = kwargs + try: + for node in placeholder.split("."): + if isinstance(value, dict): + value = value.get(node) + else: + value = getattr(value, node) + except Exception as e: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_VARIABLE_INIT_ERROR.code, + message=f"Error parsing the placeholder `{placeholder}`." + ) from e + if not isinstance(value, (str, int, float, bool)): + logger.info(f"Converting non-string value `{placeholder}` using str()." + f" Please check if the style is describe.") + formatted_text = formatted_text.replace("{{" + placeholder + "}}", str(value)) + self.value = formatted_text diff --git a/jiuwen/core/utils/prompt/assemble/variables/variable.py b/jiuwen/core/utils/prompt/assemble/variables/variable.py new file mode 100644 index 0000000..469ed4f --- /dev/null +++ b/jiuwen/core/utils/prompt/assemble/variables/variable.py @@ -0,0 +1,45 @@ +from abc import abstractmethod +from typing import List, Optional + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode + + +class Variable: + """Base class for variable.""" + def __init__(self, name: str, input_keys: Optional[List] = None): + self.name = name + self.input_keys = input_keys + self.value = "" + + @abstractmethod + def update(self, **kwargs): + """update variable.""" + + def eval(self, **kwargs): + """Validate the input key-values, update `self.value`, perform selection (if there is), and return value. + Args: + **kwargs: input key-value pairs for validate the variable. + Returns: + str: updated value of variable. + """ + input_kwargs = self._prepare_inputs(**kwargs) + self.update(**input_kwargs) + return self.value + + def _prepare_inputs(self, **kwargs) -> dict: + """prepare input key-value pairs.""" + missing_keys = set(self.input_keys) - set(kwargs.keys()) + if missing_keys: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_INPUT_KEY_ERROR.code, + message=f"Missing keys for updating the variable {self.name}: {list(missing_keys)}" + ) + unexpected_keys = set(kwargs.keys()) - set(self.input_keys) + if unexpected_keys: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_ASSEMBLER_INPUT_KEY_ERROR.code, + message=f"Unexpected keys for updating the variable {self.name}: {list(unexpected_keys)}" + ) + input_kwargs = {k:v for k, v in kwargs.items() if k in self.input_keys} + return input_kwargs \ No newline at end of file diff --git a/jiuwen/core/utils/prompt/base.py b/jiuwen/core/utils/prompt/base.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/common/__init__.py b/jiuwen/core/utils/prompt/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/common/document.py b/jiuwen/core/utils/prompt/common/document.py new file mode 100644 index 0000000..27a399b --- /dev/null +++ b/jiuwen/core/utils/prompt/common/document.py @@ -0,0 +1,20 @@ +""" +embedding +""" +from __future__ import annotations + +from abc import ABC + +from pydantic import Field, BaseModel + + +class Document(BaseModel, ABC): + """ + Class for storing a piece of text and associated metadata. + + Args: + page_content (str): main content of the document. + metadata (dict, optional): arbitrary metadata associated with this document. + """ + page_content: str = Field(default="") + metadata: dict = Field(default={}) \ No newline at end of file diff --git a/jiuwen/core/utils/prompt/common/singleton.py b/jiuwen/core/utils/prompt/common/singleton.py new file mode 100644 index 0000000..ff620e4 --- /dev/null +++ b/jiuwen/core/utils/prompt/common/singleton.py @@ -0,0 +1,17 @@ +import abc +import threading + +singleton_lock = threading.Lock() + + +class Singleton(abc.ABCMeta, type): + """ + Singleton metaclass for ensuring only one instance of a class. + """ + _instances = {} + + def __call__(cls, *args, **kwargs): + with singleton_lock: + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] \ No newline at end of file diff --git a/jiuwen/core/utils/prompt/index/__init__.py b/jiuwen/core/utils/prompt/index/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/index/template_store/__init__.py b/jiuwen/core/utils/prompt/index/template_store/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/index/template_store/in_memory.py b/jiuwen/core/utils/prompt/index/template_store/in_memory.py new file mode 100644 index 0000000..49a84e7 --- /dev/null +++ b/jiuwen/core/utils/prompt/index/template_store/in_memory.py @@ -0,0 +1,65 @@ +""" +Interface for In memory index +""" +import threading + +from jiuwen.core.utils.prompt.common.document import Document + + +class TemplateId: + """ + TemplateId class + """ + def __init__(self, name: str, filter_data: dict = None): + self.name = name + self.filter_data = filter_data + + +class InMemory: + """ + InMemoryIndex class + """ + def __init__(self): + self.lock = threading.Lock() + self.cache = {} + + def add_document(self, record: Document, template_id: TemplateId = None): + """add_document""" + try: + self.lock.acquire(True) + self.cache[template_id.name] = record.metadata + return True + finally: + self.lock.release() + + def get_documents(self, template_id: TemplateId): + """get_document""" + try: + self.lock.acquire(True) + return self.cache.get(template_id.name, None) + finally: + self.lock.release() + + def update_document(self, template_id: TemplateId, data: Document): + """update_document""" + try: + self.lock.acquire(True) + if template_id.name in self.cache: + self.cache.pop(template_id.name) + self.cache[template_id.name] = data.metadata + else: + self.cache[template_id.name] = data.metadata + return True + finally: + self.lock.release() + + def delete_document(self, template_id: TemplateId): + """delete_document""" + try: + self.lock.acquire(True) + if template_id.name in self.cache: + self.cache.pop(template_id.name) + return True + return False + finally: + self.lock.release() \ No newline at end of file diff --git a/jiuwen/core/utils/prompt/index/template_store/in_memory_template_store.py b/jiuwen/core/utils/prompt/index/template_store/in_memory_template_store.py new file mode 100644 index 0000000..462bf89 --- /dev/null +++ b/jiuwen/core/utils/prompt/index/template_store/in_memory_template_store.py @@ -0,0 +1,72 @@ +"""In memory template store""" +from abc import ABC + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.utils.prompt.common.document import Document +from jiuwen.core.utils.prompt.index.template_store.in_memory import InMemory, TemplateId +from jiuwen.core.utils.prompt.index.template_store.template_store import TemplateStore, Template + + +class InMemoryTemplateStore(TemplateStore, ABC): + """In memory template store""" + def __init__(self): + self.index = InMemory() + + @staticmethod + def __get_memory_name(name: str, filters: dict): + return name + "".join("###" + filters[item] for item in filters if filters.get(item)) if filters else name + + def register_template(self, template: Template): + """register a template""" + in_memory_name = self.__get_memory_name(name=template.name, filters=template.filters) + template_id = TemplateId(in_memory_name, filter_data=template.filters) + if self.index.get_documents(template_id): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_TEMPLATE_DUPLICATED_ERROR.code, + message=f"Template: {template.name} is duplicated to register" + ) + return self.index.add_document( + record=Document(page_content='', metadata=self._convert_to_dict(template)), + template_id=template_id + ) + + def delete_template(self, name: str, filters: dict): + """delete a template""" + template_id = TemplateId(name, filter_data=filters) + if not self.index.get_documents(template_id): + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_TEMPLATE_NOT_FOUND_ERROR.code, + message=f"Template: {name} not found to delete" + ) + return self.index.delete_document(template_id) + + def update_template(self, template: Template, **kwargs): + """update a template""" + template_id = TemplateId( + name=self.__get_memory_name(template.name, filters=template.filters), + filter_data=template.filters + ) + return self.index.update_document( + template_id, + data=Document(page_content='', metadata=self._convert_to_dict(template)) + ) + + def search_template(self, name: str, filters: dict) -> Template: + """search a template""" + result = self.__get_document(name, filters) + if filters and not result: + result = self.__get_document(name, filters={"model_name": "default"}) + if not result: + raise JiuWenBaseException( + error_code=StatusCode.PROMPT_TEMPLATE_NOT_FOUND_ERROR.code, + message=StatusCode.PROMPT_TEMPLATE_NOT_FOUND_ERROR.errmsg.format(error_message=f"template name: {name}") + ) + return Template(name=result.get("name"), content=result.get("content")) + + def __get_document(self, name: str, filters: dict): + template_id = TemplateId( + name=self.__get_memory_name(name, filters), + filter_data=filters + ) + return self.index.get_documents(template_id) \ No newline at end of file diff --git a/jiuwen/core/utils/prompt/index/template_store/template_store.py b/jiuwen/core/utils/prompt/index/template_store/template_store.py new file mode 100644 index 0000000..d29b120 --- /dev/null +++ b/jiuwen/core/utils/prompt/index/template_store/template_store.py @@ -0,0 +1,40 @@ +""" +Interface for template index +""" +import copy +import json +from abc import ABC, abstractmethod + +from jiuwen.core.utils.prompt.template.template import Template + + +class TemplateStore(ABC): + """Template operation""" + @staticmethod + def _convert_to_dict(input_template: Template) -> dict[str, any]: + """convert template value to str""" + template = copy.deepcopy(input_template) + for attr, value in input_template.__dict__.items(): + if value is None: + setattr(template, attr, '') + elif isinstance(value, dict): + setattr(template, attr, json.dumps(value)) + else: + setattr(template, attr, str(value)) + return template.model_dump() + + @abstractmethod + def delete_template(self, name: str, filters: dict) -> bool: + """delete template by name""" + + @abstractmethod + def register_template(self, template: Template) -> bool: + """register template""" + + @abstractmethod + def search_template(self, name: str, filters: dict) -> Template: + """search template by name""" + + @abstractmethod + def update_template(self, template: Template, **kwargs) -> bool: + """update template by name""" diff --git a/jiuwen/core/utils/prompt/template/__init__.py b/jiuwen/core/utils/prompt/template/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jiuwen/core/utils/prompt/template/template.py b/jiuwen/core/utils/prompt/template/template.py new file mode 100644 index 0000000..7fc9f5f --- /dev/null +++ b/jiuwen/core/utils/prompt/template/template.py @@ -0,0 +1,46 @@ +from typing import Union, List, Dict + +from pydantic import BaseModel, Field + +from jiuwen.core.utils.llm.messages import BaseMessage +from jiuwen.core.utils.prompt.assemble.assembler import Assembler +from jiuwen.core.utils.prompt.assemble.message_handler import template_to_messages + + +class Template(BaseModel): + """ + template data + + """ + name: str + content: Union[List[Dict], List[BaseMessage], str] + filters: dict = Field(default=None) + + def to_messages(self) -> List[BaseMessage]: + """Return Template as a list of Messages.""" + messages = [] + if self.content is None or len(self.content) == 0: + self.content = [] + return messages + if isinstance(self.content, str): + self.content = template_to_messages(self._content) + + for msg in self.content: + if isinstance(msg, BaseMessage): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(BaseMessage(**msg)) + else: + pass + self.content = messages + return self.content + + def format(self, keywords: dict = None): + """format prompt""" + assembler = Assembler(self.content) + input_keys = assembler.input_keys + format_dict = {} + for key in input_keys: + if keywords and keywords.get(key): + format_dict[key] = keywords.get(key) + self.content = assembler.assemble(**format_dict) diff --git a/jiuwen/core/utils/prompt/template/template_manager.py b/jiuwen/core/utils/prompt/template/template_manager.py new file mode 100644 index 0000000..6a37a17 --- /dev/null +++ b/jiuwen/core/utils/prompt/template/template_manager.py @@ -0,0 +1,107 @@ +import copy +import os +from typing import Callable, List + +from jiuwen.core.common.logging.base import logger +from jiuwen.core.utils.prompt.assemble.message_handler import messages_to_template, template_to_messages +from jiuwen.core.utils.prompt.common.singleton import Singleton +from jiuwen.core.utils.prompt.index.template_store.in_memory_template_store import InMemoryTemplateStore +from jiuwen.core.utils.prompt.index.template_store.template_store import TemplateStore, Template + + +class TemplateManager(metaclass=Singleton): + """Template manager class""" + template_store: TemplateStore + __filter_func: Callable = None + + def __init__(self): + self.template_store = InMemoryTemplateStore() + self.__filter_func = default_template_filter + self.init_prompt_templates() + + @staticmethod + def load_from_dir(dir_path: str, suffix: str = ".pr") -> List[Template]: + """Read all templates from dir_path""" + files = os.listdir(dir_path) + templates = [] + files = [f for f in files if f.endswith(suffix)] + for template_file in files: + name = os.path.splitext(template_file)[0] + with open(os.path.join(dir_path, template_file), 'r', encoding='utf-8') as f: + content = f.read() + templates.append(Template(name=name, content=content)) + return templates + + def format(self, keywords, template_name, filters: dict = None) -> Template: + template = self.get(template_name, filters=filters) if template_name else None + template.format(keywords) + return template + + def init_prompt_templates(self): + self.__init_customer_templates() + self.__init_default_templates() + + def __init_customer_templates(self): + """init customer templates""" + customer_templates_path = os.environ.get("PROMPT_DEFAULT_TEMPLATES_PATH", None) + if not customer_templates_path: + logger.warn("Customer templates path is not set") + return + self.__load_default_templates_dir(customer_templates_path) + + def __init_default_templates(self): + """init default templates""" + dir_path = os.path.join(os.path.dirname(__file__), "../resource") + self.__load_default_templates_dir(dir_path) + + def __load_default_templates_dir(self, root_dir: str): + """__load_default_templates_dir""" + templates_path = os.path.join(root_dir, "templates") + files = os.listdir(templates_path) + for file_name in files: + template_path = os.path.join(templates_path, file_name) + if not os.path.isdir(template_path): + continue + templates = self.load_from_dir(template_path) + for template in templates: + template.filters = dict(model_name=file_name) + if not self.register(template=template, force=True): + return ValueError("Invalid template to register") + + def register(self, template: Template, force: bool = False): + """register template""" + template_copy = copy.deepcopy(template) + if isinstance(template_copy.content, list): + template_copy.content = messages_to_template(template_copy.content) + if force: + return self.template_store.update_template(template_copy) + return self.template_store.register_template(template_copy) + + def get(self, name: str, filters: dict = None) -> Template: + """query prompt template by template name""" + all_filters = self.__filter_func(filters) + result_template = self.template_store.search_template(name, filters=all_filters) + if result_template: + messages = template_to_messages(result_template.content) + if messages: + result_template.content = messages + return result_template + + def delete(self, name: str, filters: dict = None): + """delete prompt template by template name""" + all_filters = self.__filter_func(filters) + return self.template_store.delete_template(name, filters=all_filters) + + def register_in_bulk(self, dir_path:str) -> bool: + """template register with bulk template from specified dir_path""" + if not os.path.isdir(dir_path): + raise NotADirectoryError("dir path is not a folder") + templates = self.load_from_dir(dir_path) + results = [] + for template in templates: + results.append(self.register(template=template, force=True)) + return all(results) + +def default_template_filter(default_filters) -> dict: + """default template filter""" + return default_filters \ No newline at end of file diff --git a/tests/unit_tests/prompt/__init__.py b/tests/unit_tests/prompt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/prompt/data/intent_recognition.pr b/tests/unit_tests/prompt/data/intent_recognition.pr new file mode 100644 index 0000000..933bb52 --- /dev/null +++ b/tests/unit_tests/prompt/data/intent_recognition.pr @@ -0,0 +1,3 @@ +#角色:场景识别助手 +以下是用户的问题: {{query}} +注意:只输出'是'或'否',不要回复对于内容 \ No newline at end of file diff --git a/tests/unit_tests/prompt/data/prompts.yaml b/tests/unit_tests/prompt/data/prompts.yaml new file mode 100644 index 0000000..72fee17 --- /dev/null +++ b/tests/unit_tests/prompt/data/prompts.yaml @@ -0,0 +1,5 @@ +summary: + description: "test template for summary text" + +intent_recognition: + description: "template for judge intent" \ No newline at end of file diff --git a/tests/unit_tests/prompt/data/summary.pr b/tests/unit_tests/prompt/data/summary.pr new file mode 100644 index 0000000..80449a5 --- /dev/null +++ b/tests/unit_tests/prompt/data/summary.pr @@ -0,0 +1,2 @@ +你是一个文本总结高手,{{command}}, +{{info}} \ No newline at end of file diff --git a/tests/unit_tests/prompt/test_message_handler.py b/tests/unit_tests/prompt/test_message_handler.py new file mode 100644 index 0000000..af64c34 --- /dev/null +++ b/tests/unit_tests/prompt/test_message_handler.py @@ -0,0 +1,80 @@ +import json +import unittest + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.utils.prompt.assemble.message_handler import validate, messages_to_template, template_to_messages + + +class TestMessageHandler(unittest.TestCase): + def test_validate_valid_data(self): + schema = {"role": str, "content": str} + data = {"role": "user", "content": "test"} + try: + validate(data, schema) + except JiuWenBaseException as e: + self.fail("validate raised JiuWenBaseException unexpectedly") + + def test_validate_extra_key(self): + schema = {"role": str} + data = {"role": "user", "content": "test"} + with self.assertRaises(JiuWenBaseException): + validate(data, schema) + + def test_validate_wrong_type(self): + schema = {"role": str} + data = {"role": 234} + with self.assertRaises(JiuWenBaseException): + validate(data, schema) + + def test_validate_nested_schema(self): + data = {"function_call": {"name": "func", "arguments": '{}'}} + schema = {"function_call": dict} + extra_schema = {"name": str, "arguments": str} + try: + validate(data, schema) + validate(data["function_call"], extra_schema) + except JiuWenBaseException: + self.fail("nested validation raised JiuWenBaseException unexpectedly") + + def test_messages_to_template_basic(self): + messages = [{"role": "user", "content": "Hello"}] + expected = "`#user#`\nHello\n" + self.assertEqual(messages_to_template(messages), expected) + + def test_messages_to_template_with_function_call(self): + messages = [{ + "role": "assistant", "content": None, + "function_call": {"name": "func", "arguments": '{}'} + }] + expected = "`#assistant#`\n\n`*function_call*`\n" + json.dumps( + {"name": "func", "arguments": '{}'}, ensure_ascii=False + ) + "\n" + self.assertEqual(messages_to_template(messages), expected) + + def test_messages_to_template_invalid_role(self): + messages = [{"role": "invalid", "content": "test"}] + with self.assertRaises(JiuWenBaseException): + messages_to_template(messages) + + def test_template_to_messages_simple(self): + template = "`#user#`\nHello\n`#system#`\nWelcome\n" + messages = template_to_messages(template) + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0], {"role": "user", "content": "Hello"}) + self.assertEqual(messages[1], {"role": "system", "content": "Welcome"}) + + def test_template_to_messages_with_function_call(self): + template = ("`#assistant#`\n\n`*function_call*`\n" + '{"name": "func", "arguments": "{}"}\n') + messages = template_to_messages(template) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["role"], "assistant") + self.assertIsNone(messages[0]["content"]) + self.assertEqual(messages[0]["function_call"], {"name": "func", "arguments": "{}"}) + + def test_template_to_messages_invalid_json(self): + template = ("`#assistant#`\n\n`*function_call*`\n" + 'invalid json\n') + with self.assertRaises(JiuWenBaseException): + template_to_messages(template) + diff --git a/tests/unit_tests/prompt/test_output_parser.py b/tests/unit_tests/prompt/test_output_parser.py new file mode 100644 index 0000000..4f17220 --- /dev/null +++ b/tests/unit_tests/prompt/test_output_parser.py @@ -0,0 +1,43 @@ +import unittest + +from jiuwen.core.utils.output_parser.novel_tool_output_parser import NovelToolOutputParser + + +class TestOutputParser(unittest.TestCase): + def setUp(self): + self.parser = NovelToolOutputParser() + + def test_parser_valid_function_call(self): + llm_output = { + "message": { + "content": "[my_function(param1='value1', param2=123, param3=True)]" + } + } + expected_output = { + 'message': + { + 'content': '', + 'llm_content': "[my_function(param1='value1', param2=123, param3=True)]", + 'tool_calls': + [ + { + 'function': { + 'arguments': '{"param1=\'value1\'": "param1", "param2=123": "param2", "param3=True": "param3"}', + 'name': 'my_function' + }, + 'type': 'function' + } + ] + } + } + output = self.parser.parse(llm_output) + + self.assertEqual(output, expected_output) + + def test_parser_invalid_function_call(self): + llm_output = { + "message": { + "content": "invalid_function_call_string" + } + } + self.assertEqual(self.parser.parse(llm_output), llm_output) \ No newline at end of file diff --git a/tests/unit_tests/prompt/test_template_assemble.py b/tests/unit_tests/prompt/test_template_assemble.py new file mode 100644 index 0000000..359a81d --- /dev/null +++ b/tests/unit_tests/prompt/test_template_assemble.py @@ -0,0 +1,204 @@ +import unittest + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.utils.llm.messages import BaseMessage +from jiuwen.core.utils.prompt.assemble.variables.variable import Variable +from jiuwen.core.utils.prompt.template.template import Assembler +from jiuwen.core.utils.prompt.assemble.variables.textable import TextableVariable +from jiuwen.core.utils.prompt.index.template_store.template_store import Template + + +class TestPromptAssemble(unittest.TestCase): + def test_textable_variable(self): + self.assertRaises(JiuWenBaseException, TextableVariable, text="{{}}") + var1 = TextableVariable(text="{{x}}") + self.assertEqual(["x"], var1.input_keys) + self.assertEqual("default", var1.name) + + var2 = TextableVariable(text="{{x}}{{y}}") + self.assertEqual(["x", "y"], var2.input_keys) + self.assertEqual("12", var2.eval(x="1", y="2")) + self.assertEqual("12", var2.value) + self.assertRaises(JiuWenBaseException, var2.eval, x=1, y=2, z=3) + + def test_textable_variables(self): + self.assertRaises(JiuWenBaseException, TextableVariable, text="{{}}") + var1 = TextableVariable(text="{{x}}") + self.assertEqual(["x"], var1.input_keys) + self.assertEqual("default", var1.name) + + var2 = TextableVariable(text="{{x}}{{y}}") + self.assertEqual({"x", "y"}, set(var2.input_keys)) + self.assertRaises(JiuWenBaseException, var2.eval, x=1, y=2, z=3) + self.assertEqual("12", var2.eval(x="1", y="2")) + self.assertEqual("12", var2.value) + + def test_initialization(self): + text = "You're an expert in the domain of {{domain}}" + var = TextableVariable(text=text, name="role") + self.assertEqual(text, var.text) + self.assertEqual("role", var.name) + self.assertEqual(["domain"], var.input_keys) + self.assertEqual(["domain"], var.placeholders) + + text = "Hello, {{user.name}}" + var = TextableVariable(text=text) + self.assertEqual(["user"], var.input_keys) + self.assertEqual(["user.name"], var.placeholders) + + text = "Hello, {{}}!" + with self.assertRaises(JiuWenBaseException): + TextableVariable(text=text) + + def test_update(self): + text = "You're an expert in the domain of {{domain}}." + var = TextableVariable(text=text) + var.update(domain="science") + self.assertEqual("You're an expert in the domain of science.", var.value) + + text = "This value is {{value}}." + var = TextableVariable(text=text) + var.update(value=42) + self.assertEqual("This value is 42.", var.value) + + def test_eval(self): + text = "You're an expert in the domain of {{domain}}." + var = TextableVariable(text=text) + result = var.eval(domain="science") + self.assertEqual("You're an expert in the domain of science.", result) + + text = "Hello, {{user.name}}!" + var = TextableVariable(text=text) + result = var.eval(user={"name": "Alice"}) + self.assertEqual("Hello, Alice!", result) + + text = "Hello, {{name}}!" + var = TextableVariable(text=text) + with self.assertRaises(JiuWenBaseException): + var.eval(wrong_key="Alice") + + def test_variable_initialization(self): + var = Variable(name="test_var", input_keys=["key1", "key2"]) + self.assertEqual("test_var", var.name) + self.assertEqual(["key1", "key2"], var.input_keys) + self.assertEqual(var.value, "") + + var = Variable(name="test_var", input_keys=None) + self.assertIsNone(var.input_keys) + + def test_prepare_inputs(self): + var = Variable(name="test_var", input_keys=["key1", "key2"]) + + input_kwargs = var._prepare_inputs(key1="value1", key2="value2") + self.assertEqual({"key1": "value1", "key2": "value2"}, input_kwargs) + + with self.assertRaises(JiuWenBaseException): + var._prepare_inputs(key1="value1") + + with self.assertRaises(JiuWenBaseException): + var._prepare_inputs(key1="value1", key2="value2", key3="value3") + + def test_variable_eval(self): + class MockVariable(Variable): + def update(self, **kwargs): + self.value = kwargs.get("key1", "") + kwargs.get("key2", "") + + var = MockVariable(name="test_var", input_keys=["key1", "key2"]) + + result = var.eval(key1="value1", key2="value2") + self.assertEqual("value1value2", result) + + with self.assertRaises(JiuWenBaseException): + var.eval(key1="value1") + + with self.assertRaises(JiuWenBaseException): + var.eval(key1="value1", key2="value2", key3="value3") + + def test_assemble(self): + asm1 = Assembler( + template_content="`#system#`{{role}}`#user#`{{memory}}", + role=TextableVariable(text="你是一个精通{{domain}}领域的问答助手。") + ) + self.assertEqual({"domain", "memory"}, set(asm1.input_keys)) + self.assertIsInstance(asm1.assemble(memory=[{"role": "user", "content": "我是谁"}], domain="科学"), list) + + asm2 = Assembler( + template_content="`#assistant#`消息内容`*function_call*`{\"name\":\"func1\", \"arguments\":\"x\"}" + ) + self.assertEqual([ + {"role": "assistant", "content": "消息内容", "function_call": {"name": "func1", "arguments": "x"}}, + ], asm2.assemble()) + + asm3 = Assembler( + template_content="`#assistant#`消息内容`*function_call*`{\"name\":\"func1\", \"arguments\":1}" + ) + self.assertRaises(JiuWenBaseException, asm3.assemble) + + asm4 = Assembler( + template_content="`#assistant#`消息内容`*function_call*`{\"name\":\"func1\", \"arguments\":\"x\"," + "\"extra\":\"y\"}" + ) + self.assertRaises(JiuWenBaseException, asm4.assemble) + + asm5 = Assembler( + template_content=[ + {"role": "system", "content": "{{role}}"}, + {"role": "user", "content": "{{user_inputs}}"}, + {"role": "assistant", "content": "None", "function_call": {"name": "func1", "arguments": "x"}}, + {"role": "function", "content": "result of function call", "name": "func1"} + ], + role=TextableVariable(text="你是一个精通{{domain}}领域的问答助手"), + user_inputs=TextableVariable(text="问题: {{query}}\n答案:") + ) + self.assertEqual({"domain", "query"}, set(asm5.input_keys)) + self.assertEqual([ + {"role": "system", "content": "你是一个精通科学领域的问答助手"}, + {"role": "user", "content": "问题: 牛顿第三定律\n答案:"}, + {"role": "assistant", "content": "None", "function_call": {"name": "func1", "arguments": "x"}}, + {"role": "function", "content": "result of function call", "name": "func1"} + ], asm5.assemble(domain="科学", query="牛顿第三定律")) + + asm6 = Assembler( + template_content=""" +`#system#`this is a system message +`#assistant#`calling function ... `*function_call*`{"name": "search", "arguments": "{'x': [1,2,3], 'y': '2'}"} +`#function#`this is the result of the function `*name*` search +`#user#`ok""" + ) + self.assertEqual([ + {"role": "system", "content": "this is a system message"}, + {"role": "assistant", "content": "calling function ...", "function_call": + {"name": "search", "arguments": "{'x': [1,2,3], 'y': '2'}"}}, + {"role": "function", "content": "this is the result of the function", "name": "search"}, + {"role": "user", "content": "ok"} + ], asm6.assemble()) + + def test_template_format(self): + template = Template( + name="test", + content="`#system#`你是一个精通{{domain}}领域的问答助手`#user#`{{memory}}") + template.format({"memory": [{"role": "user", "content": "你是谁"}], "domain": "数学"}) + self.assertEqual( + template.to_messages(), + [ + BaseMessage(**{"role": "system", "content": "你是一个精通数学领域的问答助手"}), + BaseMessage(**{"role": "user", "content": "[{'role': 'user', 'content': '你是谁'}]"}) + ] + ) + + template2 = Template( + name="xxx", + content=""" +`#system#`this is a system message +`#assistant#`calling function ... `*function_call*`{"name": "search", "arguments": "{'x': [1,2,3], 'y': '2'}"} +`#function#`this is the result of the function `*name*` search +`#user#`ok""" + ) + template2.format() + self.assertEqual([ + BaseMessage(**{"role": "system", "content": "this is a system message"}), + BaseMessage(**{"role": "assistant", "content": "calling function ...", "function_call": + {"name": "search", "arguments": "{'x': [1,2,3], 'y': '2'}"}}), + BaseMessage(**{"role": "function", "content": "this is the result of the function", "name": "search"}), + BaseMessage(**{"role": "user", "content": "ok"}) + ], template2.to_messages()) diff --git a/tests/unit_tests/prompt/test_template_manager.py b/tests/unit_tests/prompt/test_template_manager.py new file mode 100644 index 0000000..38188c1 --- /dev/null +++ b/tests/unit_tests/prompt/test_template_manager.py @@ -0,0 +1,60 @@ +import os +import unittest + +from jiuwen.core.common.exception.exception import JiuWenBaseException +from jiuwen.core.common.exception.status_code import StatusCode +from jiuwen.core.utils.llm.messages import BaseMessage +from jiuwen.core.utils.prompt.index.template_store.template_store import Template +from jiuwen.core.utils.prompt.template.template_manager import TemplateManager + + +class TestTemplateManager(unittest.TestCase): + def test_template_register_in_bulk(self): + dir_path = os.path.join(os.path.dirname(__file__), "data/") + TemplateManager().register_in_bulk(dir_path=dir_path) + self.assertEqual( + TemplateManager().get(name="summary").content, + "你是一个文本总结高手,{{command}},\n{{info}}" + ) + + self.assertEqual( + TemplateManager().get(name="intent_recognition").content, + "#角色:场景识别助手\n以下是用户的问题: {{query}}\n注意:只输出'是'或'否',不要回复对于内容" + ) + + def test_template_consistency(self): + template = Template(name="test_template_consistent", content=[{"role":"system", "content": "here is a test"}]) + TemplateManager().register(template=template, force=True) + self.assertEqual(isinstance(template.content, list), True) + self.assertEqual( + TemplateManager().get(name="test_template_consistent").content, + [{"role": "system", "content": "here is a test"}] + ) + + template = Template(name="test_template_consistent", content="here is a test") + TemplateManager().register(template=template, force=True) + self.assertEqual( + TemplateManager().get(name="test_template_consistent").content, + "here is a test" + ) + + TemplateManager().delete(name="test_template_consistent") + try: + TemplateManager().get(name="test_template_consistent") + except JiuWenBaseException as e: + self.assertEqual(e.error_code, StatusCode.PROMPT_TEMPLATE_NOT_FOUND_ERROR.code) + + def test_template_manager_format(self): + template = Template( + name="test_template_manager_format", + content="`#system#`你是一个精通{{domain}}领域的问答助手`#user#`{{memory}}") + TemplateManager().register(template=template, force=True) + keyword = {"memory": [{"role": "user", "content": "你是谁"}], "domain": "数学"} + template = TemplateManager().format(keyword, "test_template_manager_format") + self.assertEqual( + template.to_messages(), + [ + BaseMessage(**{"role": "system", "content": "你是一个精通数学领域的问答助手"}), + BaseMessage(**{"role": "user", "content": "[{'role': 'user', 'content': '你是谁'}]"}) + ] + ) -- Gitee