diff --git a/mxAgent/README.md b/mxAgent/README.md deleted file mode 100644 index a5f081ebe83975d2f141f90cc7b0ddb32398c6b9..0000000000000000000000000000000000000000 --- a/mxAgent/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# mxAgent: 基于工具调用的多模式LLM Agent框架 -**mxAgent**是一个基于LLMs的通用Agent框架,应用多种框架解决不同场景和复杂度的问题,并通过工具调用的方式允许LLMs与外部源进行交互来获取信息,使LLMs生成更加可靠和实际。mxAgent通过构建DAG(Directed Acyclic Graph)的方式建立工具之间的依赖关系,通过并行执行的方式,提高多工具执行的效率,缩短Agent在复杂场景的执行时间。mxAgent框架还在框架级别支持流式输出 -提供一套Agent实现框架,让用户可以通过框架搭建自己的Agent应用 -## Router Agent -提供意图识别的能力,用户可预设意图的分类,通过Router Agent给出具体问题的分类结果,用于设别不同的问题场景。 -## Recipe Agent -设置复杂问题执行的workflow,在解决具体问题时,将workflow翻译成有向无环图的的节点编排,通过并行的方式执行节点。 -适用于有相对固定workflow的复杂问题场景。 -1)通过自然语言描述复杂问题的workflow, -2)workflow中每一个步骤对应一次工具使用,并描述步骤间关系 -3)recipe Agent将按照workflow的指导完成工具调用 -4)使用模型总结工作流结果,解决复杂问题 -Recipe Agent利用用户所提供的流程指导和工具,使用LLMs生成SOP,并构建DAG图描述Steps之间的依赖关系。agent识别那些可并行的step,通过并行执行提高agent的执行效率。 -使用Recipe Agent,仅需要提供一段解决问题的SOP指导、用于提示最终答案生成的final prompt、解决问题可能使用的工具。 -示例见[travelagent.py](./travel_agent/travelagent.py)运行方式如下: -``` -cd mxAgent -export PYTHONPATH=. -python samples/travel_agent/travelagent.py -``` - -## ReAct Agent -使用Thought、Action、Action Input、Observation的循环流程,解决复杂问题: -1)ReAct通过大模型思考并给出下一步的工具调用, -2)执行工具调用,得到工具执行结果 -3)将工具执行结果应用于下一次的模型思考 -4)循环上述过程,直到模型认为问题得到解决 -## Single Action Agent - -通过模型反思、调用工具执行,总结工具结果的执行轨迹,完成一次复杂问题的处理。Single Action Agent使用一次工具调用帮助完成复杂问题解决 -使用示例: -``` -cd mxAgent -export PYTHONPATH=. -python samples/traj_generate_test.py -``` - -## \ No newline at end of file diff --git a/mxAgent/agent_sdk/llms/llm.py b/mxAgent/agent_sdk/llms/llm.py index a8d5da6e8f6d24aba8dbdd83bf4f1f52eb006482..953f1c47efbc66d3a34ced08fdc41fef6ffb3b50 100644 --- a/mxAgent/agent_sdk/llms/llm.py +++ b/mxAgent/agent_sdk/llms/llm.py @@ -6,8 +6,8 @@ from .openai_compatible import OpenAICompatibleLLM BACKEND_OPENAI_COMPATIBLE = 1 -def get_llm_backend(backend, api_base, api_key, llm_name): +def get_llm_backend(backend, base_url, api_key, llm_name): if backend == BACKEND_OPENAI_COMPATIBLE: - return OpenAICompatibleLLM(api_base, api_key, llm_name) + return OpenAICompatibleLLM(base_url, api_key, llm_name) else: raise Exception(f"not support backend: {backend}") \ No newline at end of file diff --git a/mxAgent/agent_sdk/requirements.txt b/mxAgent/agent_sdk/requirements.txt deleted file mode 100644 index c96356599ebcaaa40b690e10568a5fc507e5b07d..0000000000000000000000000000000000000000 --- a/mxAgent/agent_sdk/requirements.txt +++ /dev/null @@ -1,19 +0,0 @@ -requests=2.27.1 -tqdm -selenium=4.9.0 -bs4 -transformers -openai -pandas -datasets -peft -fschat -langchain -vllm -rouge -langchain_openai -colorlog -rouge-score -langchain-community -loguru -tiktoken \ No newline at end of file diff --git a/mxAgent/samples/basic_demo/agent_test.py b/mxAgent/samples/basic_demo/agent_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc0e664d4a37da296fe8aa14808b28fa20fbbc8 --- /dev/null +++ b/mxAgent/samples/basic_demo/agent_test.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import warnings +import os + +from loguru import logger + +from agent_sdk.agentchain.react_agent import ReactAgent, ReactReflectAgent +from agent_sdk.agentchain.tool_less_agent import ToollessAgent +from agent_sdk.llms.llm import get_llm_backend, BACKEND_OPENAI_COMPATIBLE +from samples.tools import QueryAccommodations, QueryAttractions, \ + QueryGoogleDistanceMatrix, QueryTransports, Finish + +warnings.filterwarnings('ignore') + +API_BASE = os.environ.get("OPENAI_API_BASE", "http://10.44.115.98:8006/v1") +API_KEY = os.environ.get("OPENAI_API_KEY", "EMPTY") +LLM_NAME = os.environ.get("MODEL_NAME", "Qwen2-7b-Instruct") + +MAX_CONTEXT_LEN = 4096 + + +EXAMPLE = ''' +Question: Can you help with a 5 day trip from Orlando to New York? Departure date is March 10, 2022. +Thought: To create a travel itinerary, I need to find accommodations, transportation, and attractions in New York. I will first find hotels in New York. +Action: QueryAccommodations +Action Input: {"destination_city": "New York", "position": "Central Park", "rank": "four stars"} +Observation1: [{"title": "紐約市10 大最佳四星級酒店 - Tripadvisor", "url": "https://www.tripadvisor.com.hk/Hotels-g60763-zfc4-New_York_City_New_York-Hotels.html", "snippet": "紐約市四星級酒店 · 1. Moxy NYC Times Square · 3,825 則評論 · 2. 格甚溫酒店 · 1,155 則評論 · 3. 托米哈德森廣場飯店 · 3,277 則評論 · 4. 時代廣場愛迪生酒店 · 5. Hard ..."}, {"title": "中央公園酒店| 人氣優惠及套餐", "url": "https://www.agoda.com/zh-hk/park-central-hotel/hotel/new-york-ny-us.html", "snippet": "中央公園酒店是一家位於紐約市的4.0星級酒店,提供豪華的住宿體驗。酒店於2013年進行了最後一次翻新,確保客人能夠享受現代化的設施和舒適的環境。酒店擁有761間客房,提供 ..."}, {"title": "紐約中央公園艾美酒店(Le Méridien New York, Central Park)", "url": "https://www.agoda.com/zh-hk/viceroy-central-park-new-york_2/hotel/new-york-ny-us.html", "snippet": "紐約中央公園艾美酒店位於美國紐約市,是一家四星級酒店。這家酒店提供240間客房 ... 作為一家五星級酒店,紐約中央公園艾美酒店提供優質的服務和舒適的住宿環境 ..."}] +Thought: Now that I have found some hotels in New York, I will next find transportation options from Orlando to New York for the travel dates specified. I will use the QueryTransports API for this purpose. +Action: QueryTransports +Action Input: {"departure_city": "Orlando", "destination_city": "New York", "date": "2022-03-10", "requirement": "budget-friendly"} +Observation2: [{"title": "从奥兰多出发前往纽约的特价机票,往返 ...", "url": "https://www.tianxun.com/routes/orlb/nyca/orlando-to-new-york.html", "snippet": "查找从奥兰多飞往纽约最便宜的月份 ; 3月. ¥750 起 ; 4月. ¥927 起 ; 5月. ¥1,012 起 ; 6月. ¥1,107 起 ; 7月. ¥1,283 起."}, {"title": "从奥兰多国际出发前往纽约拉瓜迪亚的特价机票,往返 ...", "url": "https://www.tianxun.com/routes/mco/lga/orlando-international-to-new-york-laguardia.html", "snippet": "我们始终关注票价,方便您在几秒钟内锁定好价机票。目前,从奥兰多国际飞往纽约拉瓜迪亚最便宜的月份是十一月。"}] +Thought: Now that I have found some transportation options, I will proceed to find attractions in New York. I will use the QueryAttractions API to search for tourist attractions in New York. +Action: QueryAttractions +Action Input: {"destination": "New York"} +Observation3: - snippet: 從參觀紐約市最知名觀光景點開始:時代廣場、帝國大廈、自由女神像,然後利用剩餘的時間探索周圍地區推薦去處。 探索熱門體驗. 根據評等和預訂次數,看看其他旅客喜歡從事 ... + title: 紐約市10 大最佳旅遊景點(2024) - Tripadvisor + url: https://www.tripadvisor.com.hk/Attractions-g60763-Activities-New_York_City_New_York.html +- snippet: 紐約景點推薦 · 紐約景點#1 紐約中央公園 · 紐約景點#2 范德堡一號大樓 SUMMIT · 紐約景點#3 第五大道(Fifth Avenue) + · 紐約景點#4 大都會藝術博物館The ... + title: 【2024紐約景點】漫遊曼哈頓!26個必去行程&免費景點整理 + url: https://www.klook.com/zh-TW/blog/new-york-must-go/ +- snippet: 【紐約NewYork景點推薦】紐約「10大必去」打卡景點整理懶人包 · 紐約NewYork景點推薦-10大必去景點 · 1.中央公園(Central + Park) · 2.第五大道(Fifth Avenue) · 3.大都會 ... + title: 【紐約NewYork景點推薦】紐約「10大必去」打卡景點整理懶人包 + url: https://schoolaplus.com/articles-detail.asp?seq=35 +Thought: Now that I have found some attractions in New York, I will summarize the information and create a travel itinerary for the 5-day trip. I will use the Finish tool to provide the final answer. +Action: Finish +Action Input: {"plan details": "Day 1: Depart from Orlando to New York on March 10, 2022. Stay at the Park Central Hotel in Central Park. Visit the Empire State Building and Times Square. Have dinner at Lombardi's Pizza. +Day 2: Visit Central Park, the Metropolitan Museum of Art, and the American Museum of Natural History. Have lunch at Shake Shack and dinner at Le Pain Quotidien. +Day 3: Explore the Brooklyn Bridge, Brooklyn Heights, and DUMBO. Have lunch at Di Fara Pizza and dinner at Peter Luger Steak House. +Day 4: Visit the Statue of Liberty and Ellis Island. Have lunch at The Boil and dinner at Xi'an Famous Foods. +Day 5: Spend the day shopping on Fifth Avenue and visiting the Rockefeller Center. Have lunch at Shake Shack and dinner at Katz's Delicatessen."} +''' + + +def get_default_react_agent(api_base, api_key, llm_name, max_context_len): + llm = get_llm_backend(BACKEND_OPENAI_COMPATIBLE, api_base, api_key, llm_name).run + + tool_list = [QueryAccommodations, QueryTransports, QueryGoogleDistanceMatrix, QueryAttractions, Finish] + + agent = ReactAgent(llm=llm, tool_list=tool_list, max_context_len=max_context_len) + return agent + + +def get_default_react_agent_fewshot(api_base, api_key, llm_name, max_context_len): + llm = get_llm_backend(BACKEND_OPENAI_COMPATIBLE, api_base, api_key, llm_name).run + + tool_list = [QueryAccommodations, QueryTransports, QueryGoogleDistanceMatrix, QueryAttractions, Finish] + + agent = ReactAgent(llm=llm, example=EXAMPLE, tool_list=tool_list, max_context_len=max_context_len) + return agent + + +def get_default_toolless_agent(api_base, api_key, llm_name, max_context_len): + llm = get_llm_backend(BACKEND_OPENAI_COMPATIBLE, api_base, api_key, llm_name).run + + agent = ToollessAgent(llm=llm, max_context_len=max_context_len) + return agent + + +def get_default_react_reflect_agent(api_base, api_key, llm_name, max_context_len): + llm = get_llm_backend(BACKEND_OPENAI_COMPATIBLE, api_base, api_key, llm_name).run + + tool_list = [QueryAccommodations, QueryTransports, QueryGoogleDistanceMatrix, QueryAttractions, Finish] + agent = ReactReflectAgent(reflect_llm=llm, react_llm=llm, example=EXAMPLE, + tool_list=tool_list, max_context_len=max_context_len) + return agent + + +def test_react_agent(): + a = get_default_react_agent_fewshot(API_BASE, API_KEY, LLM_NAME, MAX_CONTEXT_LEN) + response = a.run("Can you help with a 5 day trip from Orlando to Paris? Departure date is April 10, 2022.") + + logger.info(f"5 day trip from Orlando to Paris:{response.answer}") + + +def test_toolless_agent(): + a = get_default_toolless_agent(API_BASE, API_KEY, LLM_NAME, MAX_CONTEXT_LEN) + response = a.run("Can you help with a 5 day trip from Orlando to Paris? Departure date is April 10, 2022.", + text="given information") + + logger.info(f"5 day trip from Orlando to Paris:{response.answer}") + + +def test_react_reflect_agent(): + a = get_default_react_reflect_agent(API_BASE, API_KEY, LLM_NAME, MAX_CONTEXT_LEN) + response = a.run("Can you help with a 5 day trip from Orlando to Paris? Departure date is April 10, 2022.", + text="given information") + + logger.info(f"5 day trip from Orlando to Paris:{response.answer}") + + +if __name__ == '__main__': + logger.info("react agent test begin") + test_react_agent() + logger.info("react agent test end") + + logger.info("toolless agent test begin") + test_toolless_agent() + logger.info("toolless agent test end") + + logger.info("react reflect agent test begin") + test_react_reflect_agent() + logger.info("react reflect agent test end") + + + + + + + diff --git a/mxAgent/samples/basic_demo/agent_traj_systhesis.py b/mxAgent/samples/basic_demo/agent_traj_systhesis.py new file mode 100644 index 0000000000000000000000000000000000000000..ced5ad1ae2d1ff1be861b7808604ea320de586d5 --- /dev/null +++ b/mxAgent/samples/basic_demo/agent_traj_systhesis.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import os +import warnings + +from langchain._api import LangChainDeprecationWarning +from loguru import logger +from tqdm import tqdm + +from agent_sdk.agentchain.react_agent import ReactAgent +from agent_sdk.common.enum_type import AgentRunStatus +from agent_sdk.llms.llm import get_llm_backend, BACKEND_OPENAI_COMPATIBLE +from samples.tools import QueryAttractions, QueryTransports, QueryAccommodations, \ + QueryRestaurants, QueryGoogleDistanceMatrix +from mxAgent.samples.basic_demo.agent_test import EXAMPLE + + +warnings.filterwarnings('ignore') +warnings.filterwarnings('ignore', category=DeprecationWarning) +warnings.filterwarnings('ignore', category=LangChainDeprecationWarning) + +os.environ["WORKNING_DIR"] = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +API_BASE = os.environ.get("OPENAI_API_BASE", "http://10.44.115.98:8006/v1") +API_KEY = os.environ.get("OPENAI_API_KEY", "EMPTY") +LLM_NAME = os.environ.get("MODEL_NAME", "Qwen2-7b-Instruct") + +MAX_CONTEXT_LEN = 4096 + + +def get_default_react_agent(api_base, api_key, llm_name, max_context_len): + llm = get_llm_backend(BACKEND_OPENAI_COMPATIBLE, api_base, api_key, llm_name).run + tool_list = [QueryAttractions, QueryTransports, QueryAccommodations, QueryRestaurants, QueryGoogleDistanceMatrix] + return ReactAgent(llm=llm, example=EXAMPLE, tool_list=tool_list, max_context_len=max_context_len) + + +if __name__ == '__main__': + agent = get_default_react_agent(API_BASE, API_KEY, LLM_NAME, MAX_CONTEXT_LEN) + + queries = [ + "Book a rental car for two people in Salt Lake City from April 15 to April 18, 2022.", + "Research and list down outdoor activities suitable for adrenaline junkies in Moab \ +between April 12 and 14, 2022.", + "Write a short itinerary for a weekend trip to Nashville, starting on April 15, including live music venues." + ] + + s = AgentRunStatus() + + for query in tqdm(queries): + result = agent.run(query) + s.total_cnt += 1 + if agent.finished: + s.success_cnt += 1 + agent.save_agent_status("./save_instructions.jsonl") + agent.reset() + logger.info("\n") + logger.info("*" * 150) + logger.info(f"Question: {query}") + logger.info("*" * 150) + logger.info(f"Final answer: {result.answer}") + logger.info("*" * 150) + logger.info(f"Trajectory Path: {result.scratchpad}") + logger.info("*" * 150) + + logger.info(f"success rates: {s}") + logger.info(f"Total success rates: {s}") diff --git a/mxAgent/samples/basic_demo/intent_router.py b/mxAgent/samples/basic_demo/intent_router.py new file mode 100644 index 0000000000000000000000000000000000000000..3826da5bd0d082e9d678db23e12a1f43ddd1fd56 --- /dev/null +++ b/mxAgent/samples/basic_demo/intent_router.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +from loguru import logger + +from agent_sdk.llms.llm import get_llm_backend, BACKEND_OPENAI_COMPATIBLE +from agent_sdk.agentchain.router_agent import RouterAgent + +llm = get_llm_backend(backend=BACKEND_OPENAI_COMPATIBLE, + api_base="http://10.44.115.108:1055/v1", api_key="EMPTY", llm_name="Qwen1.5-32B-Chat").run + +INTENT = { + "query_flight": "用户期望查询航班信息", + "query_attraction": "用户期望查询旅游景点信息", + "query_hotel": "用户期望查询酒店和住宿信息", + "plan_attraction": "用户期望给出旅行规划建议", + "whimsical": "异想天开", + "other": "其他不符合上述意图的描述" +} + +querys = [ + "帮我查一下从北京去深圳的机票", + "帮我查一下北京的旅游景点", + "我想去北京旅游", + "去北京旅游可以住在哪里呢,推荐一下", + "帮我去书城买本书", "我想上天" +] + +agent = RouterAgent(llm=llm, intents=INTENT) + +for query in querys: + response = agent.run(query) + agent.reset() + logger.info(f"query: {query}, intent: {response.answer}") diff --git a/mxAgent/samples/basic_demo/traj_generate_test.py b/mxAgent/samples/basic_demo/traj_generate_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4b73048bf64da09069218ba2301f38aa308699 --- /dev/null +++ b/mxAgent/samples/basic_demo/traj_generate_test.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import json +import os +import warnings +from typing import Callable, List +from tqdm import tqdm +from loguru import logger +from langchain._api import LangChainDeprecationWarning + +from agent_sdk.agentchain.base_agent import BaseAgent +from agent_sdk.agentchain.single_action_agent import SingleActionAgent +from agent_sdk.llms.llm import get_llm_backend, BACKEND_OPENAI_COMPATIBLE +from samples.tools import QueryAccommodations, QueryAttractions, QueryRestaurants, \ + QueryTransports, QueryGoogleDistanceMatrix + + +API_BASE = "http://10.44.115.98:8006/v1" +API_KEY = "EMPTY" +MODEL_NAME = "Qwen2-7b-Instruct" + +os.environ["OPENAI_API_BASE"] = API_BASE +os.environ["OPENAI_API_KEY"] = API_KEY +os.environ["MODEL_NAME"] = MODEL_NAME +os.environ["WORKING_DIR"] = os.path.dirname( + os.path.dirname(os.path.realpath(__file__))) + +warnings.filterwarnings('ignore') +warnings.filterwarnings('ignore', category=DeprecationWarning) +warnings.filterwarnings('ignore', category=LangChainDeprecationWarning) + + +class TrajectoryGenerator: + + @staticmethod + def generate(output_path: str, agent: BaseAgent, load_dataset: Callable[[], List[str]], **kwargs): + questions = load_dataset() + for q in tqdm(questions): + try: + agent.run(q, **kwargs) + agent.save_agent_status(output_path) + agent.reset() + + except Exception as err: + logger.warning(f"generate traj failed, query: {q}, agent: {agent.name}, err: {err}") + continue + + @staticmethod + def _check_data_format(data): + if not isinstance(data, list): + raise ValueError("Data should be a list of dict") + + if len(data) == 0: + raise ValueError("Data should not be empty") + + if not isinstance(data[0], dict): + raise ValueError("Data item should be a dict") + + alpaca_format_keys = ["instruction", "input", "output", "status"] + data_keys_set = set(data[0].keys()) + + if not all([key in data_keys_set for key in alpaca_format_keys]): + raise ValueError("need alpaca data format") + + def _load_data_from_file(self, data_path): + if not os.path.exists(data_path): + raise FileNotFoundError(f"File not found: {data_path}") + + if data_path.endswith(".jsonl"): + traj_data = [json.loads(line) for line in open(data_path, "r")] + else: + raise ValueError("Unknown file format") + + self._check_data_format(traj_data) + return traj_data + + +def get_single_action_agent(api_base, api_key, llm_name): + tool_list = [ + QueryAccommodations, QueryAttractions, QueryRestaurants, + QueryTransports, QueryGoogleDistanceMatrix + ] + llm = get_llm_backend(BACKEND_OPENAI_COMPATIBLE, + api_base, api_key, llm_name).run + return SingleActionAgent(llm=llm, tool_list=tool_list, max_steps=5) + + +if __name__ == '__main__': + single_agent = get_single_action_agent(API_BASE, API_KEY, MODEL_NAME) + queries = [ + "Write a review of the hotel \"The Beach House\" in Charlotte Amalie.", + "Book a flight from Evansville to Sacramento for April 10th.", + "Create a list of top 5 attractions in Hilo for a solo traveler.", + "Compare the prices of hotels in Newark for a 3-night stay.", + "Book a hotel room in Paducah for April 12th.", + "Write a travel blog post about visiting the Golden Gate Bridge in San Francisco.", + "Recommend the best mode of transportation from Flagstaff to Phoenix.", + "Determine the best time to visit the Statue of Liberty.", + "Compare the prices of car rentals in Seattle.", + "What are the top - rated museums in Harrisburg?" + ] + generator = TrajectoryGenerator() + generator.generate(output_path="./save_instructions.jsonl", agent=single_agent, + load_dataset=lambda: queries) diff --git a/mxAgent/samples/tools/__init__.py b/mxAgent/samples/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..030d89752044baadfb799b193da5aa32d6082b53 --- /dev/null +++ b/mxAgent/samples/tools/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +__all__ = [ + 'CostEnquiry', 'Finish', 'QueryAccommodations', 'QueryAttractions', 'CitySearch', + 'QueryGoogleDistanceMatrix', 'QueryTransports', 'QueryWeather', "QueryRestaurants", + 'PlanSummary', 'WebSummary' +] + +from samples.tools.tool_cost_enquiry import CostEnquiry +from samples.tools.tool_finish import Finish + +from samples.tools.tool_query_accommodations import QueryAccommodations +from samples.tools.tool_query_restaurants import QueryRestaurants +from samples.tools.tool_query_attractions import QueryAttractions +from samples.tools.tool_query_city import CitySearch +from samples.tools.tool_query_distance_matrix import QueryGoogleDistanceMatrix +from samples.tools.tool_query_transports import QueryTransports +from samples.tools.tool_query_weather import QueryWeather + +from samples.tools.tool_summary import PlanSummary +from samples.tools.web_summary_api import WebSummary \ No newline at end of file diff --git a/mxAgent/samples/tools/duck_search.py b/mxAgent/samples/tools/duck_search.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5782d66ecba9d3dc8b15c8d912f4e9c74e4049 --- /dev/null +++ b/mxAgent/samples/tools/duck_search.py @@ -0,0 +1,114 @@ +import json +from typing import List + +from langchain_community.tools import DuckDuckGoSearchResults +from langchain_community.utilities import DuckDuckGoSearchAPIWrapper +from utils.log import LOGGER as logger + +from toolmngt.api import API + + +class DuckDuckGoSearch(API): + name = "DuckDuckGoSearch" + description = ("DuckDuckGoSearch engine can search for rich external knowledge on the Internet based on keywords, " + "which can compensate for knowledge fallacy and knowledge outdated.") + input_parameters = { + 'query': {'type': 'str', 'description': "the query string to be search"} + } + output_parameters = { + 'information': {'type': 'str', 'description': 'the result information from Bing search engine'} + } + usage = ("DuckDuckGoSearch[query], which searches the exact detailed query on the Internet and returns the " + "relevant information to the query. Be specific and precise with your query to increase the chances of " + "getting relevant results. For example, DuckDuckGoSearch[popular dog breeds in the United States]") + + def __init__(self) -> None: + self.scratchpad = "" + self.bingsearch_results = "" + + def format_tool_input_parameters(self, llm_output) -> dict: + input_parameters = {"query": llm_output} + return input_parameters + + def check_api_call_correctness(self, response: dict, groundtruth=None) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + + ex = response.get("exception") + + if ex is not None: + return False + else: + return True + + def call(self, input_parameters: dict, **kwargs) -> dict: + """ + Calls the API with the given parameters. + + Parameters: + + input_parameters = { + 'query': query + } + + Returns: + - response (str): the response from the API call. + """ + logger.debug(f"{input_parameters}") + query = input_parameters.get('query', "") + + try: + responses = self.call_duck_duck_go_search(query=query, count=4) + logger.debug(f"responses is {responses}") + output = "" + if len(responses) > 0: + for r in responses: + output += self.format_step(r) + else: + output = "Bing search error" + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, + 'output': f'Search error,please try again', + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output, + 'exception': None} + + def format_result(self, res): + snippet_idx = res.find("snippet:") + title_idx = res.find("title:") + link_idx = res.find("link:") + snippet = res[snippet_idx + len("snippet:"):title_idx] + title = res[title_idx + len("title:"):link_idx] + link = res[link_idx + len("link:"):] + return {"snippet": snippet.replace("", "").replace("", ""), "title": title, "link": link} + + def call_duck_duck_go_search(self, query: str, count: int) -> List[str]: + try: + logger.debug(f"search DuckDuckGo({query}, {count})") + duck_duck_search = DuckDuckGoSearchAPIWrapper(max_results=count) + search = DuckDuckGoSearchResults(api_wrapper=duck_duck_search) + self.bingsearch_results = [] + temp = search.run(query) + logger.debug(temp) + + for x in temp.split("["): + snippet = x.split("]")[0].strip() + if len(snippet) == 0: + continue + logger.debug(f"snippet is {snippet}") + self.bingsearch_results.append(self.format_result(snippet)) + logger.success(f"{json.dumps(self.bingsearch_results, indent=4)}") + except Exception as e: + self.scratchpad += f'Search error {str(e)}, please try again' + + return [x['snippet'] for x in self.bingsearch_results] diff --git a/mxAgent/samples/tools/tool_cost_enquiry.py b/mxAgent/samples/tools/tool_cost_enquiry.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ae34c5aff1b1db1ab330224e4a675a4cd406d8 --- /dev/null +++ b/mxAgent/samples/tools/tool_cost_enquiry.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + + +from typing import Union + +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager + + +@ToolManager.register_tool() +class CostEnquiry(API): + name = "CostEnquiry" + description = "Indicate the final answer for the task" + input_parameters = { + 'Sub Plan': {'type': 'str', 'description': 'Sub Plan'} + } + + output_parameters = { + + } + + example = ( + """ + { + "Sub Plan": "This function calculates the cost of a detailed subn plan, which you need to input ' + 'the people number and plan in JSON format. The sub plan encompass a complete one-day plan. An' + 'example will be provide for reference." + } + """) + + def format_tool_input_parameters(self, text) -> Union[dict, str]: + input_parameters = {"answer": text} + return input_parameters + + def check_api_call_correctness(self, response, groundtruth) -> bool: + ex = response.get("exception") + + if ex is not None: + return False + else: + return True + + def call(self, input_parameter: dict, **kwargs): + action_arg = input_parameter.get('Sub Plan', "") + react_env = kwargs.get("react_env is missing") + + if react_env is None: + raise Exception("react_env is missing") + + try: + input_arg = eval(action_arg) + if not isinstance(input_arg, dict): + raise ValueError( + 'The sub plan can not be parsed into json format, please check. Only one day plan is ' + 'supported.' + ) + result = f"Cost: {react_env.run(input_arg)}" + + except SyntaxError: + result = f"The sub plan can not be parsed into json format, please check." + + except ValueError as e: + result = str(e) + + return self.make_response(input_parameter, result) diff --git a/mxAgent/samples/tools/tool_finish.py b/mxAgent/samples/tools/tool_finish.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee7aafd761be3edc93e5032db8eaf15b9aa7872 --- /dev/null +++ b/mxAgent/samples/tools/tool_finish.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +from typing import Union + +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager + + +@ToolManager.register_tool() +class Finish(API): + description = "Provide a final answer to the given task." + + input_parameters = { + 'answer': {'type': 'str', 'description': "the final result"} + } + + output_parameters = {} + + example = ( + """ + { + "plan details": "The final answer the task." + } + """) + + def __init__(self) -> None: + super().__init__() + + def format_tool_input_parameters(self, text) -> Union[dict, str]: + input_parameter = {"answer": text} + return input_parameter + + def gen_few_shot(self, thought: str, param: str, idx: int) -> str: + return (f"Thought: {thought}\n" + f"Action: {self.__class__.__name__}\n" + f"Action Input: {param}\n") + + def check_api_call_correctness(self, response, groundtruth) -> bool: + ex = response.get("exception") + + if ex is not None: + return False + else: + return True + + def call(self, input_parameter: dict, **kwargs): + answer = input_parameter.get('answer', "") + return self.make_response(input_parameter, answer) diff --git a/mxAgent/samples/tools/tool_general_query.py b/mxAgent/samples/tools/tool_general_query.py new file mode 100644 index 0000000000000000000000000000000000000000..312d6e7cf864a3aa30a577453c18cc246aabdb5b --- /dev/null +++ b/mxAgent/samples/tools/tool_general_query.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import json + +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager +from loguru import logger +from samples.tools.web_summary_api import WebSummary + + +@ToolManager.register_tool() +class GeneralQuery(API): + name = "GeneralQuery" + description = "This api can collect information or answer about the travel related query from internet." + input_parameters = { + "keywords": {'type': 'str', + "description": "the keys words related to travel plan included in the user's query"}, + + } + + output_parameters = { + 'reply': {'type': 'str', 'description': 'the replay from internet to the query'}, + } + + example = ( + """ + { + "keywords": "北京,美食" + } + """) + + def __init__(self): + pass + + def check_api_call_correctness(self, response, groundtruth=None) -> bool: + + if response['exception'] is None: + return True + else: + return False + + def call(self, input_parameter: dict, **kwargs): + keywords = input_parameter.get('keywords') + try: + if keywords is None or len(keywords) == 0: + return self.make_response(input_parameter, results="", exception="") + prompt = """你是一个擅长文字处理和信息总结的智能助手,你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, + 请添加适当的词语,使得语句内容连贯,通顺,但不要自行杜撰,保证内容总结的客观性。 + 下面是网页的输入: + {input} + 请生成总结段落: + """ + webs = WebSummary.web_summary( + keys=keywords, search_num=3, summary_num=3, summary_prompt=prompt) + + if len(webs) == 0: + content = "" + else: + content = json.dumps(webs, ensure_ascii=False) + logger.info(content) + res = { + 'reply': content + } + + except Exception as e: + logger.error(e) + e = str(e) + return self.make_response(input_parameter, results=e, success=False, exception=e) + else: + return self.make_response(input_parameter, results=content, exception="") + + +if __name__ == '__main__': + accommodationSearch = GeneralQuery() + tes = { + "keywords": "[北京,天气]" + } + test = accommodationSearch.call(tes) diff --git a/mxAgent/samples/tools/tool_query_accommodations.py b/mxAgent/samples/tools/tool_query_accommodations.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bba1103f2b03d6155515f903311b826eed0a84 --- /dev/null +++ b/mxAgent/samples/tools/tool_query_accommodations.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import json + +import tiktoken +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager +from loguru import logger +from samples.tools.web_summary_api import WebSummary + + +@ToolManager.register_tool() +class QueryAccommodations(API): + name = "QueryAccommodations" + description = "This api can discover accommodations in your desired city." + input_parameters = { + "destination_city": {'type': 'str', 'description': 'The city you aim to reach.'}, + "position": {'type': 'str', 'description': 'The geographical position of accomodation appointed by the user'}, + "rank": {'type': 'str', 'description': 'The rank of hotel the user want to query'} + } + + output_parameters = { + 'accommodation': { + 'type': 'str', + 'description': 'Contain hotel name, price, type, check-in requirements and other information' + } + } + + example = ( + """ + { + "destination_city": "Rome", + "position": "Central Park", + "rank": "five stars" + }""") + + def __init__(self): + self.encoding = tiktoken.get_encoding("gpt2") + + def check_api_call_correctness(self, response, groundtruth) -> bool: + ex = response.exception + if ex is not None: + return False + else: + return True + + def call(self, input_parameter, **kwargs): + destination = input_parameter.get('destination_city') + position = input_parameter.get("position") + rank = input_parameter.get("rank") + llm = kwargs.get("llm", None) + keys = [destination, position, rank] + keyword = [] + logger.debug(f"search accommodation key words: {','.join(keyword)}") + for val in keys: + if val is None or len(val) == 0: + continue + if '无' in val or '未' in val or '没' in val: + continue + if isinstance(val, list): + it = flatten(val) + keyword.append(it) + keyword.append(val) + if len(keyword) == 0: + return self.make_response(input_parameter, results="", exception="") + keyword.append("住宿") + prompt = """你是一个擅长文字处理和信息总结的智能助手,你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, + 请添加适当的词语,使得语句内容连贯,通顺。提供的信息是为用户推荐的酒店的网页数据, + 请总结网页信息,要求从以下几个方面考虑: + 1. 酒店的地理位置,星级、评分,评价,品牌信息 + 2. 不同的户型对应的价格、房间情况,对入住用户的要求等 + 并给出一到两个例子介绍这些情况 + 若输入的内容没有包含有效的酒店和住宿信息,请统一返回:【无】 + 下面是网页的输入: + {input} + 请生成总结: + """ + try: + webs = WebSummary.web_summary( + keys=keyword, search_num=3, summary_num=3, summary_prompt=prompt, llm=llm) + except Exception as e: + logger.error(e) + return self.make_response(input_parameter, results=e, success=False, exception=e) + else: + if len(webs) == 0: + content = "" + else: + content = json.dumps(webs, ensure_ascii=False) + logger.info(content) + res = { + 'accommodation': content + } + return self.make_response(input_parameter, results=res, exception="") + + +def flatten(nested_list): + """递归地扁平化列表""" + for item in nested_list: + if isinstance(item, list): + return flatten(item) + else: + return item diff --git a/mxAgent/samples/tools/tool_query_attractions.py b/mxAgent/samples/tools/tool_query_attractions.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb076e807b8059b4d257909774fd58096ca497f --- /dev/null +++ b/mxAgent/samples/tools/tool_query_attractions.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import tiktoken +import yaml +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager +from loguru import logger +from samples.tools.web_summary_api import WebSummary + + +@ToolManager.register_tool() +class QueryAttractions(API): + name = "QueryAttractions" + description = "This api can be used to Search for tourist attractions from websites that '\ + users expect and summarize them." + input_parameters = { + 'destination': {'type': 'str', 'description': "The destination where the user wants to travel."}, + 'scene': {'type': 'str', 'description': 'The specific scenic spot mentioned by the user'}, + 'type': {'type': 'str', + 'description': 'The specific type of scenic spot mentioned by the user, eg museum, park'}, + 'requirement': {'type': 'str', 'description': 'The requirement of scenic spot mentioned by the user'}, + + } + + output_parameters = { + 'attractions': { + 'type': 'str', + 'description': 'Contains local attractions address, contact information, website, latitude "\ + and longitude and other information' + } + } + + example = ( + """ + { + "destination": "Paris", + "scene": "The Louvre Museum", + "type": "Museum", + "requirement": "historical" + }""") + + def __init__(self): + self.encoding = tiktoken.get_encoding("gpt2") + + def check_api_call_correctness(self, response, groundtruth) -> bool: + ex = response.exception + if ex is not None: + return False + else: + return True + + def call(self, input_parameter: dict, **kwargs): + destination = input_parameter.get('destination') + scene = input_parameter.get('scene') + scene_type = input_parameter.get('type') + requirement = input_parameter.get('requirement') + llm = kwargs.get("llm", None) + keyword = [] + keys = [destination, scene, scene_type, requirement] + for val in keys: + if val is None or len(val) == 0: + continue + if '无' in val or '未' in val or '没' in val: + continue + if isinstance(val, list): + it = flatten(val) + keyword.append(it) + keyword.append(val) + if len(keyword) == 0: + return self.make_response(input_parameter, results="", + exception="failed to obtain search keyword") + + keyword.append('景点') + logger.debug(f"search attraction key words: {','.join(keyword)}") + + summary_prompt = """你是一个擅长于网页信息总结的智能助手,提供的网页是关于旅游规划的信息,现在已经从网页中获取到了相关的文字内容信息,你需要从网页中找到与**景区**介绍相关的内容,并进行提取, + 你务必保证提取的内容都来自所提供的文本,保证结果的客观性,真实性。 + 网页中可能包含多个景点的介绍,你需要以YAML文件的格式返回,每个景点的返回的参数和格式如下: + **输出格式**: + - name: xx + introduction: xx + **参数介绍**: + name:景点名称 + introduction:精简的景区介绍,可以从以下这些方面阐述:景点的基本情况、历史文化等信息、景区门票信息、景区开放时间、景区的联系方式、预约方式以及链接,景区对游客的要求等。 + **注意** + 请注意:不要添加任何解释或注释,且严格遵循YAML格式 + 下面是提供的网页文本信息: + {input} + 请开始生成: + """ + + web_output = WebSummary.web_summary( + keyword, search_num=3, summary_num=3, summary_prompt=summary_prompt, llm=llm) + + if len(web_output) == 0: + yaml_str = "" + else: + yaml_str = yaml.dump(web_output, allow_unicode=True) + + responses = { + 'attractions': yaml_str + } + + return self.make_response(input_parameter, results=responses, exception="") + + +def flatten(nested_list): + """递归地扁平化列表""" + for item in nested_list: + if isinstance(item, list): + return flatten(item) + else: + return item diff --git a/mxAgent/samples/tools/tool_query_city.py b/mxAgent/samples/tools/tool_query_city.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1c90a037c61c17d25ef95bc430b36dfe08df5c --- /dev/null +++ b/mxAgent/samples/tools/tool_query_city.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import os +from typing import Union + +from agent_sdk.toolmngt.api import API +from loguru import logger +from agent_sdk.toolmngt.tool_manager import ToolManager + +current_file_path = os.path.abspath(__file__) +current_folder_path = os.path.dirname(current_file_path) +parent_folder_path = os.path.dirname(current_folder_path) + + +@ToolManager.register_tool() +class CitySearch(API): + name = "CitySearch" + input_parameters = { + 'state': {'type': 'str', 'description': "the name of the state"} + } + + output_parameters = { + "state": {'type': 'str', 'description': "the name of the state"}, + "city": {'type': 'str', 'description': "the name of the city in the state"} + } + + usage = f"""{name}[state]: + Description: This api can be used to retrieve cities in your target state. + Parameter: + state: The name of the state where you're finding cities. + Example: {name}[state: New York] would return cities in New York. + """ + + example = ( + """ + { + "state": "New York" + }""") + + def __init__(self, path="database/background"): + self.states_path = os.path.join(parent_folder_path, path, "stateSet.txt") + self.states_cities_path = os.path.join(parent_folder_path, path, "citySet_with_states.txt") + self.states = [] + self.cities_in_state = {} + + with open(self.states_path, "r") as f: + content = f.read() + content.split('\n') + for state in content: + self.states.append(state.strip()) + + with open(self.states_cities_path, "r") as f: + context = f.read() + context = context.split("\n") + + for city_state in context: + city_state = city_state.split('\t') + city = city_state[0].strip() + state = city_state[1].strip() + + if state in self.cities_in_state.keys(): + self.cities_in_state[state].append(city) + else: + self.cities_in_state[state] = [city] + + logger.info("cities and states loaded.") + + def format_tool_input_parameters(self, text) -> Union[dict, str]: + return text + + def check_api_call_correctness(self, response, groundtruth) -> bool: + if response["exception"] is None: + return True + else: + return False + + def call(self, input_parameter: dict, **kwargs): + state = input_parameter.get('state', '') + + if state in self.cities_in_state.keys(): + results = self.cities_in_state[state] + results = ", ".join(results) + results = f"{state} has {results}" + + logger.info("search the cities in state successfully, results:") + logger.info(results) + + return self.make_response(input_parameter, results) + else: + return self.make_response(input_parameter, "Failed to search the cities in state", + exception='cant find state') diff --git a/mxAgent/samples/tools/tool_query_distance_matrix.py b/mxAgent/samples/tools/tool_query_distance_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a10c4e151aa76657f0e5e893ba00bc52e31ee5 --- /dev/null +++ b/mxAgent/samples/tools/tool_query_distance_matrix.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + + +import json +import os +import re +from typing import Tuple +from agent_sdk.toolmngt.tool_manager import ToolManager + +import numpy as np +import pandas as pd +from agent_sdk.toolmngt.api import API, APIResponse +from loguru import logger + + +@ToolManager.register_tool() +class QueryGoogleDistanceMatrix(API): + name = "QueryGoogleDistanceMatrix" + input_parameters = { + 'origin': {'type': 'str', 'description': "The departure city of your journey."}, + 'destination': {'type': 'str', 'description': "The destination city of your journey."}, + 'mode': {'type': 'str', + 'description': "The method of transportation. Choices include 'self-driving' and 'taxi'."} + } + + output_parameters = { + 'origin': {'type': 'str', 'description': 'The origin city of the flight.'}, + 'destination': {'type': 'str', 'description': 'The destination city of your flight.'}, + 'cost': {'type': 'str', 'description': 'The cost of the flight.'}, + 'duration': {'type': 'str', 'description': 'The duration of the flight. Format: X hours Y minutes.'}, + 'distance': {'type': 'str', 'description': 'The distance of the flight. Format: Z km.'}, + } + + usage = f"""{name}[origin, destination, mode]: + Description: This api can retrieve the distance, time and cost between two cities. + Parameter: + origin: The departure city of your journey. + destination: The destination city of your journey. + mode: The method of transportation. Choices include 'self-driving' and 'taxi'. + Example: {name}[origin: Paris, destination: Lyon, mode: self-driving] would provide driving distance, time and cost between Paris and Lyon. + """ + + example = ( + """ + { + "origin": "Paris", + "destination": "Lyon", + "mode": "self-driving" + }""") + + def __init__(self) -> None: + logger.info("QueryGoogleDistanceMatrix API loaded.") + + def check_api_call_correctness(self, response, groundtruth) -> bool: + if response['exception'] is None: + return True + else: + return False + + def call(self, input_parameter: dict, **kwargs): + origin = input_parameter.get('origin', "") + destination = input_parameter.get('destination', "") + mode = input_parameter.get('mode', "") + return self.make_response(input_parameter, f"success to get {mode}, from {origin} to {destination}") diff --git a/mxAgent/samples/tools/tool_query_restaurants.py b/mxAgent/samples/tools/tool_query_restaurants.py new file mode 100644 index 0000000000000000000000000000000000000000..caed5ed7aa4b9cf17630211083362fe59a12cb86 --- /dev/null +++ b/mxAgent/samples/tools/tool_query_restaurants.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +from loguru import logger + +from agent_sdk.toolmngt.api import API, APIResponse +from agent_sdk.toolmngt.tool_manager import ToolManager + + +@ToolManager.register_tool() +class QueryRestaurants(API): + description = 'Explore dining options in a city of your choice.' + input_parameters = { + 'City': {'type': 'str', 'description': "The name of the city where you're seeking restaurants."} + } + + output_parameters = { + 'restaurant_name': {'type': 'str', 'description': 'The name of the restaurant.'}, + 'city': {'type': 'str', 'description': 'The city where the restaurant is located.'}, + 'cuisines': {'type': 'str', 'description': 'The cuisines offered by the restaurant.'}, + 'average_cost': {'type': 'int', 'description': 'The average cost for a meal at the restaurant.'}, + 'aggregate_rating': {'type': 'float', 'description': 'The aggregate rating of the restaurant.'} + } + + example = ( + """ + { + "City": "Tokyo" + }""") + + def __init__(self): + super().__init__() + logger.info("Restaurants loaded.") + + def call(self, input_parameter, **kwargs): + city = input_parameter.get('City', "") + return self.make_response(input_parameter, f"success to get restaurant in {city}") + + def check_api_call_correctness(self, response, ground_truth=None) -> bool: + return True diff --git a/mxAgent/samples/tools/tool_query_transports.py b/mxAgent/samples/tools/tool_query_transports.py new file mode 100644 index 0000000000000000000000000000000000000000..b48f03585065bed566b2cb99da4675e9008daff5 --- /dev/null +++ b/mxAgent/samples/tools/tool_query_transports.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + + +import json + +import tiktoken +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager +from loguru import logger +from samples.tools.web_summary_api import WebSummary + + +@ToolManager.register_tool() +class QueryTransports(API): + name = "QueryTransports" + description = "This API is used to query relevant travel traffic information from the \ + networkAccording to the user's input question," + input_parameters = { + "departure_city": {'type': 'str', 'description': "The city you'll be flying out from."}, + "destination_city": {'type': 'str', 'description': 'The city user aim to reach.'}, + "travel_mode": {'type': 'str', 'description': 'The mode of travel appointed by the user'}, + "date": {'type': 'str', 'description': 'The date of the user plan to travel'}, + 'requirement': {'type': 'str', 'description': 'The more requirement of transportation mentioned by the user'}, + } + output_parameters = { + "transport": {'type': 'str', + 'description': 'the transport information'}, + } + + example = ( + """ + { + "departure_city": "New York", + "destination_city": "London", + "date": "2022-10-01", + "travel_mode": "flight" + } + """) + + def __init__(self): + self.encoding = tiktoken.get_encoding("gpt2") + + def check_api_call_correctness(self, response, groundtruth=None) -> bool: + ex = response.exception + if ex is not None: + return False + else: + return True + + def call(self, input_parameter, **kwargs): + origin = input_parameter.get('departure_city') + destination = input_parameter.get('destination_city') + req = input_parameter.get("requirement") + travel_mode = input_parameter.get("travel_mode") + llm = kwargs.get("llm", None) + try: + prefix = f"从{origin}出发" if origin else "" + prefix += f"前往{destination}" if destination else "" + keys = [prefix, req, travel_mode] + filtered = [] + for val in keys: + if val is None or len(val) == 0: + continue + if '无' in val or '未' in val or '没' in val: + continue + filtered.append(val) + if len(filtered) == 0: + return self.make_response(input_parameter, results="", exception="") + filtered.append("购票") + logger.debug(f"search transport key words: {','.join(filtered)}") + + prompt = """你的任务是将提供的网页信息进行总结,并以精简的文本的形式进行返回, + 请添加适当的词语,使得语句内容连贯,通顺。输入是为用户查询的航班、高铁等交通数据,请将这些信息总结 + 请总结网页信息,要求从以下几个方面考虑: + 总结出航班或者高铁的价格区间、需要时长区间、并给出2-3例子,介绍车次、时间、时长、价格等 + 下面是网页的输入: + {input} + 请生成总结: + """ + webs = WebSummary.web_summary( + filtered, search_num=2, summary_num=2, summary_prompt=prompt, llm=llm) + if len(webs) == 0: + content = "" + else: + content = json.dumps(webs, ensure_ascii=False) + logger.info(f"search:{webs}") + res = { + 'transport': content + } + except Exception as e: + logger.error(e) + e = str(e) + return self.make_response(input_parameter, results=e, success=False, exception=e) + else: + return self.make_response(input_parameter, results=res, exception="") diff --git a/mxAgent/samples/tools/tool_query_weather.py b/mxAgent/samples/tools/tool_query_weather.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe739bc3110b83b9b1045eaebc376ec685fada7 --- /dev/null +++ b/mxAgent/samples/tools/tool_query_weather.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import datetime +import json +import json +import os +from zoneinfo import ZoneInfo + +import requests +import urllib3 +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager +from loguru import logger + +AMAP_API_KEY = "75bcb2edf5800884a31172dd0d970369" +WEEK_MAP = { + 0: "Monday", + 1: "Tuesday", + 2: "Wednesday", + 3: "Thursday", + 4: "Friday", + 5: "Saturday", + 6: "Sunday" +} +REQUEST_HEADERS = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '\ + Chrome/126.0.0.0 Safari/537.36" +} + + +@ToolManager.register_tool() +class QueryWeather(API): + name = "QueryWeather" + description = "This API is used to query weather forecast from the network according to the user's input question," + weekday = WEEK_MAP.get(datetime.datetime.now(ZoneInfo("Asia/Shanghai")).weekday(), '') + input_parameters = { + + 'destination_city': {'type': 'str', 'description': 'the destination city user aim to query weather.'}, + "province": {'type': 'str', 'description': 'The province corresponding to the city'}, + "date": {'type': 'str', + 'description': ("The date of the user want to query, today is" + + f"{datetime.date.today()}, and today is {weekday}, " + + "please reason the date from user's query, and format with YYYY-MM-DD,") + }, + 'requirement': {'type': 'str', 'description': 'The more requirement of weather mentioned by the user'}, + } + output_parameters = { + "forecast": {'type': 'str', + 'description': 'the weather forecast information'}, + } + + example = ( + """ + { + "destination_city": "ShenZhen", + "province": "GuangDong", + "date": "2022-10-01" + } + """) + + def __init__(self, ): + os.environ['CURL_CA_BUNDLE'] = '' # 关闭SSL证书验证 + urllib3.disable_warnings() + + def check_api_call_correctness(self, response, groundtruth=None) -> bool: + ex = response.exception + if ex is not None: + return False + else: + return True + + def get_forecast(self, url, param, city=""): + headers = REQUEST_HEADERS + response = requests.get(url, params=param, headers=headers, timeout=5) + if response.status_code != 200: + logger.error(f"获取网页{url}内容失败") + raise Exception(f"获取网页{url}内容失败") + content = response.content + text = json.loads(content) + return text.get("data") + + def get_city2province(self, url, city): + headers = REQUEST_HEADERS + params = { + "city": city, + "source": "pc" + } + response = requests.get(url, params=params, headers=headers, timeout=5) + if response.status_code != 200: + logger.error(f"获取网页{url}内容失败") + raise Exception(f"获取网页{url}内容失败") + content = response.content + text = json.loads(content) + return text.get("data") + + def format_weather(self, weekly_weather): + # 精简输入 + key_keeps = [ + 'day_weather', 'day_wind_direction', 'day_wind_power', + 'max_degree', 'min_degree', 'night_weather', 'night_wind_direction', 'night_wind_power' + ] + summary_copy = [] + for key, info in weekly_weather.items(): + time = info.get('time', key) + daily = {} + if isinstance(info, dict): + info_keeps = {k: info[k] for k in key_keeps if k in info} + daily[time] = info_keeps + summary_copy.append(daily) + return summary_copy + + def format_request_param(self, data, weather_type): + for key, value in data.items(): + city2province = value.replace(" ", "").split(",") + data[key] = city2province + # 遇到城市同名,认为是市的概率大于县 + _, max_probablity = min(data.items(), key=lambda item: len(item[1])) + if len(max_probablity) >= 2: + province = max_probablity[0] + city = max_probablity[1] + country = max_probablity[2] if len(max_probablity) >= 3 else "" + params = { + "source": "pc", # 请求来源,可以填 pc 即来自PC端 + "province": province, # 省, + "city": city, # 市, + "country": country, # 县区 + "weather_type": weather_type + } + return params + + def call(self, input_parameter, **kwargs): + des = input_parameter.get('destination_city') + departure_date = input_parameter.get("date") + weather_type = "forecast_24h" + + try: + if des is None: + return self.make_response(input_parameter, results="", success=False, exception="") + try: + data = self.get_city2province("https://wis.qq.com/city/like", des) + except Exception as e: + e = str(e) + return self.make_response(input_parameter, results=e, success=False, exception=e) + if len(data) == 0: + return self.make_response(input_parameter, + results="未能找到所查询城市所在的省份或市", success=False, exception="") + + params = self.format_request_param(data, weather_type) + try: + forecast = self.get_forecast( + "https://wis.qq.com/weather/common", params) + except Exception as e: + e = str(e) + return self.make_response(input_parameter, results=e, success=False, exception=e) + weekly_weather = forecast.get(weather_type) + summary_copy = self.format_weather(weekly_weather) + if departure_date is None: + res = { + 'forecast': summary_copy + } + return self.make_response(input_parameter, results=res, exception="") + + try: + formated_departure = datetime.datetime.strptime( + departure_date, "%Y-%m-%d").date() + except ValueError as e: + logger.warning(e) + formated_departure = datetime.date.today() + gaps = (formated_departure - datetime.date.today()).days + weather_summary = summary_copy[gaps + 1:] + + if len(weather_summary) == 0: + weather_summary = "**抱歉,我最多只能查询最近7天的天气情况,例如下面是我将为你提供最近的天气预报**:\n" + \ + json.dumps(summary_copy, ensure_ascii=False) + res = { + 'forecast': weather_summary + } + except Exception as e: + logger.error(e) + e = str(e) + return self.make_response(input_parameter, results=e, + success=False, exception=e) + else: + return self.make_response(input_parameter, results=res, exception="") diff --git a/mxAgent/samples/tools/tool_summary.py b/mxAgent/samples/tools/tool_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1b706c9cde747b2739e3fbeb6ada8a51891d67 --- /dev/null +++ b/mxAgent/samples/tools/tool_summary.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +from agent_sdk.toolmngt.api import API +from agent_sdk.toolmngt.tool_manager import ToolManager + + +@ToolManager.register_tool() +class PlanSummary(API): + name = "PlanSummary" + description = "this api uesed to summary all the travel plan." + input_parameters = { + 'attractions': {'type': 'str', 'description': "the planned arrangement of attraction."}, + 'accomadation': {'type': 'str', 'description': "the accomodation information"}, + 'transport': {'type': 'str', 'description': "the transport information"}, + 'weather': {'type': 'str', 'description': "Weather information for the next few days"}, + 'duration': {'type': 'str', 'description': "The days of travel"}, + } + + output_parameters = { + 'summary': {'type': 'str', 'description': 'Summary all the plan of this travel'}, + } + + example = "PlanSummary[attractions,hotel,flight] will summary all the plan of travel inculed attractions,'\ + accomadation,and transport information" + example = ( + """ + { + "attractions": "London Bridge, any of several successive structures spanning the River Thames between '\ + Borough High Street in Southwark and King William Street.", + "accomadation": "Park Plaza London Riverbank In the heart of London, with great transport connections, '\ + culture, shopping, and green spaces", + "transport": "10 hours from Beijing to London cost $1000.", + } + """) + + def __init__(self): + pass + + def format_tool_input_parameters(self, llm_output) -> dict: + return llm_output if llm_output else {} + + def check_api_call_correctness(self, response, groundtruth=None) -> bool: + if response['exception'] is None: + return True + else: + return False + + def call(self, input_parameters, **kwargs): + # 总的只能输入2500个字左右 + attraction = input_parameters.get('attractions') + hotel = input_parameters.get('accomadation') + transport = input_parameters.get('transport') + weather = input_parameters.get("weather") + duration = input_parameters.get("duration") + + res = "" + if duration is not None: + res += f"【用户需要旅行的天数】:{duration}天\n" + if attraction is not None: + res = res + f"【景点汇总】:\n{str(attraction)[:1000]}\n" + if hotel is not None: + res = res + f"【住宿安排】:\n{str(hotel)[:500]}\n" + if transport is not None: + res = res + f"【交通安排】:\n{str(transport)[:500]}\n" + if weather is not None: + res = res + f"【未来几天的天气情况】:\n{str(weather)[:500]}\n" + summary = { + "summary": res + } + return self.make_response(input_parameters, results=summary, exception="") \ No newline at end of file diff --git a/mxAgent/samples/tools/web_summary_api.py b/mxAgent/samples/tools/web_summary_api.py new file mode 100644 index 0000000000000000000000000000000000000000..feeb4bb5b306b22834ab57e5952cf7794f3e6981 --- /dev/null +++ b/mxAgent/samples/tools/web_summary_api.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + + +import asyncio +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, wait, as_completed + +import aiohttp +import requests +import tiktoken +import urllib3 +from bs4 import BeautifulSoup +from loguru import logger +from samples.tools.google_search_api import google_search + + +def check_number_input(num, crow): + if not num.isdigit(): + return False + num = int(num) + if num > crow: + return False + return True + + +async def bai_du(url): + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '\ + Chrome/126.0.0.0 Safari/537.36" + } + async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=False, limit=25), trust_env=True, + headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as session: + async with session.get(url) as response: + res = await response.text() + return res + + +class WebSummary: + encoder = tiktoken.get_encoding("gpt2") + + @classmethod + def get_detail_copy(cls, url, summary_prompt): + os.environ['CURL_CA_BUNDLE'] = '' # 关闭SSL证书验证 + urllib3.disable_warnings() + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '\ + Chrome/126.0.0.0 Safari/537.36" + } + try: + mommt = time.time() + logger.info(f"start request website: {mommt},{url}") + response = requests.get( + url, headers=headers, timeout=(3, 3), stream=True) + mommt = time.time() + logger.info(f"finish request website: {mommt},{url}") + if response.status_code != 200: + logger.error(f"获取网页{url}内容失败") + return '', f"获取网页{url}内容失败" + + content = response.content + bsobj = BeautifulSoup(content, 'html.parser') + txt = bsobj.get_text() + text = re.sub(r'\n{2,}', '\n', txt).replace(' ', '') + text = re.sub(r'\n{2,}', '\n', text) + except Exception as e: + logger.error(e) + return '', e + res = cls.generate_content(text, summary_prompt) + mommt = time.time() + logger.info(f"finish summary website: {mommt},{url}") + return res, "" + + @classmethod + async def get_details(cls, url, summary_prompt): + os.environ['CURL_CA_BUNDLE'] = '' + urllib3.disable_warnings() + try: + mommt = time.time() + logger.info(f"start request website: {mommt},{url}") + response = await bai_du(url) + mommt = time.time() + logger.debug(f"finish request website: {mommt},{url}") + content = response + bsobj = BeautifulSoup(content, 'html.parser') + txt = bsobj.get_text() + text = re.sub(r'\n{2,}', '\n', txt).replace(' ', '') + text = re.sub(r'\n{2,}', '\n', text) + if 'PleaseenableJSanddisableanyadblocker' in text: + text = "" + except Exception as e: + logger.error(e) + return '', e + if len(text) == 0: + return "", "no valid website content" + res = cls.generate_content(text, summary_prompt) + mommt = time.time() + logger.info(f"finish summary website: {mommt},{url}") + return res, "" + + @classmethod + def summary_call(cls, web, max_summary_number, summary_prompt): + title = web.get("title", "") + url = web.get("url") + snippet = web.get("snippet", "") + web_summary = {} + if url is None: + return web_summary + + web_summary['title'] = title + web_summary['url'] = url + try: + content = asyncio.run(cls.get_details(url, summary_prompt)) + except Exception as e: + logger.error(e) + if not isinstance(content, str) or len(content) == 0: + web_summary['snippet'] = snippet + else: + web_summary['content'] = content + + return web_summary + + @classmethod + def web_summary(cls, keys, search_num, summary_num, summary_prompt, llm): + logger.add('app.log', level='DEBUG') + cls.llm = llm + try: + mommt = time.time() + logger.debug(f"start google search: {mommt}") + if isinstance(keys, list): + keys = ",".join(keys) + search_result = google_search(keys, search_num) + mommt = time.time() + logger.debug(f"finish google search: {mommt}") + except Exception as e: + logger.error(e) + return [] + + max_summary_number = summary_num + + webs = [] + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [] + for web in search_result: + thread = executor.submit( + cls.summary_call, web, max_summary_number, summary_prompt) + futures.append(thread) + for future in as_completed(futures): + webs.append(future.result()) + wait(futures) + return webs + + @classmethod + def build_summary_prompt(cls, query, prompt): + max_input_token_num = 4096 + if len(query) == 0: + return prompt.format(text=query) + input_token_len = len(WebSummary.encoder.encode(query)) + prompt_len = len(WebSummary.encoder.encode(prompt)) + clip_text_index = int( + len(query) * (max_input_token_num - prompt_len) / input_token_len) + clip_text = query[:clip_text_index] + return prompt.format(input=clip_text) + + @classmethod + def generate_content(cls, query, prompt): + max_tokens = 1000 + try: + pmt = WebSummary.build_summary_prompt(query, prompt) + output = cls.llm(prompt=pmt, max_tokens=max_tokens) + except Exception as e: + logger.error(e) + return e + return output diff --git a/mxAgent/samples/travel_agent_demo/front/chat_bot_release.py b/mxAgent/samples/travel_agent_demo/front/chat_bot_release.py new file mode 100644 index 0000000000000000000000000000000000000000..d74e036966136029788afe7d05b445f3b09229db --- /dev/null +++ b/mxAgent/samples/travel_agent_demo/front/chat_bot_release.py @@ -0,0 +1,62 @@ +import streamlit as st +from samples.travel_agent.travelagent import TravelAgent + +if __name__ == "__main__": + st.set_page_config( + page_title="旅游规划agent", + page_icon="./logo.jpg" + ) + st.logo("logo.jpg") + st.markdown('

旅游规划Agent

', unsafe_allow_html=True) + + + placeholder1 = st.empty() + placeholder2 = st.empty() + placeholder3 = st.empty() + + if "messages" not in st.session_state: + st.session_state.messages = [] + + if "aagent" not in st.session_state: + st.session_state.agent = TravelAgent() + + with placeholder1: + container = st.container(height=300, border=False) + with placeholder2: + _, col1, _ = st.columns([10, 2, 10]) + with col1: + st.image("logo.jpg", use_column_width=True) + with placeholder3: + _, col2, _ = st.columns([1, 20, 1]) + helloinfo = """

您好,我是旅游规划agent,擅长旅行规划、景点攻略查询

+

例如:从北京到西安旅游规划

+

例如:西安有哪些免费的博物馆景点

+

例如:查一下西安的酒店

""" + with col2: + st.markdown(helloinfo, unsafe_allow_html=True) + + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.empty() + st.markdown(message["content"]) + + if prompt = st.chat_input("send message"): + st.session_state.messages.append({"role":"user", "content":prompt}) + placeholder1.empty() + placeholder2.empty() + placeholder3.empty() + + agent = st.session_state["agent"] + + with st.chat_message("user"): + st.markdown(prompt) + + with st.chat_message("assistant"): + with st.spinner("thinking..."): + response = agent.run(query=prompt, stream=True) + if isinstance(response, str): + st.markdown(response) + else: + response = st.write_stream(response) + + st.session_state.messages.append({"role":"assistant", "content":response}) diff --git a/mxAgent/samples/travel_agent_demo/travelagent.py b/mxAgent/samples/travel_agent_demo/travelagent.py new file mode 100644 index 0000000000000000000000000000000000000000..2474029402ecdca3f01c7eb24b5dafa9a66e8beb --- /dev/null +++ b/mxAgent/samples/travel_agent_demo/travelagent.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import enum +from abc import ABC +from loguru import logger + +from agent_sdk.llms.llm import get_llm_backend, BACKEND_OPENAI_COMPATIBLE +from agent_sdk.agentchain.recipe_agent import RecipeAgent +from agent_sdk.agentchain.router_agent import RouterAgent +from agent_sdk.agentchain.tool_less_agent import ToollessAgent +from agent_sdk.agentchain.base_agent import AgentRunResult + +from samples.tools.tool_query_accommodations import QueryAccommodations +from samples.tools.tool_query_transports import QueryTransports +from samples.tools.tool_query_attractions import QueryAttractions +from samples.tools.tool_summary import PlanSummary +from samples.tools.tool_general_query import GeneralQuery +from samples.tools.tool_query_weather import QueryWeather + + +PESUEDE_CODE = """步骤1:根据用户问题中对景点相关的需求,从网络中搜索相关的景点信息 +步骤2:根据用户的问题,从网络中查询相关的出行交通信息。 +步骤3:根据用户的问题,从网络中搜索相关的住宿和酒店信息; +步骤4:根据用户的问题,查询用户需要的城市天气情况; +步骤5:总结以上的出行信息、景点游玩、住宿信息等。""" +TRANSPORT_INST = """步骤1:根据用户的输入问题,从网络中查询相关的出行交通信息""" +ATTRACTION_INST = "步骤1:根据用户问题中对景点相关的需求,从网络中搜索相关的景点信息" +HOTEL_INST = """步骤1:根据用户的问题,从网络中搜索相关的住宿和酒店信息,""" +WEATHER_INST = "步骤一:根据用户的问题,查询用户需要的城市天气情况" +OTHER_INST = """步骤1:根据用户的输入,从互联网中查询相关的解答""" + + +GENERAL_FINAL_PROMPT = """你是一个擅长文字处理和信息总结的智能助手,你的任务是将提供的网页摘要信息进行总结,并以markdown的格式进行返回, +请添加适当的词语,使得语句内容连贯,通顺 +请将content和snippet的信息进行综合处理,进行总结,生成一个段落。 +涉及到url字段时,使用超链接的格式将网页url链接到网页title上。 +参数介绍】: +title:网页标题 +url:网页链接 +snippet:网页摘要信息 +content:网页的内容总结 +下面是JSON格式的输入: +{text} +请生成markdown段落:""" + +WEATHER_FINAL_PROMPT = """你是一个擅长文字处理和信息总结的智能助手, +当前的工作场景是:天气出行建议;输入的内容是JSON格式的用户所查询城市未来的天气预报,请将这些信息总结为的自然语言展示天气预报的信息,并对用户的出游给除建议, +根据天气的情况,你可以做出一些出行建议,比如是否需要雨具、防晒、保暖等 +请添加适当的词语,使得语句内容连贯,通顺,并尽可能保留输入的信息和数据,但不要自行杜撰信息。 +提供的信息以JSON的格式进行展示 +【参数介绍】: +date:日期 +day_weather:白天的天气情况 +day_wind_direction:白天风向 +day_wind_power: 白天风力 +night_weather:夜晚的天气情况 +night_wind_direction:夜晚风向 +night_wind_power: 夜晚风力 +max_degree: 最高温 +min_degree:最低温 +下面是JSON格式的输入: +{text} +请生成markdown段落: +""" +PLANNER_FINAL_PROMPT = """你是一个擅长规划和文字处理的智能助手,你需要将提供的信息按照下面的步骤撰写一份旅游攻略,输出markdown格式的段落, +你可以添加适当的语句,使得段落通顺,但不要自己杜撰信息。 +步骤】 +1. 根据【用户需要旅行的天数】,将输入的景点分配到每一天的行程中,每天2-3个景点,并介绍景点的详细情况 +2. 叙述输入中推荐的住宿情况,详细介绍酒店的详细情况,和预定链接 +3. 叙述输入中查询的交通安排,详细介绍每个出行方案的价格、时间、时长等详细情况,和预定链接 +4. 介绍输入中天气预报的情况,根据天气的情况,你可以做出一些出行建议,比如是否需要雨具、防晒、保暖等 +【参数介绍】: +title:网页标题 +url:网页链接,满足用户需求的酒店筛选结果 +content:网页主要内容提取物 +snippet:网页摘要信息 +输入的信息以JSON格式,下面是的输入: +{text} +请生成markdown段落:""" + + + +TRAVEL_PLAN = "TRAVEL_PLAN" +QUERY_ATTRACTION = "QUERY_ATTRACTION" +QUERY_HOTEL = "QUERY_HOTEL" +QUERY_TRANSPORT = "QUERY_TRANSPORT" +QUERY_WEATHER = "QUERY_WEATHER" +OTHERS = "OTHERS" + +classifer = [TRAVEL_PLAN, QUERY_ATTRACTION, QUERY_HOTEL, QUERY_TRANSPORT, QUERY_WEATHER, OTHERS] + +INST_MAP = { + TRAVEL_PLAN :PESUEDE_CODE, + QUERY_ATTRACTION :ATTRACTION_INST, + QUERY_HOTEL:HOTEL_INST, + QUERY_TRANSPORT :TRANSPORT_INST, + QUERY_WEATHER :WEATHER_INST, + OTHERS:OTHER_INST +} + +FINAL_PMT_MAP = { + TRAVEL_PLAN :PLANNER_FINAL_PROMPT, + QUERY_ATTRACTION :GENERAL_FINAL_PROMPT, + QUERY_HOTEL:GENERAL_FINAL_PROMPT, + QUERY_TRANSPORT :GENERAL_FINAL_PROMPT, + QUERY_WEATHER :WEATHER_FINAL_PROMPT, + OTHERS: GENERAL_FINAL_PROMPT + +} + +TOOL_LIST_MAP = { + TRAVEL_PLAN :[QueryAccommodations, QueryAttractions, QueryTransports, PlanSummary, QueryWeather], + QUERY_ATTRACTION :[QueryAttractions], + QUERY_HOTEL:[QueryAccommodations], + QUERY_TRANSPORT :[QueryTransports], + QUERY_WEATHER : [QueryWeather], + OTHERS:[] +} + +intents = { + TRAVEL_PLAN :"询问旅行规划,问题中要求旅游项目日程安排、交通查询、查询当地住宿等方面的能力", + QUERY_ATTRACTION :"查询旅游项目、景区、旅游活动", + QUERY_HOTEL: "仅查询酒店和住宿信息", + QUERY_TRANSPORT : "与现实中出行、乘坐交通、如高铁、动车、飞机、火车等相关的意图", + QUERY_WEATHER :"包括气温、湿度、降水等与天气、天气预报相关的意图", + OTHERS :"与旅游场景不相干的查询" +} +LLM_MODEL = get_llm_backend(backend=BACKEND_OPENAI_COMPATIBLE, + api_base="http://10.44.115.108:1055/v1", api_key="EMPTY", llm_name="Qwen1.5-32B-Chat").run + + +class TalkShowAgent(ToollessAgent, ABC): + def __init__(self, llm, prompt="你的名字叫昇腾智搜,是一个帮助用户完成旅行规划的助手,你的能力范围包括:'\ + 目的地推荐、行程规划、交通信息查询、酒店住宿推荐、旅行攻略推荐,请利用你的知识回答问题,这是用户的问题:{query}", + **kwargs): + super().__init__(llm, prompt, **kwargs) + self.query = "" + + def _build_agent_prompt(self, **kwargs): + return self.prompt.format( + query=self.query + ) + + +class TravelAgent: + @classmethod + def route_query(cls, query): + router_agent = RouterAgent(llm=LLM_MODEL, intents=intents) + classify = router_agent.run(query).answer + if classify not in classifer or classify == OTHERS: + return TalkShowAgent(llm=LLM_MODEL) + return RecipeAgent(name=classify, + description="你的名字叫昇腾智搜,是一个帮助用户完成旅行规划的助手,你的能力范围包括:目的地推荐、行程规划、交通信息查询、酒店住宿推荐、旅行攻略推荐", + llm=LLM_MODEL, + tool_list=TOOL_LIST_MAP[classify], + recipe=INST_MAP[classify], + max_steps=3, + max_token_number=4096, + final_prompt=FINAL_PMT_MAP[classify]) + + def run(self, query, stream): + agent = self.route_query(query) + return agent.run(query, stream=stream) + +if __name__ == "__main__": + # request = "去北京的旅游规划" + # request = "从北京到西安的机票" + # request = "查询北京王府井附近的高档酒店" + # request = "泰国有哪些值得推荐的景点" + # request = "帮我查一下北京最近的天气" + # request = "上海酒店查询" + # request = "北京到上海的高铁" + # request = "上海天气怎么样" + request = "帮我制定一份从北京到上海6天的旅游计划" + + travel_agent = TravelAgent() + res = travel_agent.run(request, stream=True) + if isinstance(res, AgentRunResult): + logger.info("-----------run agent success-------------") + logger.info(res.answer) + else: + for char in res: + logger.info(char)