From fab48e8e0603ea5238877c5587dbd63244a29826 Mon Sep 17 00:00:00 2001 From: ObjectNotFound Date: Fri, 13 Dec 2024 14:56:22 +0800 Subject: [PATCH] Framework 930 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 移除冗余代码 --- .dockerignore | 9 + .env.example | 82 ++++++ .gitignore | 19 ++ Dockerfile | 14 + Dockerfile-base | 17 ++ Jenkinsfile | 41 +++ LICENSE | 194 +++++++++++++ apps/__init__.py | 1 + apps/common/__init__.py | 1 + apps/common/config.py | 125 +++++++++ apps/common/cryptohub.py | 29 ++ apps/common/oidc.py | 169 ++++++++++++ apps/common/security.py | 115 ++++++++ apps/common/singleton.py | 17 ++ apps/common/thread.py | 21 ++ apps/common/wordscheck.py | 66 +++++ apps/constants.py | 4 + apps/cron/__init__.py | 1 + apps/cron/delete_user.py | 37 +++ apps/dependency/__init__.py | 1 + apps/dependency/csrf.py | 24 ++ apps/dependency/limit.py | 29 ++ apps/dependency/session.py | 51 ++++ apps/dependency/user.py | 61 +++++ apps/entities/__init__.py | 1 + apps/entities/blacklist.py | 27 ++ apps/entities/comment.py | 11 + apps/entities/plugin.py | 43 +++ apps/entities/request_data.py | 57 ++++ apps/entities/response_data.py | 51 ++++ apps/entities/user.py | 9 + apps/gunicorn.conf.py | 33 +++ apps/llm.py | 117 ++++++++ apps/logger/__init__.py | 100 +++++++ apps/main.py | 85 ++++++ apps/manager/__init__.py | 1 + apps/manager/api_key.py | 102 +++++++ apps/manager/audit_log.py | 34 +++ apps/manager/blacklist.py | 300 ++++++++++++++++++++ apps/manager/comment.py | 58 ++++ apps/manager/conversation.py | 115 ++++++++ apps/manager/domain.py | 82 ++++++ apps/manager/gitee_white_list.py | 23 ++ apps/manager/plugin_token.py | 84 ++++++ apps/manager/record.py | 126 +++++++++ apps/manager/session.py | 166 +++++++++++ apps/manager/user.py | 108 ++++++++ apps/manager/user_domain.py | 77 ++++++ apps/models/__init__.py | 1 + apps/models/mysql.py | 196 +++++++++++++ apps/models/redis.py | 43 +++ apps/routers/__init__.py | 1 + apps/routers/api_key.py | 44 +++ apps/routers/auth.py | 170 ++++++++++++ apps/routers/blacklist.py | 64 +++++ apps/routers/chat.py | 199 ++++++++++++++ apps/routers/client.py | 77 ++++++ apps/routers/comment.py | 48 ++++ apps/routers/conversation.py | 127 +++++++++ apps/routers/domain.py | 49 ++++ apps/routers/file.py | 51 ++++ apps/routers/health.py | 13 + apps/routers/plugin.py | 22 ++ apps/routers/record.py | 43 +++ apps/scheduler/__init__.py | 1 + apps/scheduler/call/__init__.py | 18 ++ apps/scheduler/call/api/__init__.py | 1 + apps/scheduler/call/api/api.py | 199 ++++++++++++++ apps/scheduler/call/api/sanitizer.py | 111 ++++++++ apps/scheduler/call/choice.py | 49 ++++ apps/scheduler/call/core.py | 55 ++++ apps/scheduler/call/extract.py | 57 ++++ apps/scheduler/call/llm.py | 73 +++++ apps/scheduler/call/render/__init__.py | 1 + apps/scheduler/call/render/option.json | 20 ++ apps/scheduler/call/render/render.py | 118 ++++++++ apps/scheduler/call/render/style.py | 169 ++++++++++++ apps/scheduler/call/sql.py | 83 ++++++ apps/scheduler/core.py | 49 ++++ apps/scheduler/encoder.py | 29 ++ apps/scheduler/executor/__init__.py | 7 + apps/scheduler/executor/flow.py | 178 ++++++++++++ apps/scheduler/files.py | 138 ++++++++++ apps/scheduler/gen_json.py | 167 +++++++++++ apps/scheduler/parse_json.py | 55 ++++ apps/scheduler/pool/__init__.py | 1 + apps/scheduler/pool/entities.py | 33 +++ apps/scheduler/pool/loader.py | 233 ++++++++++++++++ apps/scheduler/pool/pool.py | 305 +++++++++++++++++++++ apps/scheduler/scheduler.py | 206 ++++++++++++++ apps/scheduler/utils/__init__.py | 21 ++ apps/scheduler/utils/backprop.py | 47 ++++ apps/scheduler/utils/consistency.py | 159 +++++++++++ apps/scheduler/utils/evaluate.py | 127 +++++++++ apps/scheduler/utils/json.py | 227 +++++++++++++++ apps/scheduler/utils/recommend.py | 59 ++++ apps/scheduler/utils/reflect.py | 125 +++++++++ apps/scheduler/utils/select.py | 159 +++++++++++ apps/scheduler/utils/summary.py | 91 ++++++ apps/scheduler/vector.py | 132 +++++++++ apps/service/__init__.py | 10 + apps/service/activity.py | 34 +++ apps/service/domain.py | 98 +++++++ apps/service/history.py | 59 ++++ apps/service/rag.py | 46 ++++ apps/service/suggestion.py | 34 +++ apps/service/summary.py | 19 ++ apps/utils/user_exporter.py | 242 ++++++++++++++++ assets/euler-copilot-frame.sql | 77 ++++++ assets/host.example.json | 12 + requirements.txt | 41 +++ sdk/example_plugin/flows/flow.yaml | 53 ++++ sdk/example_plugin/lib/__init__.py | 8 + sdk/example_plugin/lib/sub_lib/__init__.py | 0 sdk/example_plugin/lib/user_tool.py | 45 +++ sdk/example_plugin/openapi.yaml | 37 +++ sdk/example_plugin/plugin.json | 11 + 117 files changed, 8415 insertions(+) create mode 100644 .dockerignore create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 Dockerfile-base create mode 100644 Jenkinsfile create mode 100644 LICENSE create mode 100644 apps/__init__.py create mode 100644 apps/common/__init__.py create mode 100644 apps/common/config.py create mode 100644 apps/common/cryptohub.py create mode 100644 apps/common/oidc.py create mode 100644 apps/common/security.py create mode 100644 apps/common/singleton.py create mode 100644 apps/common/thread.py create mode 100644 apps/common/wordscheck.py create mode 100644 apps/constants.py create mode 100644 apps/cron/__init__.py create mode 100644 apps/cron/delete_user.py create mode 100644 apps/dependency/__init__.py create mode 100644 apps/dependency/csrf.py create mode 100644 apps/dependency/limit.py create mode 100644 apps/dependency/session.py create mode 100644 apps/dependency/user.py create mode 100644 apps/entities/__init__.py create mode 100644 apps/entities/blacklist.py create mode 100644 apps/entities/comment.py create mode 100644 apps/entities/plugin.py create mode 100644 apps/entities/request_data.py create mode 100644 apps/entities/response_data.py create mode 100644 apps/entities/user.py create mode 100644 apps/gunicorn.conf.py create mode 100644 apps/llm.py create mode 100644 apps/logger/__init__.py create mode 100644 apps/main.py create mode 100644 apps/manager/__init__.py create mode 100644 apps/manager/api_key.py create mode 100644 apps/manager/audit_log.py create mode 100644 apps/manager/blacklist.py create mode 100644 apps/manager/comment.py create mode 100644 apps/manager/conversation.py create mode 100644 apps/manager/domain.py create mode 100644 apps/manager/gitee_white_list.py create mode 100644 apps/manager/plugin_token.py create mode 100644 apps/manager/record.py create mode 100644 apps/manager/session.py create mode 100644 apps/manager/user.py create mode 100644 apps/manager/user_domain.py create mode 100644 apps/models/__init__.py create mode 100644 apps/models/mysql.py create mode 100644 apps/models/redis.py create mode 100644 apps/routers/__init__.py create mode 100644 apps/routers/api_key.py create mode 100644 apps/routers/auth.py create mode 100644 apps/routers/blacklist.py create mode 100644 apps/routers/chat.py create mode 100644 apps/routers/client.py create mode 100644 apps/routers/comment.py create mode 100644 apps/routers/conversation.py create mode 100644 apps/routers/domain.py create mode 100644 apps/routers/file.py create mode 100644 apps/routers/health.py create mode 100644 apps/routers/plugin.py create mode 100644 apps/routers/record.py create mode 100644 apps/scheduler/__init__.py create mode 100644 apps/scheduler/call/__init__.py create mode 100644 apps/scheduler/call/api/__init__.py create mode 100644 apps/scheduler/call/api/api.py create mode 100644 apps/scheduler/call/api/sanitizer.py create mode 100644 apps/scheduler/call/choice.py create mode 100644 apps/scheduler/call/core.py create mode 100644 apps/scheduler/call/extract.py create mode 100644 apps/scheduler/call/llm.py create mode 100644 apps/scheduler/call/render/__init__.py create mode 100644 apps/scheduler/call/render/option.json create mode 100644 apps/scheduler/call/render/render.py create mode 100644 apps/scheduler/call/render/style.py create mode 100644 apps/scheduler/call/sql.py create mode 100644 apps/scheduler/core.py create mode 100644 apps/scheduler/encoder.py create mode 100644 apps/scheduler/executor/__init__.py create mode 100644 apps/scheduler/executor/flow.py create mode 100644 apps/scheduler/files.py create mode 100644 apps/scheduler/gen_json.py create mode 100644 apps/scheduler/parse_json.py create mode 100644 apps/scheduler/pool/__init__.py create mode 100644 apps/scheduler/pool/entities.py create mode 100644 apps/scheduler/pool/loader.py create mode 100644 apps/scheduler/pool/pool.py create mode 100644 apps/scheduler/scheduler.py create mode 100644 apps/scheduler/utils/__init__.py create mode 100644 apps/scheduler/utils/backprop.py create mode 100644 apps/scheduler/utils/consistency.py create mode 100644 apps/scheduler/utils/evaluate.py create mode 100644 apps/scheduler/utils/json.py create mode 100644 apps/scheduler/utils/recommend.py create mode 100644 apps/scheduler/utils/reflect.py create mode 100644 apps/scheduler/utils/select.py create mode 100644 apps/scheduler/utils/summary.py create mode 100644 apps/scheduler/vector.py create mode 100644 apps/service/__init__.py create mode 100644 apps/service/activity.py create mode 100644 apps/service/domain.py create mode 100644 apps/service/history.py create mode 100644 apps/service/rag.py create mode 100644 apps/service/suggestion.py create mode 100644 apps/service/summary.py create mode 100644 apps/utils/user_exporter.py create mode 100644 assets/euler-copilot-frame.sql create mode 100644 assets/host.example.json create mode 100644 requirements.txt create mode 100644 sdk/example_plugin/flows/flow.yaml create mode 100644 sdk/example_plugin/lib/__init__.py create mode 100644 sdk/example_plugin/lib/sub_lib/__init__.py create mode 100644 sdk/example_plugin/lib/user_tool.py create mode 100644 sdk/example_plugin/openapi.yaml create mode 100644 sdk/example_plugin/plugin.json diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..859ebc26a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +*venv/ +**/__pycache__/ +Dockerfile +.dockerignore +.idea +.vscode +*.bak +.gitignore +.git/ \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..3085c166e --- /dev/null +++ b/.env.example @@ -0,0 +1,82 @@ +DEPLOY_MODE= +COOKIE_MODE= +# WEB +WEB_FRONT_URL= + +# Plugin_Token_URL +AOPS_TOKEN_URL= +AOPS_TOKEN_EXPIRE_TIME= + +# Redis +REDIS_HOST= +REDIS_PORT= +REDIS_PWD= + +# OIDC +OIDC_APP_ID= +OIDC_APP_SECRET= +OIDC_USER_URL= +OIDC_TOKEN_URL= +OIDC_REFRESH_TOKEN_URL= +OIDC_REDIRECT_URL= +EULER_LOGIN_API= +OIDC_ACCESS_TOKEN_EXPIRE_TIME= +OIDC_REFRESH_TOKEN_EXPIRE_TIME= +SESSION_TTL= + + +# Sensitive Word +DETECT_TYPE= +WORDS_CHECK= +WORDS_LIST= + +# logging +LOG= + +# Vectorize +VECTORIZE_HOST= + +# RAG +RAG_HOST= +RAG_KB_SN= + +# FastAPI +UVICORN_HOST= +UVICORN_PORT= +SSL_ENABLE= +SSL_CERTFILE= +SSL_KEYFILE= +SSL_KEY_PWD= +DOMAIN= +JWT_KEY= + +# LLM +MODEL= +## Spark +SPARK_APP_ID= +SPARK_API_KEY= +SPARK_API_SECRET= +SPARK_API_URL= +SPARK_LLM_DOMAIN= +## OpenAI Compatible +LLM_URL= +LLM_KEY= +LLM_MODEL_NAME= + +# 调度 +SCHEDULER_URL= +SCHEDULER_API_KEY= +PLUGIN_DIR= + +# MySQL +MYSQL_HOST= +MYSQL_DATABASE= +MYSQL_USER= +MYSQL_PWD= + +# PostgresSQL +POSTGRES_HOST= +POSTGRES_PORT= +POSTGRES_DATABASE= +POSTGRES_USER= +POSTGRES_PWD= diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..0aff34953 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +*venv/ +**/__pycache__/ +**/*.pyc +.env +# ide +.idea +.vscode +.chroma +*.key +*.crt +apps/utils/init +start.sh +*.bak +encrypted_config.json +apps/embedding +apps/scheduler/plugin +logs +data/ +plugins/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..ed9b14e92 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +FROM hub.oepkgs.net/neocopilot/framework-baseimg:0.9.1 + +USER root +RUN sed -i 's/umask 002/umask 027/g' /etc/bashrc && \ + sed -i 's/umask 022/umask 027/g' /etc/bashrc && \ + yum remove -y gdb-gdbserver + +USER eulercopilot +COPY --chown=1001:1001 --chmod=550 ./ /euler-copilot-frame/ + +WORKDIR /euler-copilot-frame +ENV PYTHONPATH /euler-copilot-frame + +CMD bash -c "python3 -m gunicorn -c apps/gunicorn.conf.py apps.main:app" diff --git a/Dockerfile-base b/Dockerfile-base new file mode 100644 index 000000000..49c1e10c4 --- /dev/null +++ b/Dockerfile-base @@ -0,0 +1,17 @@ +FROM hub.oepkgs.net/openeuler/openeuler:22.03-lts-sp4 + +ENV PATH /home/eulercopilot/.local/bin:$PATH +RUN sed -i 's|repo.openeuler.org|mirrors.nju.edu.cn/openeuler|g' /etc/yum.repos.d/openEuler.repo && \ + sed -i '/metalink/d' /etc/yum.repos.d/openEuler.repo && \ + sed -i '/metadata_expire/d' /etc/yum.repos.d/openEuler.repo && \ + yum update -y &&\ + yum install -y python3 python3-pip shadow-utils findutils &&\ + groupadd -g 1001 eulercopilot && useradd -u 1001 -g eulercopilot eulercopilot &&\ + yum clean all + +USER eulercopilot + +COPY --chown=1001:1001 requirements.txt . + +RUN pip3 install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple && \ + chmod -R 750 /home/eulercopilot \ No newline at end of file diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 000000000..8fb5afa1f --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,41 @@ +node { + echo "拉取代码仓库" + checkout scm + + def REPO = scm.getUserRemoteConfigs()[0].getUrl().tokenize('/').last().split("\\.")[0] + def BRANCH = scm.branches[0].name.split("/")[1] + def BUILD = sh(script: 'git rev-parse --short HEAD', returnStdout: true).trim() + + withCredentials([string(credentialsId: "host", variable: "HOST")]) { + echo "构建当前分支Docker Image镜像" + sh "sed -i 's|framework_base|${HOST}:30000/framework-baseimg|g' Dockerfile" + docker.withRegistry("http://${HOST}:30000", "dockerAuth") { + def image = docker.build("${HOST}:30000/${REPO}:${BUILD}", "-f Dockerfile .") + image.push() + image.push("${BRANCH}") + } + + def remote = [:] + remote.name = "machine" + remote.host = "${HOST}" + withCredentials([usernamePassword(credentialsId: "ssh", usernameVariable: 'sshUser', passwordVariable: 'sshPass')]) { + remote.user = sshUser + remote.password = sshPass + } + remote.allowAnyHosts = true + + echo "清除构建缓存" + sshCommand remote: remote, command: "sh -c \"docker rmi ${HOST}:30000/${REPO}:${BUILD} || true\";" + sshCommand remote: remote, command: "sh -c \"docker rmi ${REPO}:${BUILD} || true\";" + sshCommand remote: remote, command: "sh -c \"docker rmi ${REPO}:${BRANCH} || true\";" + sshCommand remote: remote, command: "sh -c \"docker image prune -f || true\";"; + sshCommand remote: remote, command: "sh -c \"docker builder prune -f || true\";"; + sshCommand remote: remote, command: "sh -c \"k3s crictl rmi --prune || true\";"; + + echo "重新部署" + withCredentials([usernamePassword(credentialsId: "dockerAuth", usernameVariable: 'dockerUser', passwordVariable: 'dockerPass')]) { + sshCommand remote: remote, command: "sh -c \"cd /home/registry/registry-cli; python3 ./registry.py -l ${dockerUser}:${dockerPass} -r http://${HOST}:30000 --delete --keep-tags 'master' '0001' '330-feature' '430-feature' || true\";" + } + sshCommand remote: remote, command: "sh -c \"kubectl -n euler-copilot set image deployment/framework-deploy framework=${HOST}:30000/${REPO}:${BUILD}\";" + } +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..f63f5a9cf --- /dev/null +++ b/LICENSE @@ -0,0 +1,194 @@ +木兰宽松许可证,第2版 + +木兰宽松许可证,第2版 + +2020年1月 http://license.coscl.org.cn/MulanPSL2 + +您对“软件”的复制、使用、修改及分发受木兰宽松许可证,第2版(“本许可证”)的如下条款的约束: + +0. 定义 + +“软件” 是指由“贡献”构成的许可在“本许可证”下的程序和相关文档的集合。 + +“贡献” 是指由任一“贡献者”许可在“本许可证”下的受版权法保护的作品。 + +“贡献者” 是指将受版权法保护的作品许可在“本许可证”下的自然人或“法人实体”。 + +“法人实体” 是指提交贡献的机构及其“关联实体”。 + +“关联实体” 是指,对“本许可证”下的行为方而言,控制、受控制或与其共同受控制的机构,此处的控制是 +指有受控方或共同受控方至少50%直接或间接的投票权、资金或其他有价证券。 + +1. 授予版权许可 + +每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的版权许可,您可 +以复制、使用、修改、分发其“贡献”,不论修改与否。 + +2. 授予专利许可 + +每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的(根据本条规定 +撤销除外)专利许可,供您制造、委托制造、使用、许诺销售、销售、进口其“贡献”或以其他方式转移其“贡 +献”。前述专利许可仅限于“贡献者”现在或将来拥有或控制的其“贡献”本身或其“贡献”与许可“贡献”时的“软 +件”结合而将必然会侵犯的专利权利要求,不包括对“贡献”的修改或包含“贡献”的其他结合。如果您或您的“ +关联实体”直接或间接地,就“软件”或其中的“贡献”对任何人发起专利侵权诉讼(包括反诉或交叉诉讼)或 +其他专利维权行动,指控其侵犯专利权,则“本许可证”授予您对“软件”的专利许可自您提起诉讼或发起维权 +行动之日终止。 + +3. 无商标许可 + +“本许可证”不提供对“贡献者”的商品名称、商标、服务标志或产品名称的商标许可,但您为满足第4条规定 +的声明义务而必须使用除外。 + +4. 分发限制 + +您可以在任何媒介中将“软件”以源程序形式或可执行形式重新分发,不论修改与否,但您必须向接收者提供“ +本许可证”的副本,并保留“软件”中的版权、商标、专利及免责声明。 + +5. 免责声明与责任限制 + +“软件”及其中的“贡献”在提供时不带任何明示或默示的担保。在任何情况下,“贡献者”或版权所有者不对 +任何人因使用“软件”或其中的“贡献”而引发的任何直接或间接损失承担责任,不论因何种原因导致或者基于 +何种法律理论,即使其曾被建议有此种损失的可能性。 + +6. 语言 + +“本许可证”以中英文双语表述,中英文版本具有同等法律效力。如果中英文版本存在任何冲突不一致,以中文 +版为准。 + +条款结束 + +如何将木兰宽松许可证,第2版,应用到您的软件 + +如果您希望将木兰宽松许可证,第2版,应用到您的新软件,为了方便接收者查阅,建议您完成如下三步: + +1, 请您补充如下声明中的空白,包括软件名、软件的首次发表年份以及您作为版权人的名字; + +2, 请您在软件包的一级目录下创建以“LICENSE”为名的文件,将整个许可证文本放入该文件中; + +3, 请将如下声明文本放入每个源文件的头部注释中。 + +Copyright (c) [Year] [name of copyright holder] +[Software Name] is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan +PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. + +Mulan Permissive Software License,Version 2 + +Mulan Permissive Software License,Version 2 (Mulan PSL v2) + +January 2020 http://license.coscl.org.cn/MulanPSL2 + +Your reproduction, use, modification and distribution of the Software shall +be subject to Mulan PSL v2 (this License) with the following terms and +conditions: + +0. Definition + +Software means the program and related documents which are licensed under +this License and comprise all Contribution(s). + +Contribution means the copyrightable work licensed by a particular +Contributor under this License. + +Contributor means the Individual or Legal Entity who licenses its +copyrightable work under this License. + +Legal Entity means the entity making a Contribution and all its +Affiliates. + +Affiliates means entities that control, are controlled by, or are under +common control with the acting entity under this License, ‘control’ means +direct or indirect ownership of at least fifty percent (50%) of the voting +power, capital or other securities of controlled or commonly controlled +entity. + +1. Grant of Copyright License + +Subject to the terms and conditions of this License, each Contributor hereby +grants to you a perpetual, worldwide, royalty-free, non-exclusive, +irrevocable copyright license to reproduce, use, modify, or distribute its +Contribution, with modification or not. + +2. Grant of Patent License + +Subject to the terms and conditions of this License, each Contributor hereby +grants to you a perpetual, worldwide, royalty-free, non-exclusive, +irrevocable (except for revocation under this Section) patent license to +make, have made, use, offer for sale, sell, import or otherwise transfer its +Contribution, where such patent license is only limited to the patent claims +owned or controlled by such Contributor now or in future which will be +necessarily infringed by its Contribution alone, or by combination of the +Contribution with the Software to which the Contribution was contributed. +The patent license shall not apply to any modification of the Contribution, +and any other combination which includes the Contribution. If you or your +Affiliates directly or indirectly institute patent litigation (including a +cross claim or counterclaim in a litigation) or other patent enforcement +activities against any individual or entity by alleging that the Software or +any Contribution in it infringes patents, then any patent license granted to +you under this License for the Software shall terminate as of the date such +litigation or activity is filed or taken. + +3. No Trademark License + +No trademark license is granted to use the trade names, trademarks, service +marks, or product names of Contributor, except as required to fulfill notice +requirements in section 4. + +4. Distribution Restriction + +You may distribute the Software in any medium with or without modification, +whether in source or executable forms, provided that you provide recipients +with a copy of this License and retain copyright, patent, trademark and +disclaimer statements in the Software. + +5. Disclaimer of Warranty and Limitation of Liability + +THE SOFTWARE AND CONTRIBUTION IN IT ARE PROVIDED WITHOUT WARRANTIES OF ANY +KIND, EITHER EXPRESS OR IMPLIED. IN NO EVENT SHALL ANY CONTRIBUTOR OR +COPYRIGHT HOLDER BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT +LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING +FROM YOUR USE OR INABILITY TO USE THE SOFTWARE OR THE CONTRIBUTION IN IT, NO +MATTER HOW IT’S CAUSED OR BASED ON WHICH LEGAL THEORY, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +6. Language + +THIS LICENSE IS WRITTEN IN BOTH CHINESE AND ENGLISH, AND THE CHINESE VERSION +AND ENGLISH VERSION SHALL HAVE THE SAME LEGAL EFFECT. IN THE CASE OF +DIVERGENCE BETWEEN THE CHINESE AND ENGLISH VERSIONS, THE CHINESE VERSION +SHALL PREVAIL. + +END OF THE TERMS AND CONDITIONS + +How to Apply the Mulan Permissive Software License,Version 2 +(Mulan PSL v2) to Your Software + +To apply the Mulan PSL v2 to your work, for easy identification by +recipients, you are suggested to complete following three steps: + +i. Fill in the blanks in following statement, including insert your software +name, the year of the first publication of your software, and your name +identified as the copyright owner; + +ii. Create a file named "LICENSE" which contains the whole context of this +License in the first directory of your software package; + +iii. Attach the statement to the appropriate annotated syntax at the +beginning of each source file. + +Copyright (c) [Year] [name of copyright holder] +[Software Name] is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan +PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/common/__init__.py b/apps/common/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/common/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/common/config.py b/apps/common/config.py new file mode 100644 index 000000000..1bc8c8480 --- /dev/null +++ b/apps/common/config.py @@ -0,0 +1,125 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import os +from typing import Optional +import secrets + +from dotenv import dotenv_values +from pydantic import BaseModel, Field + + +class ConfigModel(BaseModel): + """ + 配置文件的校验Class + """ + + # DEPLOY + DEPLOY_MODE: str = Field(description="oidc 部署方式", default="online") + COOKIE_MODE: str = Field(description="COOKIE SET 方式", default="domain") + # WEB + WEB_FRONT_URL: str = Field(description="web前端地址") + # Redis + REDIS_HOST: str = Field(description="Redis主机名") + REDIS_PORT: int = Field(description="Redis端口号", default=6379) + REDIS_PWD: str = Field(description="Redis连接密码") + # OIDC + DISABLE_LOGIN: bool = Field(description="是否禁用登录", default=False) + DEFAULT_USER: Optional[str] = Field(description="禁用登录后,默认的用户ID", default=None) + OIDC_APP_ID: Optional[str] = Field(description="OIDC AppID", default=None) + OIDC_APP_SECRET: Optional[str] = Field(description="OIDC App Secret", default=None) + OIDC_USER_URL: Optional[str] = Field(description="OIDC USER URL", default=None) + OIDC_TOKEN_URL: Optional[str] = Field(description="OIDC Token获取URL", default=None) + OIDC_REFRESH_TOKEN_URL: Optional[str] = Field(description="OIDC 刷新token", default=None) + OIDC_REDIRECT_URL: Optional[str] = Field(description="OIDC Redirect URL", default=None) + EULER_LOGIN_API: Optional[str] = Field(description="Euler Login API", default=None) + OIDC_ACCESS_TOKEN_EXPIRE_TIME: int = Field(description="OIDC access token过期时间", default=30) + OIDC_REFRESH_TOKEN_EXPIRE_TIME: int = Field(description="OIDC refresh token过期时间", default=180) + SESSION_TTL: int = Field(description="用户需要刷新Token的间隔(min)", default=30) + # Logging + LOG: str = Field(description="日志记录模式") + # Vectorize + VECTORIZE_HOST: str = Field(description="Vectorize服务域名") + # RAG + RAG_HOST: str = Field(description="RAG服务域名") + RAG_KB_SN: Optional[str] = Field(description="RAG 资产库", default=None) + # FastAPI + DOMAIN: str = Field(description="当前实例的域名") + JWT_KEY: str = Field(description="JWT key", default=secrets.token_hex(16)) + PICKLE_KEY: str = Field(description="Pickle Key", default=secrets.token_hex(16)) + # 风控 + DETECT_TYPE: Optional[str] = Field(description="敏感词检测系统类型", default=None) + WORDS_CHECK: Optional[str] = Field(description="AutoGPT敏感词检测系统API URL", default=None) + WORDS_LIST: Optional[str] = Field(description="敏感词列表文件路径", default=None) + # CSRF + ENABLE_CSRF: bool = Field(description="是否启用CSRF Token功能", default=True) + # MySQL + MYSQL_HOST: str = Field(description="MySQL主机名、端口号") + MYSQL_DATABASE: str = Field(description="MySQL数据库名") + MYSQL_USER: str = Field(description="MySQL用户名") + MYSQL_PWD: str = Field(description="MySQL密码") + # PGSQL + POSTGRES_HOST: str = Field(description="PGSQL主机名、端口号") + POSTGRES_DATABASE: str = Field(description="PGSQL数据库名") + POSTGRES_USER: str = Field(description="PGSQL用户名") + POSTGRES_PWD: str = Field(description="PGSQL密码") + # Security + HALF_KEY1: str = Field(description="Half key 1") + HALF_KEY2: str = Field(description="Half key 2") + HALF_KEY3: str = Field(description="Half key 3") + # 模型类型 + MODEL: str = Field(description="选择的模型类型", default="openai") + # OpenAI API + LLM_KEY: Optional[str] = Field(description="OpenAI API 密钥", default=None) + LLM_URL: Optional[str] = Field(description="OpenAI API URL地址", default=None) + LLM_MODEL: Optional[str] = Field(description="OpenAI API 模型名", default=None) + # 星火大模型 + SPARK_APP_ID: Optional[str] = Field(description="星火大模型API 应用名", default=None) + SPARK_API_KEY: Optional[str] = Field(description="星火大模型API 密钥名", default=None) + SPARK_API_SECRET: Optional[str] = Field(description="星火大模型API 密钥值", default=None) + SPARK_API_URL: Optional[str] = Field(description="星火大模型API URL地址", default=None) + SPARK_LLM_DOMAIN: Optional[str] = Field(description="星火大模型API 领域名", default=None) + # 参数猜解 + SCHEDULER_BACKEND: Optional[str] = Field(description="参数猜解后端", default=None) + SCHEDULER_URL: Optional[str] = Field(description="参数猜解 URL地址", default=None) + SCHEDULER_API_KEY: Optional[str] = Field(description="参数猜解 API密钥", default=None) + SCHEDULER_STRUCTURED_OUTPUT: Optional[bool] = Field(description="是否启用结构化输出", default=True) + # 插件位置 + PLUGIN_DIR: Optional[str] = Field(description="插件路径", default=None) + # 临时路径 + TEMP_DIR: str = Field(description="临时目录位置", default="/tmp") + # SQL接口路径 + SQL_URL: str = Field(description="Chat2DB接口路径") + + +class Config: + """ + 配置文件读取和使用Class + """ + + config: ConfigModel + + def __init__(self): + """ + 读取配置文件;当PROD环境变量设置时,配置文件将在读取后删除 + """ + if os.getenv("CONFIG"): + config_file = os.getenv("CONFIG") + else: + config_file = "./config/.env" + self.config = ConfigModel(**(dotenv_values(config_file))) + + if os.getenv("PROD"): + os.remove(config_file) + + def __getitem__(self, key): + """ + 获得配置文件中特定条目的值 + :param key: 配置文件条目名 + :return: 条目的值;不存在则返回None + """ + if key in self.config.__dict__: + return self.config.__dict__[key] + else: + return None + + +config = Config() diff --git a/apps/common/cryptohub.py b/apps/common/cryptohub.py new file mode 100644 index 000000000..a7bedf362 --- /dev/null +++ b/apps/common/cryptohub.py @@ -0,0 +1,29 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import hashlib + +from apps.common.security import Security + + +class CryptoHub: + + @staticmethod + def generate_str_from_sha256(plain_txt): + hash_object = hashlib.sha256(plain_txt.encode('utf-8')) + hex_dig = hash_object.hexdigest() + return hex_dig[:] + + @staticmethod + def decrypt_with_config(encrypted_plaintext): + secret_dict_key_list = [ + "encrypted_work_key", + "encrypted_work_key_iv", + "encrypted_iv", + "half_key1" + ] + encryption_config = {} + for key in secret_dict_key_list: + encryption_config[key] = encrypted_plaintext[1][CryptoHub.generate_str_from_sha256( + key)] + plaintext = Security.decrypt(encrypted_plaintext[0], encryption_config) + del encryption_config + return plaintext diff --git a/apps/common/oidc.py b/apps/common/oidc.py new file mode 100644 index 000000000..bef67589a --- /dev/null +++ b/apps/common/oidc.py @@ -0,0 +1,169 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from __future__ import annotations + +from typing import Dict, Any + +import aiohttp +import logging + + +from apps.common.config import config +from apps.models.redis import RedisConnectionPool +from apps.manager.gitee_white_list import GiteeIDManager + +from fastapi import status, HTTPException + +logger = logging.getLogger('gunicorn.error') + + +async def get_oidc_token(code: str) -> Dict[str, Any]: + if config["DEPLOY_MODE"] == 'local': + ret = await get_local_oidc_token(code) + return ret + data = { + "client_id": config["OIDC_APP_ID"], + "client_secret": config["OIDC_APP_SECRET"], + "redirect_uri": config["EULER_LOGIN_API"], + "grant_type": "authorization_code", + "code": code + } + url = config['OIDC_TOKEN_URL'] + headers = { + "Content-Type": "application/x-www-form-urlencoded" + } + result = None + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, data=data, timeout=10) as resp: + if resp.status != 200: + raise Exception(f"Get OIDC token error: {resp.status}, full output is: {await resp.text()}") + logger.info(f'full response is {await resp.text()}') + result = await resp.json() + return { + "access_token": result["access_token"], + "refresh_token": result["refresh_token"], + } + + +async def get_oidc_user(access_token: str, refresh_token: str) -> dict: + if config["DEPLOY_MODE"] == 'local': + ret = await get_local_oidc_user(access_token, refresh_token) + return ret + elif config["DEPLOY_MODE"] == 'gitee': + ret = await get_gitee_oidc_user(access_token, refresh_token) + return ret + + if not access_token: + raise Exception("Access token is empty.") + url = config['OIDC_USER_URL'] + headers = { + "Authorization": access_token + } + + result = None + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers, timeout=10) as resp: + if resp.status != 200: + raise Exception(f"Get OIDC user error: {resp.status}, full response is: {resp.text()}") + logger.info(f'full response is {await resp.text()}') + result = await resp.json() + + if not result["phone_number_verified"]: + raise Exception("Could not validate credentials.") + + user_sub = result['sub'] + with RedisConnectionPool.get_redis_connection() as r: + r.set(f'{user_sub}_oidc_access_token', access_token, int(config['OIDC_ACCESS_TOKEN_EXPIRE_TIME'])*60) + r.set(f'{user_sub}_oidc_refresh_token', refresh_token, int(config['OIDC_REFRESH_TOKEN_EXPIRE_TIME'])*60) + + return { + "user_sub": user_sub + } + + +async def get_local_oidc_token(code: str): + data = { + "client_id": config["OIDC_APP_ID"], + "redirect_uri": config["EULER_LOGIN_API"], + "grant_type": "authorization_code", + "code": code + } + headers = { + "Content-Type": "application/json" + } + url = config['OIDC_TOKEN_URL'] + result = None + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=data, timeout=10) as resp: + if resp.status != 200: + raise Exception(f"Get OIDC token error: {resp.status}, full response is: {resp.text()}") + logger.info(f'full response is {await resp.text()}') + result = await resp.json() + return { + "access_token": result["data"]["access_token"], + "refresh_token": result["data"]["refresh_token"], + } + + +async def get_local_oidc_user(access_token: str, refresh_token: str) -> dict: + if not access_token: + raise Exception("Access token is empty.") + headers = { + "Content-Type": "application/json" + } + url = config['OIDC_USER_URL'] + data = { + "token": access_token, + "client_id": config["OIDC_APP_ID"], + } + result = None + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=data, timeout=10) as resp: + if resp.status != 200: + raise Exception(f"Get OIDC user error: {resp.status}, full response is: {resp.text()}") + logger.info(f'full response is {await resp.text()}') + result = await resp.json() + user_sub = result['data'] + with RedisConnectionPool.get_redis_connection() as r: + r.set( + f'{user_sub}_oidc_access_token', + access_token, + int(config['OIDC_ACCESS_TOKEN_EXPIRE_TIME'])*60 + ) + r.set( + f'{user_sub}_oidc_refresh_token', + refresh_token, + int(config['OIDC_REFRESH_TOKEN_EXPIRE_TIME'])*60 + ) + + return { + "user_sub": user_sub + } + + +async def get_gitee_oidc_user(access_token: str, refresh_token: str) -> dict: + if not access_token: + raise Exception("Access token is empty.") + + url = f'''{config['OIDC_USER_URL']}?access_token={access_token}''' + result = None + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=10) as resp: + if resp.status != 200: + raise Exception(f"Get OIDC user error: {resp.status}, full response is: {resp.text()}") + logger.info(f'full response is {await resp.text()}') + result = await resp.json() + + user_sub = result['login'] + if not GiteeIDManager.check_user_exist_or_not(user_sub): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="auth error" + ) + with RedisConnectionPool.get_redis_connection() as r: + r.set(f'{user_sub}_oidc_access_token', access_token, int(config['OIDC_ACCESS_TOKEN_EXPIRE_TIME'])*60) + r.set(f'{user_sub}_oidc_refresh_token', refresh_token, int(config['OIDC_REFRESH_TOKEN_EXPIRE_TIME'])*60) + + return { + "user_sub": user_sub + } + diff --git a/apps/common/security.py b/apps/common/security.py new file mode 100644 index 000000000..9a53c91a6 --- /dev/null +++ b/apps/common/security.py @@ -0,0 +1,115 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import base64 +import binascii +import hashlib +import secrets + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from apps.common.config import config + + +class Security: + + @staticmethod + def encrypt(plaintext: str) -> tuple[str, dict]: + """ + 加密公共方法 + :param plaintext: + :return: + """ + half_key1 = config['HALF_KEY1'] + + encrypted_work_key, encrypted_work_key_iv = Security._generate_encrypted_work_key( + half_key1) + encrypted_plaintext, encrypted_iv = Security._encrypt_plaintext(half_key1, encrypted_work_key, + encrypted_work_key_iv, plaintext) + del plaintext + secret_dict = { + "encrypted_work_key": encrypted_work_key, + "encrypted_work_key_iv": encrypted_work_key_iv, + "encrypted_iv": encrypted_iv, + "half_key1": half_key1 + } + return encrypted_plaintext, secret_dict + + @staticmethod + def decrypt(encrypted_plaintext: str, secret_dict: dict): + """ + 解密公共方法 + :param encrypted_plaintext: 待解密的字符串 + :param secret_dict: 存放工作密钥的dict + :return: + """ + plaintext = Security._decrypt_plaintext(half_key1=secret_dict.get("half_key1"), + encrypted_work_key=secret_dict.get( + "encrypted_work_key"), + encrypted_work_key_iv=secret_dict.get( + "encrypted_work_key_iv"), + encrypted_iv=secret_dict.get( + "encrypted_iv"), + encrypted_plaintext=encrypted_plaintext) + return plaintext + + @staticmethod + def _get_root_key(half_key1: str) -> bytes: + half_key2 = config['HALF_KEY2'] + key = (half_key1 + half_key2).encode("utf-8") + half_key3 = config['HALF_KEY3'].encode("utf-8") + hash_key = hashlib.pbkdf2_hmac("sha256", key, half_key3, 10000) + return binascii.hexlify(hash_key)[13:45] + + @staticmethod + def _generate_encrypted_work_key(half_key1: str) -> tuple[str, str]: + bin_root_key = Security._get_root_key(half_key1) + bin_work_key = secrets.token_bytes(32) + bin_encrypted_work_key_iv = secrets.token_bytes(16) + bin_encrypted_work_key = Security._root_encrypt(bin_root_key, bin_encrypted_work_key_iv, bin_work_key) + encrypted_work_key = base64.b64encode(bin_encrypted_work_key).decode("ascii") + encrypted_work_key_iv = base64.b64encode(bin_encrypted_work_key_iv).decode("ascii") + return encrypted_work_key, encrypted_work_key_iv + + @staticmethod + def _get_work_key(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str) -> bytes: + bin_root_key = Security._get_root_key(half_key1) + bin_encrypted_work_key = base64.b64decode(encrypted_work_key.encode("ascii")) + bin_encrypted_work_key_iv = base64.b64decode(encrypted_work_key_iv.encode("ascii")) + return Security._root_decrypt(bin_root_key, bin_encrypted_work_key_iv, bin_encrypted_work_key) + + @staticmethod + def _root_encrypt(key: bytes, encrypted_iv: bytes, plaintext: bytes) -> bytes: + encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() + encrypted = encryptor.update(plaintext) + encryptor.finalize() + return encrypted + + @staticmethod + def _root_decrypt(key: bytes, encrypted_iv: bytes, encrypted: bytes) -> bytes: + encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() + plaintext = encryptor.update(encrypted) + return plaintext + + @staticmethod + def _encrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, + plaintext: str) -> tuple[str, str]: + bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) + salt = f"{half_key1}{plaintext}" + plaintext_temp = salt.encode("utf-8") + del plaintext + del salt + bin_encrypted_iv = secrets.token_bytes(16) + bin_encrypted_plaintext = Security._root_encrypt(bin_work_key, bin_encrypted_iv, plaintext_temp) + encrypted_plaintext = base64.b64encode(bin_encrypted_plaintext).decode("ascii") + encrypted_iv = base64.b64encode(bin_encrypted_iv).decode("ascii") + return encrypted_plaintext, encrypted_iv + + @staticmethod + def _decrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, + encrypted_plaintext: str, encrypted_iv) -> str: + bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) + bin_encrypted_plaintext = base64.b64decode(encrypted_plaintext.encode("ascii")) + bin_encrypted_iv = base64.b64decode(encrypted_iv.encode("ascii")) + plaintext_temp = Security._root_decrypt(bin_work_key, bin_encrypted_iv, bin_encrypted_plaintext) + plaintext_salt = plaintext_temp.decode("utf-8") + plaintext = plaintext_salt[len(half_key1):] + return plaintext diff --git a/apps/common/singleton.py b/apps/common/singleton.py new file mode 100644 index 000000000..c14a48976 --- /dev/null +++ b/apps/common/singleton.py @@ -0,0 +1,17 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from threading import Lock + + +class Singleton(type): + """ + 用于实现全局单例的Class + """ + + _instances = {} + _lock: Lock = Lock() + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + with cls._lock: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/apps/common/thread.py b/apps/common/thread.py new file mode 100644 index 000000000..e419de1c8 --- /dev/null +++ b/apps/common/thread.py @@ -0,0 +1,21 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from concurrent.futures import ThreadPoolExecutor + +from apps.common.singleton import Singleton + +class ProcessThreadPool(metaclass=Singleton): + """ + 给每个进程分配一个线程池 + """ + + thread_executor: ThreadPoolExecutor + + def __init__(self, thread_worker_num: int = 5): + self.thread_executor = ThreadPoolExecutor(max_workers=thread_worker_num) + + def exec(self): + """ + 获取线程执行器 + :return: 线程执行器对象;可将任务提交到线程池中 + """ + return self.thread_executor diff --git a/apps/common/wordscheck.py b/apps/common/wordscheck.py new file mode 100644 index 000000000..46755c5a4 --- /dev/null +++ b/apps/common/wordscheck.py @@ -0,0 +1,66 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from __future__ import annotations + +import http +import re +import logging + +import requests + +from apps.common.config import config + +logger = logging.getLogger('gunicorn.error') + + +class APICheck(object): + + @classmethod + def check(cls, content: str) -> int: + url = config['WORDS_CHECK'] + headers = {"Content-Type": "application/json"} + data = {"content": content} + try: + response = requests.post(url=url, json=data, headers=headers, timeout=10) + if response.status_code == http.HTTPStatus.OK: + if re.search("ok", str(response.content)): + return 1 + return 0 + except Exception as e: + logger.info("过滤敏感词错误:" + str(e)) + return -1 + + +class KeywordCheck: + words_list: list + + def __init__(self): + with open(config["WORDS_LIST"], "r", encoding="utf-8") as f: + self.words_list = f.read().splitlines() + + def check_words(self, message: str) -> int: + if message in self.words_list: + return 1 + return 0 + + +class WordsCheck: + tool: APICheck | KeywordCheck | None = None + + def __init__(self): + raise NotImplementedError("WordsCheck无法被实例化!") + + @classmethod + def init(cls): + if config["DETECT_TYPE"] == "keyword": + cls.tool = KeywordCheck() + elif config["DETECT_TYPE"] == "wordscheck": + cls.tool = APICheck() + else: + cls.tool = None + + @classmethod + async def check(cls, message: str) -> int: + # 异常-1,拦截0,正常1 + if not cls.tool: + return 1 + return cls.tool.check(message) diff --git a/apps/constants.py b/apps/constants.py new file mode 100644 index 000000000..12a3d7051 --- /dev/null +++ b/apps/constants.py @@ -0,0 +1,4 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +CURRENT_REVISION_VERSION = '0.0.0' +NEW_CHAT = 'New Chat' diff --git a/apps/cron/__init__.py b/apps/cron/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/cron/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/cron/delete_user.py b/apps/cron/delete_user.py new file mode 100644 index 000000000..76a358283 --- /dev/null +++ b/apps/cron/delete_user.py @@ -0,0 +1,37 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from datetime import datetime, timedelta, timezone + +import pytz +import logging + +from apps.manager.audit_log import AuditLogData, AuditLogManager +from apps.manager.comment import CommentManager +from apps.manager.record import RecordManager +from apps.manager.user import UserManager +from apps.manager.conversation import ConversationManager + + +class DeleteUserCron: + logger = logging.getLogger('gunicorn.error') + + @staticmethod + def delete_user(): + try: + now = datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')) + thirty_days_ago = now - timedelta(days=30) + userinfos = UserManager.query_userinfo_by_login_time( + thirty_days_ago) + for user in userinfos: + conversations = ConversationManager.get_conversation_by_user_sub( + user.user_sub) + for conv in conversations: + RecordManager.delete_encrypted_qa_pair_by_conversation_id( + conv.conversation_id) + CommentManager.delete_comment_by_user_sub(user.user_sub) + UserManager.delete_userinfo_by_user_sub(user.user_sub) + data = AuditLogData(method_type='internal_scheduler_job', source_name='delete_user', ip='internal', + result=f'Deleted user: {user.user_sub}', reason='30天未登录') + AuditLogManager.add_audit_log(user.user_sub, data) + except Exception as e: + DeleteUserCron.logger.info( + f"Scheduler delete user failed: {e}") diff --git a/apps/dependency/__init__.py b/apps/dependency/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/dependency/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/dependency/csrf.py b/apps/dependency/csrf.py new file mode 100644 index 000000000..3c9cb5ffd --- /dev/null +++ b/apps/dependency/csrf.py @@ -0,0 +1,24 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from fastapi import Request, HTTPException, status, Response + +from apps.manager.session import SessionManager +from apps.common.config import config + + +def verify_csrf_token(request: Request, response: Response): + if not config["ENABLE_CSRF"]: + return + + csrf_token = request.headers.get('x-csrf-token').strip("\"") + session = request.cookies.get('ECSESSION') + + if not SessionManager.verify_csrf_token(session, csrf_token): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='CSRF token is invalid.') + + new_csrf_token = SessionManager.create_csrf_token(session) + if not new_csrf_token: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Renew CSRF token failed.") + + response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, + secure=True, domain=config["DOMAIN"], samesite="strict") + return response diff --git a/apps/dependency/limit.py b/apps/dependency/limit.py new file mode 100644 index 000000000..07d0812ad --- /dev/null +++ b/apps/dependency/limit.py @@ -0,0 +1,29 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from limits import storage, strategies, RateLimitItemPerMinute +from functools import wraps + +from apps.models.redis import RedisConnectionPool + +from fastapi import Response + + +class Limit: + memory_storage = storage.MemoryStorage() + moving_window = strategies.MovingWindowRateLimiter(memory_storage) + limit_rate = RateLimitItemPerMinute(50) + + +def moving_window_limit(func): + @wraps(func) + async def wrapper(*args, **kwargs): + user_sub = kwargs.get('user').user_sub + rate_limit_response = Response(content='Rate limit exceeded', status_code=429) + with RedisConnectionPool.get_redis_connection() as r: + if r.get(f'{user_sub}_active'): + return rate_limit_response + if not Limit.moving_window.hit(Limit.limit_rate, "stream_answer", cost=1): + return rate_limit_response + r.setex(f'{user_sub}_active', 300, user_sub) + return await func(*args, **kwargs) + + return wrapper diff --git a/apps/dependency/session.py b/apps/dependency/session.py new file mode 100644 index 000000000..0b671bd80 --- /dev/null +++ b/apps/dependency/session.py @@ -0,0 +1,51 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +from apps.manager.session import SessionManager +from apps.common.config import config + + +BYPASS_LIST = [ + "/health_check", + "/api/auth/login", + "/api/auth/logout", +] + + +class VerifySessionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if request.url.path in BYPASS_LIST: + return await call_next(request) + + cookie = request.cookies.get("ECSESSION", "") + host = request.client.host + session_id = SessionManager.get_session(cookie, host) + + if session_id != request.cookies.get("ECSESSION", ""): + cookie_str = "" + + for item in request.scope["headers"]: + if item[0] == b"cookie": + cookie_str = item[1].decode() + request.scope["headers"].remove(item) + break + + all_cookies = "" + if cookie_str != "": + other_headers = cookie_str.split(";") + for item in other_headers: + if "ECSESSION" not in item: + all_cookies += "{}; ".format(item) + + all_cookies += "ECSESSION={}".format(session_id) + request.scope["headers"].append((b"cookie", all_cookies.encode())) + + response = await call_next(request) + response.set_cookie("ECSESSION", session_id, httponly=True, secure=True, samesite="strict", + max_age=config["SESSION_TTL"] * 60, domain=config["DOMAIN"]) + else: + response = await call_next(request) + response.set_cookie("ECSESSION", session_id, httponly=True, secure=True, samesite="strict", + max_age=config["SESSION_TTL"] * 60, domain=config["DOMAIN"]) + return response diff --git a/apps/dependency/user.py b/apps/dependency/user.py new file mode 100644 index 000000000..eceaae652 --- /dev/null +++ b/apps/dependency/user.py @@ -0,0 +1,61 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer +from starlette import status +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection + +from apps.entities.user import User +from apps.manager.api_key import ApiKeyManager +from apps.manager.session import SessionManager + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +def verify_user(request: HTTPConnection): + """ + 验证Session是否已鉴权;未鉴权则抛出HTTP 401 + 接口级dependence + :param request: HTTP请求 + :return: + """ + session_id = request.cookies["ECSESSION"] + if not SessionManager.verify_user(session_id): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + +def get_session(request: HTTPConnection): + """ + 验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401 + 参数级dependence + :param request: HTTP请求 + :return: Session ID + """ + session_id = request.cookies["ECSESSION"] + if not SessionManager.verify_user(session_id): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + return session_id + +def get_user(request: HTTPConnection) -> User: + """ + 验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401 + 参数级dependence + :param request: + :return: + """ + session_id = request.cookies["ECSESSION"] + user = SessionManager.get_user(session_id) + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + return user + + +def verify_api_key(api_key: str = Depends(oauth2_scheme)): + if not ApiKeyManager.verify_api_key(api_key): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key!") + + +def get_user_by_api_key(api_key: str = Depends(oauth2_scheme)) -> User: + user = ApiKeyManager.get_user_by_api_key(api_key) + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key!") + return user diff --git a/apps/entities/__init__.py b/apps/entities/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/entities/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/entities/blacklist.py b/apps/entities/blacklist.py new file mode 100644 index 000000000..f138479c4 --- /dev/null +++ b/apps/entities/blacklist.py @@ -0,0 +1,27 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from pydantic import BaseModel + + +# 问题相关 FastAPI 所需数据 +class QuestionBlacklistRequest(BaseModel): + question: str + answer: str + is_deletion: int + + +# 用户相关 FastAPI 所需数据 +class UserBlacklistRequest(BaseModel): + user_sub: str + is_ban: int + + +# 举报相关 FastAPI 所需数据 +class AbuseRequest(BaseModel): + record_id: str + reason: str + + +# 举报审核相关 FastAPI 所需数据 +class AbuseProcessRequest(BaseModel): + id: int + is_deletion: int diff --git a/apps/entities/comment.py b/apps/entities/comment.py new file mode 100644 index 000000000..2c3f4a5f8 --- /dev/null +++ b/apps/entities/comment.py @@ -0,0 +1,11 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from dataclasses import dataclass + + +@dataclass +class CommentData: + record_id: str + is_like: bool + dislike_reason: str + reason_link: str + reason_description: str diff --git a/apps/entities/plugin.py b/apps/entities/plugin.py new file mode 100644 index 000000000..a310eb905 --- /dev/null +++ b/apps/entities/plugin.py @@ -0,0 +1,43 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 数据结构定义 +from typing import List, Dict, Any, Optional + +from pydantic import BaseModel + + +class ToolData(BaseModel): + name: str + params: Dict[str, Any] + + +class Step(BaseModel): + name: str + dangerous: bool = False + call_type: str + params: Dict[str, Any] = {} + next: Optional[str] = None + + +class Flow(BaseModel): + on_error: Optional[Step] = Step( + name="error", + call_type="llm", + params={ + "user_prompt": "当前工具执行发生错误,原始错误信息为:{data}. 请向用户展示错误信息,并给出可能的解决方案。\n\n背景信息:{context}" + } + ) + steps: Dict[str, Step] + next_flow: Optional[List[str]] = None + + +class PluginData(BaseModel): + id: str + plugin_name: str + plugin_description: str + plugin_auth: Optional[dict] = None + + +class PluginListData(BaseModel): + code: int + message: str + result: list[PluginData] diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py new file mode 100644 index 000000000..0401c23d6 --- /dev/null +++ b/apps/entities/request_data.py @@ -0,0 +1,57 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class RequestData(BaseModel): + question: str = Field(..., max_length=2000) + language: Optional[str] = Field(default="zh") + conversation_id: str = Field(..., min_length=32, max_length=32) + record_id: Optional[str] = Field(default=None) + user_selected_plugins: List[str] = Field(default=[]) + user_selected_flow: Optional[str] = Field(default=None) + files: Optional[List[str]] = Field(default=None) + flow_id: Optional[str] = Field(default=None) + + +class ClientChatRequestData(BaseModel): + session_id: str = Field(..., min_length=32, max_length=32) + question: str = Field(..., max_length=2000) + language: Optional[str] = Field(default="zh") + conversation_id: str = Field(..., min_length=32, max_length=32) + record_id: Optional[str] = Field(default=None) + user_selected_plugins: List[str] = Field(default=[]) + user_selected_flow: Optional[str] = Field(default=None) + files: Optional[List[str]] = Field(default=None) + flow_id: Optional[str] = Field(default=None) + + +class ClientSessionData(BaseModel): + session_id: Optional[str] = Field(default=None) + + +class ModifyConversationData(BaseModel): + title: str = Field(..., min_length=1, max_length=2000) + + +class ModifyRevisionData(BaseModel): + revision_num: str = Field(..., min_length=5, max_length=5) + + +class DeleteConversationData(BaseModel): + conversation_list: list[str] = Field(...) + + +class AddCommentData(BaseModel): + record_id: str = Field(..., min_length=32, max_length=32) + is_like: bool = Field(...) + dislike_reason: str = Field(default=None, max_length=100) + reason_link: str = Field(default=None, max_length=200) + reason_description: str = Field( + default=None, max_length=500) + + +class AddDomainData(BaseModel): + domain_name: str = Field(..., min_length=1, max_length=100) + domain_description: str = Field(..., max_length=2000) diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py new file mode 100644 index 000000000..ab4153863 --- /dev/null +++ b/apps/entities/response_data.py @@ -0,0 +1,51 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + + +class ResponseData(BaseModel): + code: int + message: str + result: dict + + +class ConversationData(BaseModel): + conversation_id: str + title: str + created_time: datetime + + +class ConversationListData(BaseModel): + code: int + message: str + result: list[ConversationData] + + +class RecordData(BaseModel): + conversation_id: str + record_id: str + question: str + answer: str + is_like: Optional[int] = None + created_time: datetime + group_id: str + + +class RecordListData(BaseModel): + code: int + message: str + result: list[RecordData] + + +class RecordQueryData(BaseModel): + conversation_id: str + record_id: str + encrypted_question: str + question_encryption_config: dict + encrypted_answer: str + answer_encryption_config: dict + created_time: str + is_like: Optional[int] = None + group_id: str diff --git a/apps/entities/user.py b/apps/entities/user.py new file mode 100644 index 000000000..01d486e87 --- /dev/null +++ b/apps/entities/user.py @@ -0,0 +1,9 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from typing import Optional + +from pydantic import BaseModel, Field + + +class User(BaseModel): + user_sub: str = Field(..., description="user sub") + revision_number: Optional[str] = None diff --git a/apps/gunicorn.conf.py b/apps/gunicorn.conf.py new file mode 100644 index 000000000..d7f10d45b --- /dev/null +++ b/apps/gunicorn.conf.py @@ -0,0 +1,33 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.common.thread import ProcessThreadPool +from apps.common.wordscheck import WordsCheck +from apps.scheduler.pool.loader import Loader + + +preload_app = True +bind = "0.0.0.0:8002" +workers = 8 +timeout = 300 +accesslog = "-" +capture_output = True +worker_class = "uvicorn.workers.UvicornWorker" + +def on_starting(server): + """ + Gunicorn服务器启动时的初始化代码 + :param server: 服务器配置项 + :return: + """ + WordsCheck.init() + Loader.init() + + +def post_fork(server, worker): + """ + Gunicorn服务器每个Worker进程启动后的初始化代码 + :param server: 服务器配置项 + :param worker: Worker配置项 + :return: + """ + ProcessThreadPool(thread_worker_num=5) diff --git a/apps/llm.py b/apps/llm.py new file mode 100644 index 000000000..e2cfabaa8 --- /dev/null +++ b/apps/llm.py @@ -0,0 +1,117 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import re +from typing import List, Dict, Any + +from openai import AsyncOpenAI +from sglang import RuntimeEndpoint +from sglang.lang.chat_template import get_chat_template +from sparkai.llm.llm import ChatSparkLLM +from langchain_openai import ChatOpenAI +from langchain_core.messages import ChatMessage as LangchainChatMessage +from sparkai.messages import ChatMessage as SparkChatMessage +import openai +from untruncate_json import untrunc +from json_minify import json_minify + +from apps.common.config import config + + +def get_scheduler() -> RuntimeEndpoint | AsyncOpenAI: + if config["SCHEDULER_BACKEND"] == "sglang": + endpoint = RuntimeEndpoint(config["SCHEDULER_URL"], api_key=config["SCHEDULER_API_KEY"]) + endpoint.chat_template = get_chat_template("chatml") + return endpoint + else: + # 使用vllm框架原生的扩展API,支持sm75以下NVIDIA显卡 + client = openai.AsyncOpenAI( + base_url=config["SCHEDULER_URL"], + api_key=config["SCHEDULER_API_KEY"], + ) + return client + + +async def create_vllm_stream(client: openai.AsyncOpenAI, messages: List[Dict[str, str]], + max_tokens: int, extra_body: Dict[str, Any]): + return client.chat.completions.create( + model=config["SCHEDULER_MODEL"], + messages=messages, + max_tokens=max_tokens, + extra_body=extra_body, + top_p=0.5, + temperature=0.01, + stream=True + ) + +async def stream_to_str(stream) -> str: + """ + 使用拼接的方式将openai client的stream转化为完整结果 + :param stream: openai async迭代器 + :return: 完整的大模型输出 + """ + result = "" + async for chunk in stream: + result += chunk.choices[0].delta.content or "" + return result + + +def get_llm(): + """ + 获取大模型API Client + :return: OpenAI大模型Client,或星火大模型SDK Client + """ + if config["MODEL"] == "openai": + return ChatOpenAI( + openai_api_key=config["LLM_KEY"], + openai_api_base=config["LLM_URL"], + model_name=config["LLM_MODEL"], + tiktoken_model_name="cl100k_base", + max_tokens=4096, + streaming=True, + temperature=0.07 + ) + elif config["MODEL"] == "spark": + return ChatSparkLLM( + spark_app_id=config["SPARK_APP_ID"], + spark_api_key=config["SPARK_API_KEY"], + spark_api_secret=config["SPARK_API_SECRET"], + spark_api_url=config["SPARK_API_URL"], + spark_llm_domain=config["SPARK_LLM_DOMAIN"], + request_timeout=600, + max_tokens=4096, + streaming=True, + temperature=0.07 + ) + else: + raise NotImplementedError + + +def get_message_model(llm): + """ + 根据大模型Client的Class,获取大模型消息的Class + :param llm: 大模型Client + :return: 大模型消息的Class + """ + if isinstance(llm, ChatOpenAI): + return LangchainChatMessage + elif isinstance(llm, ChatSparkLLM): + return SparkChatMessage + else: + raise NotImplementedError + + +def get_json_code_block(text): + """ + 从大模型的返回信息中提取出JSON代码段 + :param text: 大模型的返回信息 + :return: 提取出的JSON代码段 + """ + pattern = r'```(json)?(.*)```' + matches = re.search(pattern, text, re.DOTALL) + raw_result = matches.group(2) + raw_mini = json_minify(raw_result) + raw_fixed = untrunc.complete(raw_mini) + + return raw_fixed diff --git a/apps/logger/__init__.py b/apps/logger/__init__.py new file mode 100644 index 000000000..ac4066cc3 --- /dev/null +++ b/apps/logger/__init__.py @@ -0,0 +1,100 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +import os +import time +from logging.handlers import TimedRotatingFileHandler + +from apps.common.config import config + + +class SizedTimedRotatingFileHandler(TimedRotatingFileHandler): + def __init__(self, filename, max_bytes=0, backup_count=0, encoding=None, + delay=False, when='midnight', interval=1, utc=False): + super().__init__(filename, when, interval, backup_count, encoding, delay, utc) + self.max_bytes = max_bytes + + def shouldRollover(self, record): + if self.stream is None: + self.stream = self._open() + if self.max_bytes > 0: + msg = "%s\n" % self.format(record) + self.stream.seek(0, 2) + if self.stream.tell()+len(msg) >= self.max_bytes: + return 1 + t = int(time.time()) + if t >= self.rolloverAt: + return 1 + return 0 + + def doRollover(self): + self.stream.close() + os.chmod(self.baseFilename, 0o440) + TimedRotatingFileHandler.doRollover(self) + os.chmod(self.baseFilename, 0o640) + +LOG_FORMAT = '[{asctime}][{levelname}][{name}][P{process}][T{thread}][{message}][{funcName}({filename}:{lineno})]' + +if config["LOG"] == "stdout": + handlers = { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + }, + } +else: + LOG_DIR = './logs' + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR, 0o750) + handlers = { + 'default': { + 'formatter': 'default', + 'class': 'apps.logger.SizedTimedRotatingFileHandler', + 'filename': f"{LOG_DIR}/app.log", + 'backup_count': 30, + 'when': 'MIDNIGHT', + 'max_bytes': 5000000 + } + } + +log_config = { + "version": 1, + 'disable_existing_loggers': False, + "formatters": { + "default": { + '()': 'logging.Formatter', + 'fmt': LOG_FORMAT, + 'style': '{' + } + }, + "handlers": handlers, + "loggers": { + "uvicorn": { + "level": "INFO", + "handlers": ["default"], + 'propagate': False + }, + "uvicorn.errors": { + "level": "INFO", + "handlers": ["default"], + 'propagate': False + }, + "uvicorn.access": { + "level": "INFO", + "handlers": ["default"], + 'propagate': False + } + } +} + + +def get_logger(): + logger = logging.getLogger('uvicorn') + logger.setLevel(logging.INFO) + if config['LOG'] != 'stdout': + rotate_handler = SizedTimedRotatingFileHandler( + filename=f'{LOG_DIR}/app.log', when='MIDNIGHT', backup_count=30, max_bytes=5000000) + logger.addHandler(rotate_handler) + logger.propagate = False + return logger diff --git a/apps/main.py b/apps/main.py new file mode 100644 index 000000000..8265a2d28 --- /dev/null +++ b/apps/main.py @@ -0,0 +1,85 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging + +import uvicorn +from apscheduler.schedulers.background import BackgroundScheduler +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from apps.common.config import config +from apps.cron.delete_user import DeleteUserCron +from apps.dependency.session import VerifySessionMiddleware +from apps.logger import log_config +from apps.models.redis import RedisConnectionPool +from apps.routers import ( + api_key, + auth, + blacklist, + chat, + client, + comment, + conversation, + file, + health, + plugin, + record, +) +from apps.scheduler.files import Files + +# 定义FastAPI app +app = FastAPI(docs_url=None, redoc_url=None) +# 定义FastAPI全局中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=[config['WEB_FRONT_URL']], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.add_middleware(VerifySessionMiddleware) +# 关联API路由 +app.include_router(conversation.router) +app.include_router(auth.router) +app.include_router(api_key.router) +app.include_router(comment.router) +app.include_router(record.router) +app.include_router(health.router) +app.include_router(plugin.router) +app.include_router(chat.router) +app.include_router(client.router) +app.include_router(blacklist.router) +app.include_router(file.router) +# 初始化日志记录器 +logger = logging.getLogger('gunicorn.error') +# 初始化后台定时任务 +scheduler = BackgroundScheduler() +scheduler.start() +scheduler.add_job(DeleteUserCron.delete_user, 'cron', hour=3) +scheduler.add_job(Files.delete_old_files, 'cron', hour=3) +# 初始化Redis连接池 +RedisConnectionPool.get_redis_pool() + + +if __name__ == "__main__": + try: + ssl_enable = config["SSL_ENABLE"] + if ssl_enable: + uvicorn.run( + app, + host=config["UVICORN_HOST"], + port=int(config["UVICORN_PORT"]), + log_config=log_config, + ssl_certfile=config["SSL_CERTFILE"], + ssl_keyfile=config["SSL_KEYFILE"], + ssl_keyfile_password=config["SSL_KEY_PWD"] + ) + else: + uvicorn.run( + app, + host=config["UVICORN_HOST"], + port=int(config["UVICORN_PORT"]), + log_config=log_config + ) + except Exception as e: + logger.error(e) diff --git a/apps/manager/__init__.py b/apps/manager/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/manager/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/manager/api_key.py b/apps/manager/api_key.py new file mode 100644 index 000000000..c4787f809 --- /dev/null +++ b/apps/manager/api_key.py @@ -0,0 +1,102 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import hashlib +import logging +import uuid + +from apps.entities.user import User as UserInfo +from apps.manager.user import UserManager +from apps.models.mysql import ApiKey, MysqlDB + +logger = logging.getLogger('gunicorn.error') + + +class ApiKeyManager: + def __init__(self): + raise NotImplementedError("ApiKeyManager无法被实例化") + + @staticmethod + def generate_api_key(userinfo: UserInfo) -> str | None: + user_sub = userinfo.user_sub + api_key = str(uuid.uuid4().hex) + api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] + try: + with MysqlDB().get_session() as session: + session.add(ApiKey(user_sub=user_sub, api_key_hash=api_key_hash)) + session.commit() + except Exception as e: + logger.info(f"Add API key failed due to error: {e}") + return None + return api_key + + @staticmethod + def delete_api_key(userinfo: UserInfo) -> bool: + user_sub = userinfo.user_sub + if not ApiKeyManager.api_key_exists(userinfo): + return False + try: + with MysqlDB().get_session() as session: + session.query(ApiKey).filter(ApiKey.user_sub == user_sub).delete() + session.commit() + except Exception as e: + logger.info(f"Delete API key failed due to error: {e}") + return False + else: + return True + + @staticmethod + def api_key_exists(userinfo: UserInfo) -> bool: + user_sub = userinfo.user_sub + try: + with MysqlDB().get_session() as session: + result = session.query(ApiKey).filter(ApiKey.user_sub == user_sub).first() + except Exception as e: + logger.info(f"Check API key existence failed due to error: {e}") + return False + else: + return result is not None + + @staticmethod + def get_user_by_api_key(api_key: str) -> UserInfo | None: + api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] + try: + with MysqlDB().get_session() as session: + user_sub = session.query(ApiKey).filter(ApiKey.api_key_hash == api_key_hash).first().user_sub + if user_sub is None: + return None + userdata = UserManager.get_userinfo_by_user_sub(user_sub) + if userdata is None: + return None + except Exception as e: + logger.info(f"Get user info by API key failed due to error: {e}") + else: + return UserInfo(user_sub=userdata.user_sub, revision_number=userdata.revision_number) + + @staticmethod + def verify_api_key(api_key: str) -> bool: + api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] + try: + with MysqlDB().get_session() as session: + user_sub = session.query(ApiKey).filter(ApiKey.api_key_hash == api_key_hash).first().user_sub + except Exception as e: + logger.info(f"Verify API key failed due to error: {e}") + return False + return user_sub is not None + + @staticmethod + def update_api_key(userinfo: UserInfo) -> str | None: + if not ApiKeyManager.api_key_exists(userinfo): + return None + user_sub = userinfo.user_sub + api_key = str(uuid.uuid4().hex) + api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] + try: + with MysqlDB().get_session() as session: + session.query(ApiKey).filter(ApiKey.user_sub == user_sub).update({"api_key_hash": api_key_hash}) + session.commit() + except Exception as e: + logger.info(f"Update API key failed due to error: {e}") + return None + return api_key diff --git a/apps/manager/audit_log.py b/apps/manager/audit_log.py new file mode 100644 index 000000000..f20eb3dde --- /dev/null +++ b/apps/manager/audit_log.py @@ -0,0 +1,34 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from dataclasses import dataclass +import logging + +from apps.models.mysql import AuditLog, MysqlDB + +logger = logging.getLogger('gunicorn.error') + + +@dataclass +class AuditLogData: + method_type: str + source_name: str + ip: str + result: str + reason: str + + +class AuditLogManager: + def __init__(self): + raise NotImplementedError("AuditLogManager无法被实例化") + + @staticmethod + def add_audit_log(user_sub: str, data: AuditLogData): + try: + with MysqlDB().get_session() as session: + add_audit_log = AuditLog(user_sub=user_sub, method_type=data.method_type, + source_name=data.source_name, ip=data.ip, + result=data.result, reason=data.reason) + session.add(add_audit_log) + session.commit() + except Exception as e: + logger.info(f"Add audit log failed due to error: {e}") diff --git a/apps/manager/blacklist.py b/apps/manager/blacklist.py new file mode 100644 index 000000000..3a94eef57 --- /dev/null +++ b/apps/manager/blacklist.py @@ -0,0 +1,300 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from __future__ import annotations + +import json +import logging + +from sqlalchemy import select + +from apps.common.security import Security +from apps.models.mysql import ( + MysqlDB, + Record, + QuestionBlacklist, + User, + Conversation, +) + + +logger = logging.getLogger('gunicorn.error') + + +class QuestionBlacklistManager: + def __init__(self): + raise NotImplementedError("QuestionBlacklistManager无法被实例化") + + # 给定问题,查找问题是否在黑名单里 + @staticmethod + def check_blacklisted_questions(input_question: str) -> bool | None: + try: + # 搜索问题 + with MysqlDB().get_session() as session: + result = session.scalars( + select(QuestionBlacklist).filter_by(is_audited=True) + .order_by(QuestionBlacklist.id) + ) + + # 问题表为空,则下面的代码不会执行 + for item in result: + if item.question in input_question: + # 用户输入的问题中包含黑名单问题的一部分,故拉黑 + logger.info("Question in blacklist.") + return False + + return True + except Exception as e: + # 访问数据库异常 + logger.info(f"Check question blacklist failed: {e}") + return None + + # 删除或修改已在黑名单里的问题,is_deletion标识是否为删除操作 + @staticmethod + def change_blacklisted_questions(question: str, answer: str, is_deletion: bool = False) -> bool: + try: + with MysqlDB().get_session() as session: + # 搜索问题,精确匹配 + result = session.scalars( + select(QuestionBlacklist).filter_by(is_audited=True).filter_by(question=question).limit(1) + ).first() + + if result is None: + if not is_deletion: + # 没有查到任何结果,进行添加问题 + logger.info("Question not found in blacklist.") + session.add(QuestionBlacklist(question=question, answer=answer, is_audited=True, + reason_description="手动添加")) + else: + logger.info("Question does not exist.") + else: + # search_question就是搜到的结果,进行答案更改 + if not is_deletion: + # 修改 + logger.info("Modify question in blacklist.") + + result.question = question + result.answer = answer + else: + # 删除 + logger.info("Delete question in blacklist.") + session.delete(result) + + session.commit() + return True + except Exception as e: + # 数据库操作异常 + logger.info(f"Change question blacklist failed: {e}") + # 放弃执行后续操作 + return False + + # 分页式获取目前所有的问题(待审核或已拉黑)黑名单 + @staticmethod + def get_blacklisted_questions(limit: int, offset: int, is_audited: bool) -> list | None: + try: + with MysqlDB().get_session() as session: + query = session.scalars( + # limit:取多少条;offset:跳过前多少条; + select(QuestionBlacklist).filter_by(is_audited=is_audited). + order_by(QuestionBlacklist.id).limit(limit).offset(offset) + ) + + result = [] + # 无条目,则下面的语句不会执行 + for item in query: + result.append({ + "id": item.id, + "question": item.question, + "answer": item.answer, + "reason": item.reason_description, + "created_time": item.created_time, + }) + + return result + except Exception as e: + logger.info(f"Query question blacklist failed: {e}") + # 异常,返回None + return None + + +# 用户黑名单相关操作 +class UserBlacklistManager: + def __init__(self): + raise NotImplementedError("UserBlacklistManager无法被实例化") + + # 获取当前所有黑名单用户 + @staticmethod + def get_blacklisted_users(limit: int, offset: int) -> list | None: + try: + with MysqlDB().get_session() as session: + result = session.scalars( + select(User).order_by(User.user_sub).filter(User.credit <= 0) + .filter_by(is_whitelisted=False).limit(limit).offset(offset) + ) + + user = [] + # 无条目,则下面的语句不会执行 + for item in result: + user.append({ + "user_sub": item.user_sub, + "organization": item.organization, + "credit": item.credit, + "login_time": item.login_time + }) + + return user + + except Exception as e: + logger.info(f"Query user blacklist failed: {e}") + return None + + # 检测某用户是否已被拉黑 + @staticmethod + def check_blacklisted_users(user_sub: str) -> bool | None: + try: + with MysqlDB().get_session() as session: + result = session.scalars( + select(User).filter_by(user_sub=user_sub).filter(User.credit <= 0) + .filter_by(is_whitelisted=False).limit(1) + ).first() + + # 有条目,说明被拉黑 + if result is not None: + logger.info("User blacklisted.") + return True + + return False + + except Exception as e: + logger.info(f"Check user blacklist failed: {e}") + return None + + # 修改用户的信用分 + @staticmethod + def change_blacklisted_users(user_sub: str, credit_diff: int, credit_limit: int = 100) -> bool | None: + try: + with MysqlDB().get_session() as session: + # 查找当前用户信用分 + result = session.scalars( + select(User).filter_by(user_sub=user_sub).limit(1) + ).first() + + # 用户不存在 + if result is None: + logger.info("User does not exist.") + return False + + # 用户已被加白,什么都不做 + if result.is_whitelisted: + return False + + if result.credit > 0 and credit_diff > 0: + logger.info("User already unbanned.") + return True + if result.credit <= 0 and credit_diff < 0: + logger.info("User already banned.") + return True + + # 给当前用户的信用分加上偏移量 + result.credit += credit_diff + # 不得超过积分上限 + if result.credit > credit_limit: + result.credit = credit_limit + # 不得小于0 + elif result.credit < 0: + result.credit = 0 + + session.commit() + return True + except Exception as e: + # 数据库错误 + logger.info(f"Change user blacklist failed: {e}") + return None + + +# 用户举报相关操作 +class AbuseManager: + def __init__(self): + raise NotImplementedError("AbuseManager无法被实例化") + + # 存储用户举报详情 + @staticmethod + def change_abuse_report(user_sub: str, qa_record_id: str, reason: str) -> bool | None: + try: + with MysqlDB().get_session() as session: + # 检查qa_record_id是否在当前user下 + qa_record = session.scalars( + select(Record).filter_by(qa_record_id=qa_record_id).limit(1) + ).first() + + # qa_record_id 不存在 + if qa_record is None: + logger.info("QA record invalid.") + return False + + user = session.scalars( + select(Conversation).filter_by( + user_sub=user_sub, + user_qa_record_id=qa_record.conversation_id + ).limit(1) + ).first() + + # qa_record_id 不在当前用户下 + if user is None: + logger.info("QA record user mismatch.") + return False + + # 获得用户的明文输入 + user_question = Security.decrypt(qa_record.encrypted_question, + json.loads(qa_record.question_encryption_config)) + user_answer = Security.decrypt(qa_record.encrypted_answer, + json.loads(qa_record.answer_encryption_config)) + + # 检查该条目是否已被举报 + query = session.scalars( + select(QuestionBlacklist).filter_by(question=user_question).order_by(QuestionBlacklist.id).limit(1) + ).first() + # 结果为空 + if query is None: + # 新增举报信息;具体的举报类型在前端拼接 + session.add(QuestionBlacklist( + question=user_question, + answer=user_answer, + is_audited=False, + reason_description=reason + )) + session.commit() + return True + else: + # 类似问题已待审核/被加入黑名单,什么都不做 + logger.info("Question has been reported before.") + session.commit() + return True + + except Exception as e: + logger.info(f"Change user abuse report failed: {e}") + return None + + # 对某一特定的待审问题进行操作,包括批准审核与删除未审问题 + @staticmethod + def audit_abuse_report(question_id: int, is_deletion: int = False) -> bool | None: + try: + with MysqlDB().get_session() as session: + # 从数据库中查找该问题 + question = session.scalars( + select(QuestionBlacklist).filter_by(id=question_id).filter_by(is_audited=False).limit(1) + ).first() + + # 条目不存在 + if question is None: + return False + + # 删除 + if is_deletion: + session.delete(question) + else: + question.is_audited = True + + session.commit() + return True + except Exception as e: + logger.info(f"Audit user abuse report failed: {e}") + return None diff --git a/apps/manager/comment.py b/apps/manager/comment.py new file mode 100644 index 000000000..7ce2e8e28 --- /dev/null +++ b/apps/manager/comment.py @@ -0,0 +1,58 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import logging + +from apps.models.mysql import Comment, MysqlDB +from apps.entities.comment import CommentData + + +class CommentManager: + logger = logging.getLogger('gunicorn.error') + + @staticmethod + def query_comment(record_id: str): + result = None + try: + with MysqlDB().get_session() as session: + result = session.query(Comment).filter( + Comment.record_id == record_id).first() + except Exception as e: + CommentManager.logger.info( + f"Query comment failed due to error: {e}") + return result + + @staticmethod + def add_comment(user_sub: str, data: CommentData): + try: + with MysqlDB().get_session() as session: + add_comment = Comment(user_sub=user_sub, record_id=data.record_id, + is_like=data.is_like, dislike_reason=data.dislike_reason, + reason_link=data.reason_link, reason_description=data.reason_description) + session.add(add_comment) + session.commit() + except Exception as e: + CommentManager.logger.info( + f"Add comment failed due to error: {e}") + + @staticmethod + def update_comment(user_sub: str, data: CommentData): + try: + with MysqlDB().get_session() as session: + session.query(Comment).filter(Comment.user_sub == user_sub).filter( + Comment.record_id == data.record_id).update( + {"is_like": data.is_like, "dislike_reason": data.dislike_reason, "reason_link": data.reason_link, + "reason_description": data.reason_description}) + session.commit() + except Exception as e: + CommentManager.logger.info( + f"Add comment failed due to error: {e}") + + @staticmethod + def delete_comment_by_user_sub(user_sub: str): + try: + with MysqlDB().get_session() as session: + session.query(Comment).filter( + Comment.user_sub == user_sub).delete() + session.commit() + except Exception as e: + CommentManager.logger.info( + f"delete comment by user_sub failed due to error: {e}") diff --git a/apps/manager/conversation.py b/apps/manager/conversation.py new file mode 100644 index 000000000..db0ad3522 --- /dev/null +++ b/apps/manager/conversation.py @@ -0,0 +1,115 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import uuid +from datetime import datetime, timezone + +import pytz +import logging + +from apps.models.mysql import MysqlDB, Conversation + + +class ConversationManager: + logger = logging.getLogger('gunicorn.error') + + @staticmethod + def get_conversation_by_user_sub(user_sub): + results = [] + try: + with MysqlDB().get_session() as session: + results = session.query(Conversation).filter( + Conversation.user_sub == user_sub).all() + except Exception as e: + ConversationManager.logger.info( + f"Get conversation by user_sub failed: {e}") + return results + + @staticmethod + def get_conversation_by_conversation_id(conversation_id): + result = None + try: + with MysqlDB().get_session() as session: + result = session.query(Conversation).filter( + Conversation.conversation_id == conversation_id).first() + except Exception as e: + ConversationManager.logger.info( + f"Get conversation by conversation_id failed: {e}") + return result + + @staticmethod + def add_conversation_by_user_sub(user_sub): + conversation_id = str(uuid.uuid4().hex) + try: + with MysqlDB().get_session() as session: + conv = Conversation(conversation_id=conversation_id, + user_sub=user_sub, title="New Chat", + created_time=datetime.now(timezone.utc).astimezone( + pytz.timezone('Asia/Shanghai') + )) + session.add(conv) + session.commit() + session.refresh(conv) + except Exception as e: + conversation_id = None + ConversationManager.logger.info( + f"Add conversation by user_sub failed: {e}") + return conversation_id + + @staticmethod + def update_conversation_by_conversation_id(conversation_id, title): + try: + with MysqlDB().get_session() as session: + session.query(Conversation).filter(Conversation.conversation_id == + conversation_id).update({"title": title}) + session.commit() + except Exception as e: + ConversationManager.logger.info( + f"Update conversation by conversation_id failed: {e}") + result = ConversationManager.get_conversation_by_conversation_id( + conversation_id) + return result + + @staticmethod + def update_conversation_metadata_by_conversation_id(conversation_id, title, created_time): + try: + with MysqlDB().get_session() as session: + session.query(Conversation).filter(Conversation.conversation_id == conversation_id).update({ + "title": title, + "created_time": created_time + }) + except Exception as e: + ConversationManager.logger.info(f"Update conversation metadata by conversation_id failed: {e}") + result = ConversationManager.get_conversation_by_conversation_id(conversation_id) + return result + + @staticmethod + def delete_conversation_by_conversation_id(conversation_id): + try: + with MysqlDB().get_session() as session: + session.query(Conversation).filter(Conversation.conversation_id == conversation_id).delete() + session.commit() + except Exception as e: + ConversationManager.logger.info( + f"Delete conversation by conversation_id failed: {e}") + + @staticmethod + def delete_conversation_by_user_sub(user_sub): + try: + with MysqlDB().get_session() as session: + session.query(Conversation).filter( + Conversation.user_sub == user_sub).delete() + session.commit() + except Exception as e: + ConversationManager.logger.info( + f"Delete conversation by user_sub failed: {e}") + + @staticmethod + def update_summary(conversation_id, summary): + try: + with MysqlDB().get_session() as session: + session.query(Conversation).filter(Conversation.conversation_id == conversation_id).update({ + "summary": summary + }) + session.commit() + except Exception as e: + ConversationManager.logger.info(f"Update summary failed: {e}") diff --git a/apps/manager/domain.py b/apps/manager/domain.py new file mode 100644 index 000000000..0c369c61d --- /dev/null +++ b/apps/manager/domain.py @@ -0,0 +1,82 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from datetime import datetime, timezone + +import pytz +import logging + +from apps.models.mysql import MysqlDB, Domain +from apps.entities.request_data import AddDomainData + +logger = logging.getLogger('gunicorn.error') + + +class DomainManager: + def __init__(self): + raise NotImplementedError() + + @staticmethod + def get_domain(): + results = [] + try: + with MysqlDB().get_session() as session: + results = session.query(Domain).all() + except Exception as e: + logger.info(f"Get domain by domain_name failed: {e}") + return results + + @staticmethod + def get_domain_by_domain_name(domain_name): + results = [] + try: + with MysqlDB().get_session() as session: + results = session.query(Domain).filter( + Domain.domain_name == domain_name).all() + except Exception as e: + logger.info(f"Get domain by domain_name failed: {e}") + return results + + @staticmethod + def add_domain(add_domain_data: AddDomainData) -> bool: + try: + domain = Domain( + domain_name=add_domain_data.domain_name, + domain_description=add_domain_data.domain_description, + created_time=datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')), + updated_time=datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai'))) + with MysqlDB().get_session() as session: + session.add(domain) + session.commit() + return True + except Exception as e: + logger.info(f"Add domain failed due to error: {e}") + return False + + @staticmethod + def update_domain_by_domain_name(update_domain_data: AddDomainData): + result = None + try: + update_dict = { + "domain_description": update_domain_data.domain_description, + "updated_time": datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')) + } + + with MysqlDB().get_session() as session: + session.query(Domain).filter(Domain.domain_name == update_domain_data.domain_name).update(update_dict) + session.commit() + result = DomainManager.get_domain_by_domain_name(update_domain_data.domain_name) + except Exception as e: + logger.info(f"Update domain by domain_name failed due to error: {e}") + finally: + return result + + @staticmethod + def delete_domain_by_domain_name(delete_domain_data: AddDomainData): + try: + with MysqlDB().get_session() as session: + session.query(Domain).filter(Domain.domain_name == delete_domain_data.domain_name).delete() + session.commit() + return True + except Exception as e: + logger.info(f"Delete domain by domain_name failed due to error: {e}") + return False diff --git a/apps/manager/gitee_white_list.py b/apps/manager/gitee_white_list.py new file mode 100644 index 000000000..6a0ef9aaa --- /dev/null +++ b/apps/manager/gitee_white_list.py @@ -0,0 +1,23 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging + +from apps.models.mysql import GiteeIDWhiteList, MysqlDB + +class GiteeIDManager: + logger = logging.getLogger('gunicorn.error') + + @staticmethod + def check_user_exist_or_not(gitee_id): + result = None + try: + with MysqlDB().get_session() as session: + result = session.query(GiteeIDWhiteList).filter( + GiteeIDWhiteList.gitee_id == gitee_id).count() + except Exception as e: + GiteeIDManager.logger.info( + f"check user exist or not fail: {e}") + if not result: + return False + return True + diff --git a/apps/manager/plugin_token.py b/apps/manager/plugin_token.py new file mode 100644 index 000000000..beeb7e018 --- /dev/null +++ b/apps/manager/plugin_token.py @@ -0,0 +1,84 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import requests +import logging + +from apps.manager.session import SessionManager +from apps.models.redis import RedisConnectionPool +from apps.common.config import config + +logger = logging.getLogger('gunicorn.error') + + +class PluginTokenManager: + + @staticmethod + def get_plugin_token(plugin_domain, session_id, access_token_url, expire_time): + user_sub = SessionManager.get_user(session_id=session_id).user_sub + with RedisConnectionPool.get_redis_connection() as r: + token = r.get(f'{plugin_domain}_{user_sub}_token') + if not token: + token = PluginTokenManager.generate_plugin_token( + plugin_domain, + session_id, + user_sub, + access_token_url, + expire_time + ) + if isinstance(token, str): + return token + else: + return token.decode() + + + @staticmethod + def generate_plugin_token( + plugin_domain, session_id: str, + user_sub: str, + access_token_url: str, + expire_time: int + ): + with RedisConnectionPool.get_redis_connection() as r: + oidc_access_token = r.get(f'{user_sub}_oidc_access_token') + oidc_refresh_token = r.get(f'{user_sub}_oidc_refresh_token') + if not oidc_refresh_token: + # refresh token均过期的情况下,需要重新登录 + SessionManager.delete_session(session_id) + elif not oidc_access_token: + # access token 过期的时候,重新获取 + url = config['OIDC_REFRESH_TOKEN_URL'] + response = requests.post( + url=url, + json={ + "refresh_token": oidc_refresh_token.decode(), + "client_id": config["OIDC_APP_ID"] + } + ) + ret = response.json() + if response.status_code != 200: + logger.error('获取OIDC Access token 失败') + return None + oidc_access_token = ret['data']['access_token'], + with RedisConnectionPool.get_redis_connection() as r: + r.set( + f'{user_sub}_oidc_access_token', + oidc_access_token, + int(config['OIDC_ACCESS_TOKEN_EXPIRE_TIME']) * 60 + ) + response = requests.post( + url=access_token_url, + json={ + "client_id": config['OIDC_APP_ID'], + "access_token": oidc_access_token.decode() + } + ) + ret = response.json() + if response.status_code != 200: + logger.error(f'获取{plugin_domain} token失败') + return None + with RedisConnectionPool.get_redis_connection() as r: + r.set(f'{plugin_domain}_{user_sub}_token', ret['data']['access_token'], int(expire_time)*60) + return ret['data']['access_token'] + diff --git a/apps/manager/record.py b/apps/manager/record.py new file mode 100644 index 000000000..22bad519d --- /dev/null +++ b/apps/manager/record.py @@ -0,0 +1,126 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import json +import logging +from typing import Literal + +from sqlalchemy import desc, func, asc + +from apps.common.security import Security +from apps.entities.response_data import RecordQueryData +from apps.models.mysql import Comment, MysqlDB, Record + + +class RecordManager: + logger = logging.getLogger('gunicorn.error') + + @staticmethod + def insert_encrypted_data(conversation_id, record_id, group_id, user_sub, question, answer): + try: + encrypted_question, question_encryption_config = Security.encrypt( + question) + except Exception as e: + RecordManager.logger.info(f"Encryption failed: {e}") + return + try: + encrypted_answer, answer_encryption_config = Security.encrypt( + answer) + except Exception as e: + RecordManager.logger.info(f"Encryption failed: {e}") + return + + question_encryption_config = json.dumps(question_encryption_config) + answer_encryption_config = json.dumps(answer_encryption_config) + + new_qa_record = Record(conversation_id=conversation_id, + record_id=record_id, + encrypted_question=encrypted_question, + question_encryption_config=question_encryption_config, + encrypted_answer=encrypted_answer, + answer_encryption_config=answer_encryption_config, + group_id=group_id) + try: + with MysqlDB().get_session()as session: + session.add(new_qa_record) + session.commit() + RecordManager.logger.info( + f"Inserted encrypted data succeeded: {user_sub}") + except Exception as e: + RecordManager.logger.info( + f"Insert encrypted data failed: {e}") + del question_encryption_config + del answer_encryption_config + + @staticmethod + def query_encrypted_data_by_conversation_id(conversation_id, total_pairs=None, group_id=None, order: Literal["desc", "asc"] = "desc"): + if order == "desc": + order_func = desc + else: + order_func = asc + + results = [] + try: + with MysqlDB().get_session() as session: + subquery = session.query( + Record, + Comment.is_like, + func.row_number().over( + partition_by=Record.group_id, + order_by=order_func(Record.created_time) + ).label("rn") + ).join( + Comment, Record.record_id == Comment.record_id, isouter=True + ).filter( + Record.conversation_id == conversation_id + ).subquery() + + if group_id is not None: + query = session.query(subquery).filter( + subquery.c.group_id != group_id, subquery.c.rn == 1).order_by( + order_func(subquery.c.created_time)) + else: + query = session.query(subquery).filter(subquery.c.rn == 1).order_by(order_func(subquery.c.created_time)) + + if total_pairs is not None: + query = query.limit(total_pairs) + else: + query = query + + query_results = query.all() + for re in query_results: + res = RecordQueryData( + conversation_id=re.conversation_id, record_id=re.record_id, + encrypted_answer=re.encrypted_answer, encrypted_question=re.encrypted_question, + created_time=str(re.created_time), + is_like=re.is_like, group_id=re.group_id, question_encryption_config=json.loads( + re.question_encryption_config), + answer_encryption_config=json.loads(re.answer_encryption_config)) + results.append(res) + except Exception as e: + RecordManager.logger.info( + f"Query encrypted data by conversation_id failed: {e}") + + return results + + @staticmethod + def query_encrypted_data_by_record_id(record_id): + try: + with MysqlDB().get_session() as session: + result = session.query(Record).filter( + Record.record_id == record_id).first() + return result + except Exception as e: + RecordManager.logger.info( + f"query encrypted data by record_id failed: {e}") + + @staticmethod + def delete_encrypted_qa_pair_by_conversation_id(conversation_id): + try: + with MysqlDB().get_session() as session: + session.query(Record) \ + .filter(Record.conversation_id == conversation_id) \ + .delete() + session.commit() + except Exception as e: + RecordManager.logger.info( + f"Query encrypted data by conversation_id failed: {e}") diff --git a/apps/manager/session.py b/apps/manager/session.py new file mode 100644 index 000000000..7373fbfdc --- /dev/null +++ b/apps/manager/session.py @@ -0,0 +1,166 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from __future__ import annotations + +import base64 +import hashlib +import hmac +import logging +import secrets +from typing import Any, Dict + +from apps.common.config import config +from apps.entities.user import User +from apps.manager.blacklist import UserBlacklistManager +from apps.manager.user import UserManager +from apps.models.redis import RedisConnectionPool + +logger = logging.getLogger("gunicorn.error") + + +class SessionManager: + def __init__(self): + raise NotImplementedError("SessionManager不可以被实例化") + + @staticmethod + def create_session(ip: str , extra_keys: Dict[str, Any] | None = None) -> str: + session_id = secrets.token_hex(16) + data = { + "ip": ip + } + if config["DISABLE_LOGIN"]: + data.update({ + "user_sub": config["DEFAULT_USER"] + }) + + if extra_keys is not None: + data.update(extra_keys) + with RedisConnectionPool.get_redis_connection() as r: + try: + r.hmset(session_id, data) + r.expire(session_id, config["SESSION_TTL"] * 60) + except Exception as e: + logger.error(f"Session error: {e}") + return session_id + + @staticmethod + def delete_session(session_id: str) -> bool: + if not session_id: + return True + with RedisConnectionPool.get_redis_connection() as r: + try: + if not r.exists(session_id): + return True + num = r.delete(session_id) + if num != 1: + return True + return False + except Exception as e: + logger.error(f"Delete session error: {e}") + return False + + @staticmethod + def get_session(session_id: str, session_ip: str) -> str: + if not session_id: + session_id = SessionManager.create_session(session_ip) + return session_id + + ip = None + with RedisConnectionPool.get_redis_connection() as r: + try: + ip = r.hget(session_id, "ip").decode() + r.expire(session_id, config["SESSION_TTL"] * 60) + except Exception as e: + logger.error(f"Read session error: {e}") + + return session_id + + @staticmethod + def verify_user(session_id: str) -> bool: + with RedisConnectionPool.get_redis_connection() as r: + try: + user_exist = r.hexists(session_id, "user_sub") + r.expire(session_id, config["SESSION_TTL"] * 60) + return user_exist + except Exception as e: + logger.error(f"User not in session: {e}") + return False + + @staticmethod + def get_user(session_id: str) -> User | None: + # 从session_id查询user_sub + with RedisConnectionPool.get_redis_connection() as r: + try: + user_sub = r.hget(session_id, "user_sub") + r.expire(session_id, config["SESSION_TTL"] * 60) + except Exception as e: + logger.error(f"Get user from session error: {e}") + return None + + # 查询黑名单 + if UserBlacklistManager.check_blacklisted_users(user_sub): + logger.error("User in session blacklisted.") + with RedisConnectionPool.get_redis_connection() as r: + try: + r.hdel(session_id, "user_sub") + r.expire(session_id, config["SESSION_TTL"] * 60) + return None + except Exception as e: + logger.error(f"Delete user from session error: {e}") + return None + + user = UserManager.get_userinfo_by_user_sub(user_sub) + return User(user_sub=user.user_sub, revision_number=user.revision_number) + + @staticmethod + def create_csrf_token(session_id: str) -> str | None: + rand = secrets.token_hex(8) + + with RedisConnectionPool.get_redis_connection() as r: + try: + r.hset(session_id, "nonce", rand) + r.expire(session_id, config["SESSION_TTL"] * 60) + except Exception as e: + logger.error(f"Create csrf token from session error: {e}") + return None + + csrf_value = f"{session_id}{rand}" + csrf_b64 = base64.b64encode(bytes.fromhex(csrf_value)) + + hmac_processor = hmac.new(key=bytes.fromhex(config["JWT_KEY"]), msg=csrf_b64, digestmod=hashlib.sha256) + signature = base64.b64encode(hmac_processor.digest()) + + csrf_b64 = csrf_b64.decode("utf-8") + signature = signature.decode("utf-8") + return f"{csrf_b64}.{signature}" + + @staticmethod + def verify_csrf_token(session_id: str, token: str) -> bool: + if not token: + return False + + token_msg = token.split(".") + if len(token_msg) != 2: + return False + + first_part = base64.b64decode(token_msg[0]).hex() + current_session_id = first_part[:32] + logger.error(f"current_session_id: {current_session_id}, session_id: {session_id}") + if current_session_id != session_id: + return False + + current_nonce = first_part[32:] + with RedisConnectionPool.get_redis_connection() as r: + try: + nonce = r.hget(current_session_id, "nonce") + if nonce != current_nonce: + return False + r.expire(current_session_id, config["SESSION_TTL"] * 60) + except Exception as e: + logger.error(f"Get csrf token from session error: {e}") + + hmac_obj = hmac.new(key=bytes.fromhex(config["JWT_KEY"]), + msg=token_msg[0].encode("utf-8"), digestmod=hashlib.sha256) + signature = hmac_obj.digest() + current_signature = base64.b64decode(token_msg[1]) + + return hmac.compare_digest(signature, current_signature) diff --git a/apps/manager/user.py b/apps/manager/user.py new file mode 100644 index 000000000..8f962c412 --- /dev/null +++ b/apps/manager/user.py @@ -0,0 +1,108 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +from datetime import datetime, timezone + +import pytz + +from apps.entities.user import User as UserInfo +from apps.models.mysql import MysqlDB, User + + +class UserManager: + logger = logging.getLogger('gunicorn.error') + + @staticmethod + def add_userinfo(userinfo: User): + user_slice = User( + user_sub=userinfo.user_sub, + revision_number=userinfo.revision_number, + login_time=datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')) + ) + try: + with MysqlDB().get_session() as session: + session.add(user_slice) + session.commit() + session.refresh(user_slice) + except Exception as e: + UserManager.logger.info(f"Add userinfo failed due to error: {e}") + + @staticmethod + def get_all_user_sub(): + result = None + try: + with MysqlDB().get_session() as session: + result = session.query(User.user_sub).all() + except Exception as e: + UserManager.logger.info( + f"Get all user_sub failed due to error: {e}") + return result + + @staticmethod + def get_userinfo_by_user_sub(user_sub): + result = None + try: + with MysqlDB().get_session() as session: + result = session.query(User).filter( + User.user_sub == user_sub).first() + except Exception as e: + UserManager.logger.info( + f"Get userinfo by user_sub failed due to error: {e}") + return result + + @staticmethod + def get_revision_number_by_user_sub(user_sub): + userinfo = UserManager.get_userinfo_by_user_sub(user_sub) + revision_number = None + if userinfo is not None: + revision_number = userinfo.revision_number + return revision_number + + @staticmethod + def update_userinfo_by_user_sub(userinfo: UserInfo, refresh_revision=False): + user_slice = UserManager.get_userinfo_by_user_sub( + userinfo.user_sub) + if not user_slice: + UserManager.add_userinfo(userinfo) + return UserManager.get_userinfo_by_user_sub(userinfo.user_sub) + user_dict = { + "user_sub": userinfo.user_sub, + "login_time": datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')) + } + if refresh_revision: + user_dict.update({"revision_number": userinfo.revision_number}) + try: + with MysqlDB().get_session() as session: + session.query(User).filter(User.user_sub == userinfo.user_sub).update(user_dict) + session.commit() + except Exception as e: + UserManager.logger.info( + f"Update userinfo by user_sub failed due to error: {e}") + ret = UserManager.get_userinfo_by_user_sub(userinfo.user_sub) + ret_dict = ret.__dict__ + if '_sa_instance_state' in ret_dict: + del ret_dict['_sa_instance_state'] + return User(**ret_dict) + + @staticmethod + def query_userinfo_by_login_time(login_time): + result = [] + try: + with MysqlDB().get_session() as session: + result = session.query(User).filter( + User.login_time <= login_time).all() + except Exception as e: + UserManager.logger.info( + f"Get userinfo by login_time failed due to error: {e}") + return result + + @staticmethod + def delete_userinfo_by_user_sub(user_sub): + try: + with MysqlDB().get_session() as session: + session.query(User).filter( + User.user_sub == user_sub).delete() + session.commit() + except Exception as e: + UserManager.logger.info( + f"Delete userinfo by user_sub failed due to error: {e}") diff --git a/apps/manager/user_domain.py b/apps/manager/user_domain.py new file mode 100644 index 000000000..cf7ddbe1f --- /dev/null +++ b/apps/manager/user_domain.py @@ -0,0 +1,77 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from datetime import datetime, timezone + +import pytz +import logging + +from apps.models.mysql import MysqlDB, UserDomain, Domain +from sqlalchemy import desc + +logger = logging.getLogger('gunicorn.error') + + +class UserDomainManager: + def __init__(self): + raise NotImplementedError() + + @staticmethod + def get_user_domain_by_user_sub_and_domain_name(user_sub, domain_name): + result = None + try: + with MysqlDB().get_session() as session: + result = session.query(UserDomain).filter( + UserDomain.user_sub == user_sub, UserDomain.domain_name == domain_name).first() + except Exception as e: + logger.info(f"Get user_domain by user_sub_and_domain_name failed: {e}") + return result + + @staticmethod + def get_user_domain_by_user_sub(user_sub): + results = [] + try: + with MysqlDB().get_session() as session: + results = session.query(UserDomain).filter( + UserDomain.user_sub == user_sub).all() + except Exception as e: + logger.info(f"Get user_domain by user_sub failed: {e}") + return results + + @staticmethod + def get_user_domain_by_user_sub_and_topk(user_sub, topk): + results = [] + try: + with MysqlDB().get_session() as session: + results = session.query(UserDomain.domain_count, Domain.domain_name, Domain.domain_description).join(Domain, UserDomain.domain_name==Domain.domain_name).filter( + UserDomain.user_sub == user_sub).order_by( + desc(UserDomain.domain_count)).limit(topk).all() + except Exception as e: + logger.info(f"Get user_domain by user_sub and topk failed: {e}") + return results + + @staticmethod + def update_user_domain_by_user_sub_and_domain_name(user_sub, domain_name): + result = None + try: + with MysqlDB().get_session() as session: + cur_user_domain = UserDomainManager.get_user_domain_by_user_sub_and_domain_name(user_sub, domain_name) + if not cur_user_domain: + cur_user_domain = UserDomain( + user_sub=user_sub, domain_name=domain_name, domain_count=1, + created_time=datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')), + updated_time=datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai'))) + session.add(cur_user_domain) + session.commit() + else: + update_dict = { + "domain_count": cur_user_domain.domain_count+1, + "updated_time": datetime.now(timezone.utc).astimezone(pytz.timezone('Asia/Shanghai')) + } + session.query(UserDomain).filter(UserDomain.user_sub == user_sub, + UserDomain.domain_name == domain_name).update(update_dict) + session.commit() + result = UserDomainManager.get_user_domain_by_user_sub_and_domain_name(user_sub, domain_name) + except Exception as e: + logger.info(f"Update user_domain by user_sub and domain_name failed due to error: {e}") + finally: + return result diff --git a/apps/models/__init__.py b/apps/models/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/models/mysql.py b/apps/models/mysql.py new file mode 100644 index 000000000..3c263a657 --- /dev/null +++ b/apps/models/mysql.py @@ -0,0 +1,196 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from datetime import datetime + +import pytz +from sqlalchemy import BigInteger, Column, DateTime, Integer, String, create_engine, Boolean, Text +from sqlalchemy.orm import declarative_base, sessionmaker +import logging + +from apps.common.config import config +from apps.common.singleton import Singleton + +Base = declarative_base() + + +class User(Base): + __tablename__ = "user" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + user_sub = Column(String(length=100)) + revision_number = Column(String(length=100), nullable=True) + login_time = Column(DateTime, nullable=True) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + credit = Column(Integer, default=100) + is_whitelisted = Column(Boolean, default=False) + + +class Conversation(Base): + __tablename__ = "conversation" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + conversation_id = Column(String(length=100), unique=True) + summary = Column(Text, nullable=True) + user_sub = Column(String(length=100)) + title = Column(String(length=200)) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + + +class Record(Base): + __tablename__ = "record" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + conversation_id = Column(String(length=100), index=True) + record_id = Column(String(length=100)) + encrypted_question = Column(Text) + question_encryption_config = Column(String(length=1000)) + encrypted_answer = Column(Text) + answer_encryption_config = Column(String(length=1000)) + structured_output = Column(Text) + flow_id = Column(String(length=100)) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + group_id = Column(String(length=100), nullable=True) + + +class AuditLog(Base): + __tablename__ = "audit_log" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + user_sub = Column(String(length=100), nullable=True) + method_type = Column(String(length=100), nullable=True) + source_name = Column(String(length=100), nullable=True) + ip = Column(String(length=100), nullable=True) + result = Column(String(length=100), nullable=True) + reason = Column(String(length=100), nullable=True) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + + +class Comment(Base): + __tablename__ = "comment" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + record_id = Column(String(length=100), unique=True) + is_like = Column(Boolean, nullable=True) + dislike_reason = Column(String(length=100), nullable=True) + reason_link = Column(String(length=200), nullable=True) + reason_description = Column(String(length=500), nullable=True) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + user_sub = Column(String(length=100), nullable=True) + + +class ApiKey(Base): + __tablename__ = "api_key" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + user_sub = Column(String(length=100)) + api_key_hash = Column(String(length=16)) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + + +class QuestionBlacklist(Base): + __tablename__ = "blacklist" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + question = Column(Text) + answer = Column(Text) + is_audited = Column(Boolean, default=False) + reason_description = Column(String(length=200)) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + + +class Domain(Base): + __tablename__ = "domain" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + domain_name = Column(String(length=100)) + domain_description = Column(Text) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + updated_time = Column(DateTime) + + +class GiteeIDWhiteList(Base): + __tablename__ = "gitee_id_white_list" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger, primary_key=True, autoincrement=True) + gitee_id = Column(String(length=100)) + + +class UserDomain(Base): + __tablename__ = "user_domain" + __table_args__ = { + "mysql_engine": "InnoDB", + "mysql_charset": "utf8mb4", + "mysql_collate": "utf8mb4_general_ci" + } + id = Column(BigInteger().with_variant(Integer, "sqlite"), primary_key=True, autoincrement=True) + user_sub = Column(String(length=100)) + domain_name = Column(String(length=100)) + domain_count = Column(Integer) + created_time = Column(DateTime, default=lambda: datetime.now(pytz.timezone('Asia/Shanghai'))) + updated_time = Column(DateTime) + + +class MysqlDB(metaclass=Singleton): + + def __init__(self): + self.logger = logging.getLogger('gunicorn.error') + try: + self.engine = create_engine( + f'mysql+pymysql://{config["MYSQL_USER"]}:{config["MYSQL_PWD"]}' + f'@{config["MYSQL_HOST"]}/{config["MYSQL_DATABASE"]}', + hide_parameters=True, + echo=False, + pool_recycle=300, + pool_pre_ping=True) + Base.metadata.create_all(self.engine) + except Exception as e: + self.logger.info(f"Error creating a session: {e}") + + def get_session(self): + try: + return sessionmaker(bind=self.engine)() + except Exception as e: + self.logger.info(f"Error creating a session: {e}") + return None + + def close(self): + try: + self.engine.dispose() + except Exception as e: + self.logger.info(f"Error closing the engine: {e}") diff --git a/apps/models/redis.py b/apps/models/redis.py new file mode 100644 index 000000000..51f356e4a --- /dev/null +++ b/apps/models/redis.py @@ -0,0 +1,43 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import redis +import logging + +from apps.common.config import config + + +class RedisConnectionPool: + _redis_pool = None + logger = logging.getLogger('gunicorn.error') + + @classmethod + def get_redis_pool(cls): + if not cls._redis_pool: + cls._redis_pool = redis.ConnectionPool( + host=config['REDIS_HOST'], + port=config['REDIS_PORT'], + password=config['REDIS_PWD'] + ) + return cls._redis_pool + + @classmethod + def get_redis_connection(cls): + try: + pool = redis.Redis(connection_pool=cls.get_redis_pool()) + except Exception as e: + cls.logger.error(f"Init redis connection failed: {e}") + return None + return cls._ConnectionManager(pool) + + class _ConnectionManager: + def __init__(self, connection): + self.connection = connection + + def __enter__(self): + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.connection.close() + except Exception as e: + RedisConnectionPool.logger.error(f"Redis connection close failed: {e}") diff --git a/apps/routers/__init__.py b/apps/routers/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/routers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py new file mode 100644 index 000000000..133e4d5b6 --- /dev/null +++ b/apps/routers/api_key.py @@ -0,0 +1,44 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from fastapi import APIRouter, Depends, status + +from apps.dependency.user import get_user, verify_user +from apps.dependency.csrf import verify_csrf_token +from apps.entities.response_data import ResponseData +from apps.entities.user import User +from apps.manager.api_key import ApiKeyManager + +router = APIRouter( + prefix="/api/auth/key", + tags=["key"], + dependencies=[Depends(verify_user)] +) + + +@router.get("", response_model=ResponseData) +def check_api_key_existence(user: User = Depends(get_user)): + exists: bool = ApiKeyManager.api_key_exists(user) + return ResponseData(code=status.HTTP_200_OK, message="success", result={ + "api_key_exists": exists + }) + + +@router.post("", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +def manage_api_key(action: str, user: User = Depends(get_user)): + action = action.lower() + if action == "create": + api_key: str = ApiKeyManager.generate_api_key(user) + elif action == "update": + api_key: str = ApiKeyManager.update_api_key(user) + elif action == "delete": + success = ApiKeyManager.delete_api_key(user) + if success: + return ResponseData(code=status.HTTP_200_OK, message="success", result={}) + return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="failed to revoke api key", result={}) + else: + return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="invalid request body", result={}) + if api_key is None: + return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="failed to generate api key", result={}) + return ResponseData(code=status.HTTP_200_OK, message="success", result={ + "api_key": api_key + }) diff --git a/apps/routers/auth.py b/apps/routers/auth.py new file mode 100644 index 000000000..167c89ca1 --- /dev/null +++ b/apps/routers/auth.py @@ -0,0 +1,170 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import RedirectResponse + +from apps.common.config import config +from apps.common.oidc import get_oidc_token, get_oidc_user +from apps.dependency.csrf import verify_csrf_token +from apps.dependency.user import get_user, verify_user +from apps.entities.request_data import ModifyRevisionData +from apps.entities.response_data import ResponseData +from apps.entities.user import User +from apps.manager.audit_log import AuditLogData, AuditLogManager +from apps.manager.session import SessionManager +from apps.manager.user import UserManager +from apps.models.redis import RedisConnectionPool + +logger = logging.getLogger('gunicorn.error') + +router = APIRouter( + prefix="/api/auth", + tags=["auth"] +) + + +@router.get("/login", response_class=RedirectResponse) +async def oidc_login(request: Request, code: str, redirect_index: str = None): + if redirect_index: + response = RedirectResponse(redirect_index, status_code=301) + else: + response = RedirectResponse(config["WEB_FRONT_URL"], status_code=301) + try: + token = await get_oidc_token(code) + user_info = await get_oidc_user(token["access_token"], token["refresh_token"]) + user_sub: str | None = user_info.get('user_sub', None) + except Exception as e: + logger.error(f"User login failed: {e}") + if 'auth error' in str(e): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="auth error") + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User login failed.") + + user_host = request.client.host + if not user_sub: + logger.error("OIDC no user_sub associated.") + data = AuditLogData(method_type='get', source_name='/authorize/login', + ip=user_host, result='fail', reason="OIDC no user_sub associated.") + AuditLogManager.add_audit_log('None', data) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User login failed.") + + UserManager.update_userinfo_by_user_sub(User(**user_info)) + + current_session = request.cookies.get("ECSESSION") + try: + SessionManager.delete_session(current_session) + current_session = SessionManager.create_session(user_host, extra_keys={ + "user_sub": user_sub + }) + except Exception as e: + logger.error(f"Change session failed: {e}") + data = AuditLogData(method_type='get', source_name='/authorize/login', + ip=user_host, result='fail', reason="Change session failed.") + AuditLogManager.add_audit_log(user_sub, data) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User login failed.") + + new_csrf_token = SessionManager.create_csrf_token(current_session) + if config['COOKIE_MODE'] == 'DEBUG': + response.set_cookie( + "_csrf_tk", + new_csrf_token + ) + response.set_cookie( + "ECSESSION", + current_session + ) + else: + response.set_cookie( + "_csrf_tk", + new_csrf_token, + max_age=config["SESSION_TTL"] * 60, + secure=True, + domain=config["DOMAIN"], + samesite="strict" + ) + response.set_cookie( + "ECSESSION", + current_session, + max_age=config["SESSION_TTL"] * 60, + secure=True, + domain=config["DOMAIN"], + httponly=True, + samesite="strict" + ) + data = AuditLogData( + method_type='get', + source_name='/authorize/login', + ip=user_host, + result='success', + reason="User login." + ) + AuditLogManager.add_audit_log(user_sub, data) + + return response + + +# 用户主动logout +@router.get("/logout", response_model=ResponseData, dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) +async def logout(request: Request, response: Response, user: User = Depends(get_user)): + session_id = request.cookies['ECSESSION'] + if not SessionManager.verify_user(session_id): + logger.info("User already logged out.") + return ResponseData(code=200, message="ok", result={}) + + # 删除 oidc related token + user_sub = user.user_sub + with RedisConnectionPool.get_redis_connection() as r: + r.delete(f'{user_sub}_oidc_access_token') + r.delete(f'{user_sub}_oidc_refresh_token') + r.delete(f'aops_{user_sub}_token') + + SessionManager.delete_session(session_id) + new_session = SessionManager.create_session(request.client.host) + + response.set_cookie("ECSESSION", new_session, max_age=config["SESSION_TTL"] * 60, + httponly=True, secure=True, samesite="strict", domain=config["DOMAIN"]) + response.delete_cookie("_csrf_tk") + + data = AuditLogData(method_type='get', source_name='/authorize/logout', + ip=request.client.host, result='User logout succeeded.', reason='') + AuditLogManager.add_audit_log(user.user_sub, data) + return { + "code": 200, + "message": "success", + "result": dict() + } + + +@router.get("/redirect") +async def oidc_redirect(): + return { + "code": 200, + "message": "success", + "result": config["OIDC_REDIRECT_URL"] + } + + +@router.get("/user", dependencies=[Depends(verify_user)], response_model=ResponseData) +async def userinfo(user: User = Depends(get_user)): + revision_number = UserManager.get_revision_number_by_user_sub(user_sub=user.user_sub) + user.revision_number = revision_number + return { + "code": 200, + "message": "success", + "result": user.__dict__ + } + + +@router.post("/update_revision_number", dependencies=[Depends(verify_user), Depends(verify_csrf_token)], + response_model=ResponseData) +async def update_revision_number(post_body: ModifyRevisionData, user: User = Depends(get_user)): + user.revision_number = post_body.revision_num + ret = UserManager.update_userinfo_by_user_sub(user, refresh_revision=True) + return { + "code": 200, + "message": "success", + "result": ret.__dict__ + } diff --git a/apps/routers/blacklist.py b/apps/routers/blacklist.py new file mode 100644 index 000000000..21c861980 --- /dev/null +++ b/apps/routers/blacklist.py @@ -0,0 +1,64 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from fastapi import APIRouter, Depends, Response, status + +from apps.dependency.user import verify_user, get_user +from apps.dependency.csrf import verify_csrf_token +from apps.entities.blacklist import ( + AbuseProcessRequest, + AbuseRequest, + QuestionBlacklistRequest, + UserBlacklistRequest, +) +from apps.entities.response_data import ResponseData +from apps.manager.blacklist import ( + AbuseManager, + QuestionBlacklistManager, + UserBlacklistManager, +) +from apps.models.mysql import User + +router = APIRouter( + prefix="/api/blacklist", + tags=["blacklist"], + dependencies=[Depends(verify_user)], +) + +PAGE_SIZE = 20 +MAX_CREDIT = 100 + + +# 通用返回函数 +def check_result(result: any, response: Response, error_msg: str) -> ResponseData: + if result is None: + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + return ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=error_msg, + result={} + ) + else: + if isinstance(result, dict): + response.status_code = status.HTTP_200_OK + return ResponseData( + code=status.HTTP_200_OK, + message="ok", + result=result + ) + else: + response.status_code = status.HTTP_200_OK + return ResponseData( + code=status.HTTP_200_OK, + message="ok", + result={"value": result} + ) + +# 用户实施举报 +@router.post("/complaint", dependencies=[Depends(verify_csrf_token)]) +async def abuse_report(request: AbuseRequest, response: Response, user: User = Depends(get_user)): + result = AbuseManager.change_abuse_report( + user.user_sub, + request.record_id, + request.reason + ) + return check_result(result, response, "Report abuse complaint error.") diff --git a/apps/routers/chat.py b/apps/routers/chat.py new file mode 100644 index 000000000..bd31ddc1f --- /dev/null +++ b/apps/routers/chat.py @@ -0,0 +1,199 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import json +import logging +import uuid + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse + +from apps.common.wordscheck import WordsCheck +from apps.dependency.csrf import verify_csrf_token +from apps.dependency.limit import moving_window_limit +from apps.dependency.user import get_session, get_user, verify_user +from apps.entities.request_data import RequestData +from apps.entities.response_data import ResponseData +from apps.entities.user import User +from apps.manager.blacklist import ( + QuestionBlacklistManager, + UserBlacklistManager, +) +from apps.manager.conversation import ConversationManager +from apps.manager.record import RecordManager +from apps.scheduler.scheduler import Scheduler +from apps.service import RAG, Activity, ChatSummary, Suggestion +from apps.service.history import History + +logger = logging.getLogger('gunicorn.error') +RECOMMEND_TRES = 5 + +router = APIRouter( + prefix="/api", + tags=["chat"] +) + + +async def generate_content_stream(user_sub, session_id: str, post_body: RequestData): + if not Activity.is_active(user_sub): + raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") + + try: + if await WordsCheck.check(post_body.question) != 1: + yield "data: [SENSITIVE]\n\n" + return + except Exception as e: + logger.error(msg="敏感词检查失败:{}".format(str(e))) + yield "data: [ERROR]\n\n" + Activity.remove_active(user_sub) + return + + try: + summary = History.get_summary(post_body.conversation_id) + group_id, history = History.get_history_messages(post_body.conversation_id, post_body.record_id) + except Exception as e: + logger.error("获取历史记录失败!{}".format(str(e))) + yield "data: [ERROR]\n\n" + Activity.remove_active(user_sub) + return + + # 找出当前执行的Flow ID + if post_body.user_selected_flow is None: + logger.info("Executing: {}".format(post_body.user_selected_flow)) + flow_id = await Scheduler.choose_flow( + question=post_body.question, + user_selected_plugins=post_body.user_selected_plugins + ) + else: + flow_id = post_body.user_selected_flow + + # 如果flow_id还是None:调用智能问答 + full_answer = "" + if flow_id is None: + logger.info("Executing: KnowledgeBase") + async for line in RAG.get_rag_result( + user_sub, + post_body.question, + post_body.language, + history + ): + if Activity.is_active(user_sub): + yield line + try: + data = json.loads(line[6:])["content"] + full_answer += data + except Exception: + continue + + # 否则:执行特定Flow + else: + logger.info("Executing: {}".format(flow_id)) + async for line in Scheduler.run_certain_flow( + user_selected_flow=flow_id, + question=post_body.question, + files=post_body.files, + context=summary, + session_id=session_id + ): + if Activity.is_active(user_sub): + yield line + try: + data = json.loads(line[6:])["content"] + full_answer += data + except Exception: + continue + + # 对结果进行敏感词检查 + if await WordsCheck.check(full_answer) != 1: + yield "data: [SENSITIVE]\n\n" + return + # 存入数据库,更新Summary + record_id = str(uuid.uuid4().hex) + RecordManager().insert_encrypted_data( + post_body.conversation_id, + record_id, + group_id, + user_sub, + post_body.question, + full_answer + ) + Suggestion.update_user_domain(user_sub, post_body.question, full_answer) + new_summary = await ChatSummary.generate_chat_summary( + last_summary=summary, question=post_body.question, answer=full_answer) + del summary + ConversationManager.update_summary(post_body.conversation_id, new_summary) + yield 'data: {"qa_record_id": "' + record_id + '"}\n\n' + + if len(post_body.user_selected_plugins) != 0: + # 如果选择了插件,走Flow推荐 + suggestions = await Scheduler.plan_next_flow( + question=post_body.question, + summary=new_summary, + user_selected_plugins=post_body.user_selected_plugins, + current_flow_name=flow_id + ) + else: + # 如果未选择插件,不走Flow推荐 + suggestions = [] + + # 限制推荐个数 + if len(suggestions) < RECOMMEND_TRES: + domain_suggestions = Suggestion.generate_suggestions( + post_body.conversation_id, summary=new_summary, question=post_body.question, answer=full_answer) + for i in range(min(RECOMMEND_TRES - len(suggestions), 3)): + suggestions.append(domain_suggestions[i]) + yield 'data: {"search_suggestions": ' + json.dumps(suggestions, ensure_ascii=False) + '}' + '\n\n' + + # 删除活跃标识 + del new_summary + if not Activity.is_active(user_sub): + return + + yield 'data: [DONE]\n\n' + Activity.remove_active(user_sub) + + +async def natural_language_post_func(post_body: RequestData, user: User, session_id: str): + user_sub = user.user_sub + try: + headers = { + "X-Accel-Buffering": "no" + } + # 问题黑名单检测 + if QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): + res = generate_content_stream(user_sub, session_id, post_body) + else: + # 用户扣分 + UserBlacklistManager.change_blacklisted_users(user_sub, -10) + res_data = ['data: [SENSITIVE]' + '\n\n'] + res = iter(res_data) + + response = StreamingResponse( + content=res, + media_type="text/event-stream", + headers=headers + ) + return response + except Exception as ex: + logger.info(f"Get stream answer failed due to error: {ex}") + return HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@router.post("/chat", dependencies=[Depends(verify_csrf_token), Depends(verify_user)]) +@moving_window_limit +async def natural_language_post( + post_body: RequestData, + user: User = Depends(get_user), + session_id: str = Depends(get_session) +): + return await natural_language_post_func(post_body, user, session_id) + + +@router.post("/stop", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +async def stop_generation(user: User = Depends(get_user)): + user_sub = user.user_sub + Activity.remove_active(user_sub) + return ResponseData( + code=status.HTTP_200_OK, + message="stop generation success", + result={} + ) diff --git a/apps/routers/client.py b/apps/routers/client.py new file mode 100644 index 000000000..810ec3f81 --- /dev/null +++ b/apps/routers/client.py @@ -0,0 +1,77 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from typing import Optional + +from fastapi import APIRouter, Depends, status +from starlette.requests import HTTPConnection + +from apps.dependency.limit import moving_window_limit +from apps.dependency.user import get_user_by_api_key, verify_api_key +from apps.entities.plugin import PluginListData +from apps.entities.request_data import ClientChatRequestData, ClientSessionData, RequestData +from apps.entities.response_data import ResponseData +from apps.entities.user import User +from apps.manager.session import SessionManager +from apps.routers.chat import natural_language_post_func +from apps.routers.conversation import add_conversation_func +from apps.scheduler.pool.pool import Pool +from apps.service import Activity + +router = APIRouter( + prefix="/api/client", + tags=["client"] +) + + +@router.post("/session", response_model=ResponseData) +async def get_session_id( + request: HTTPConnection, + post_body: ClientSessionData, + user: User = Depends(get_user_by_api_key) +): + session_id: Optional[str] = post_body.session_id + if session_id and not SessionManager.verify_user(session_id) or not session_id: + return ResponseData( + code=status.HTTP_200_OK, message="gen new session id success", result={ + "session_id": SessionManager.create_session(request.client.host, extra_keys={ + "user_sub": user.user_sub + }) + } + ) + return ResponseData( + code=status.HTTP_200_OK, message="verify session id success", result={"session_id": session_id} + ) + + +@router.get("/plugin", response_model=PluginListData, dependencies=[Depends(verify_api_key)]) +async def get_plugin_list(): + return PluginListData(code=status.HTTP_200_OK, message="success", result=Pool().get_plugin_list()) + + +@router.post("/conversation", response_model=ResponseData) +async def add_conversation(user: User = Depends(get_user_by_api_key)): + return await add_conversation_func(user) + + +@router.post("/chat") +@moving_window_limit +async def natural_language_post(post_body: ClientChatRequestData, user: User = Depends(get_user_by_api_key)): + body: RequestData = RequestData( + question=post_body.question, + language=post_body.language, + conversation_id=post_body.conversation_id, + record_id=post_body.record_id, + user_selected_plugins=post_body.user_selected_plugins, + user_selected_flow=post_body.user_selected_flow, + files=post_body.files, + flow_id=post_body.flow_id, + ) + session_id: str = post_body.session_id + return await natural_language_post_func(body, user, session_id) + + +@router.post("/stop", response_model=ResponseData) +async def stop_generation(user: User = Depends(get_user_by_api_key)): + user_sub = user.user_sub + Activity.remove_active(user_sub) + return ResponseData(code=status.HTTP_200_OK, message="stop generation success", result={}) diff --git a/apps/routers/comment.py b/apps/routers/comment.py new file mode 100644 index 000000000..8630d405f --- /dev/null +++ b/apps/routers/comment.py @@ -0,0 +1,48 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import json + +from fastapi import APIRouter, Depends, status, HTTPException +import logging + +from apps.dependency.user import verify_user, get_user +from apps.dependency.csrf import verify_csrf_token +from apps.entities.request_data import AddCommentData +from apps.entities.response_data import ResponseData +from apps.entities.user import User +from apps.manager.comment import CommentData, CommentManager +from apps.manager.record import RecordManager +from apps.manager.conversation import ConversationManager + + +router = APIRouter( + prefix="/api/comment", + tags=["comment"], + dependencies=[ + Depends(verify_user) + ] +) +logger = logging.getLogger('gunicorn.error') + + +@router.post("", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +async def add_comment(post_body: AddCommentData, user: User = Depends(get_user)): + user_sub = user.user_sub + cur_record = RecordManager.query_encrypted_data_by_record_id( + post_body.record_id) + if not cur_record: + logger.error("Comment: record_id not found.") + raise HTTPException(status_code=status.HTTP_204_NO_CONTENT) + cur_conv = ConversationManager.get_conversation_by_conversation_id( + cur_record.conversation_id) + if not cur_conv or cur_conv.user_sub != user.user_sub: + logger.error("Comment: conversation_id not found.") + raise HTTPException(status_code=status.HTTP_204_NO_CONTENT) + cur_comment = CommentManager.query_comment(post_body.record_id) + comment_data = CommentData(post_body.record_id, post_body.is_like, post_body.dislike_reason, + post_body.reason_link, post_body.reason_description) + if cur_comment: + CommentManager.update_comment(user_sub, comment_data) + else: + CommentManager.add_comment(user_sub, comment_data) + return ResponseData(code=status.HTTP_200_OK, message="success", result={}) diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py new file mode 100644 index 000000000..43d3c52a8 --- /dev/null +++ b/apps/routers/conversation.py @@ -0,0 +1,127 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import logging +from datetime import datetime + +import pytz +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status + +from apps.constants import NEW_CHAT +from apps.dependency.csrf import verify_csrf_token +from apps.dependency.user import get_user, verify_user +from apps.entities.request_data import DeleteConversationData, ModifyConversationData +from apps.entities.response_data import ConversationData, ConversationListData, ResponseData +from apps.entities.user import User +from apps.manager.audit_log import AuditLogData, AuditLogManager +from apps.manager.conversation import ConversationManager +from apps.manager.record import RecordManager + +router = APIRouter( + prefix="/api/conversation", + tags=["conversation"], + dependencies=[ + Depends(verify_user) + ] +) +logger = logging.getLogger('gunicorn.error') + + +@router.get("", response_model=ConversationListData) +async def get_conversation_list(user: User = Depends(get_user)): + user_sub = user.user_sub + conversations = ConversationManager.get_conversation_by_user_sub(user_sub) + for conv in conversations: + record_list = RecordManager.query_encrypted_data_by_conversation_id(conv.conversation_id) + if not record_list: + ConversationManager.update_conversation_metadata_by_conversation_id( + conv.conversation_id, + NEW_CHAT, + datetime.now(pytz.timezone('Asia/Shanghai')) + ) + break + conversations = ConversationManager.get_conversation_by_user_sub(user_sub) + result_conversations = [] + for conv in conversations: + conv_data = ConversationData( + conversation_id=conv.conversation_id, title=conv.title, created_time=conv.created_time) + result_conversations.append(conv_data) + return ConversationListData(code=status.HTTP_200_OK, message="success", result=result_conversations) + + +@router.post("", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +async def add_conversation(user: User = Depends(get_user)): + return await add_conversation_func(user) + + +@router.put("", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +async def update_conversation( + post_body: ModifyConversationData, + user: User = Depends(get_user), + conversation_id: str = Query(min_length=1, max_length=100) +): + cur_conv = ConversationManager.get_conversation_by_conversation_id( + conversation_id) + if not cur_conv or cur_conv.user_sub != user.user_sub: + logger.error("Conversation: conversation_id not found.") + raise HTTPException(status_code=status.HTTP_204_NO_CONTENT) + conv = ConversationManager.update_conversation_by_conversation_id( + conversation_id, post_body.title) + converse_result = ConversationData( + conversation_id=conv.conversation_id, title=conv.title, created_time=conv.created_time) + return ResponseData(code=status.HTTP_200_OK, message="success", result={ + "conversation": converse_result + }) + + +@router.post("/delete", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +async def delete_conversation(request: Request, post_body: DeleteConversationData, user: User = Depends(get_user)): + deleted_conversation = [] + for conversation_id in post_body.conversation_list: + cur_conv = ConversationManager.get_conversation_by_conversation_id( + conversation_id) + # Session有误,跳过 + if not cur_conv or cur_conv.user_sub != user.user_sub: + continue + + try: + RecordManager.delete_encrypted_qa_pair_by_conversation_id(conversation_id) + ConversationManager.delete_conversation_by_conversation_id(conversation_id) + data = AuditLogData(method_type='post', source_name='/conversation/delete', ip=request.client.host, + result=f'deleted conversation with id: {conversation_id}', reason='') + AuditLogManager.add_audit_log(user.user_sub, data) + deleted_conversation.append(conversation_id) + except Exception as e: + # 删除过程中发生错误,跳过 + logger.error(f"删除Conversation错误:{conversation_id}, {str(e)}") + continue + return ResponseData(code=status.HTTP_200_OK, message="success", result={ + "conversation_id_list": deleted_conversation + }) + + +async def add_conversation_func(user: User): + user_sub = user.user_sub + conversations = ConversationManager.get_conversation_by_user_sub(user_sub) + for conv in conversations: + record_list = RecordManager.query_encrypted_data_by_conversation_id(conv.conversation_id) + if not record_list: + ConversationManager.update_conversation_metadata_by_conversation_id( + conv.conversation_id, + NEW_CHAT, + datetime.now(pytz.timezone('Asia/Shanghai')) + ) + return ResponseData( + code=status.HTTP_200_OK, + message="success", + result={ + "conversation_id": conv.conversation_id + } + ) + conversation_id = ConversationManager.add_conversation_by_user_sub( + user_sub) + if not conversation_id: + return ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, message="generate conversation_id fail", result={}) + return ResponseData(code=status.HTTP_200_OK, message="success", result={ + "conversation_id": conversation_id + }) diff --git a/apps/routers/domain.py b/apps/routers/domain.py new file mode 100644 index 000000000..e06f1f245 --- /dev/null +++ b/apps/routers/domain.py @@ -0,0 +1,49 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from fastapi import APIRouter, Depends, HTTPException, status + +from apps.entities.request_data import AddDomainData +from apps.entities.response_data import ResponseData +from apps.manager.domain import DomainManager +from apps.dependency.csrf import verify_csrf_token +from apps.dependency.user import verify_user + + +router = APIRouter( + prefix='/api/domain', + tags=['domain'], + dependencies=[ + Depends(verify_csrf_token), + Depends(verify_user), + ] +) + + +@router.post('', response_model=ResponseData) +async def add_domain(post_body: AddDomainData): + if DomainManager.get_domain_by_domain_name(post_body.domain_name): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="add domain name is exist.") + if not DomainManager.add_domain(post_body): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="add domain failed") + return ResponseData(code=status.HTTP_200_OK, message="add domain success.", result={}) + + +@router.put('') +async def update_domain(post_body: AddDomainData): + if not DomainManager.get_domain_by_domain_name(post_body.domain_name): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="update domain name is not exist.") + if not DomainManager.update_domain_by_domain_name(post_body): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="update domain failed") + return ResponseData(code=status.HTTP_200_OK, message="update domain success.", result={}) + + +@router.post("/delete", response_model=ResponseData) +async def delete_domain(post_body: AddDomainData): + if not DomainManager.get_domain_by_domain_name(post_body.domain_name): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="delete domain name is not exist.") + if not DomainManager.delete_domain_by_domain_name(post_body): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="delete domain failed") + return ResponseData(code=status.HTTP_200_OK, message="delete domain success.", result={}) diff --git a/apps/routers/file.py b/apps/routers/file.py new file mode 100644 index 000000000..3069e2d5d --- /dev/null +++ b/apps/routers/file.py @@ -0,0 +1,51 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import time +from typing import List + +from fastapi import APIRouter, Depends, File, UploadFile +from starlette.responses import JSONResponse +import aiofiles +import uuid + +from apps.common.config import config +from apps.scheduler.files import Files +from apps.dependency.csrf import verify_csrf_token +from apps.dependency.user import verify_user + +router = APIRouter( + prefix="/api/file", + tags=["file"], + dependencies=[ + Depends(verify_csrf_token), + Depends(verify_user), + ] +) + + +@router.post("") +async def data_report_upload(files: List[UploadFile] = File(...)): + file_ids = [] + + for file in files: + file_id = str(uuid.uuid4()) + file_ids.append(file_id) + + current_filename = file.filename + suffix = current_filename.split(".")[-1] + + async with aiofiles.open("{}/{}.{}".format(config["TEMP_DIR"], file_id, suffix), 'wb') as f: + content = await file.read() + await f.write(content) + + file_metadata = { + "time": time.time(), + "name": current_filename, + "path": "{}/{}.{}".format(config["TEMP_DIR"], file_id, suffix) + } + + Files.add(file_id, file_metadata) + + return JSONResponse(status_code=200, content={ + "files": file_ids, + }) diff --git a/apps/routers/health.py b/apps/routers/health.py new file mode 100644 index 000000000..e020fcf10 --- /dev/null +++ b/apps/routers/health.py @@ -0,0 +1,13 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from fastapi import APIRouter, Response, status + +router = APIRouter( + prefix="/health_check", + tags=["health_check"] +) + + +@router.get("") +def health_check(): + return Response(status_code=status.HTTP_200_OK, content="ok") diff --git a/apps/routers/plugin.py b/apps/routers/plugin.py new file mode 100644 index 000000000..d5698b2c3 --- /dev/null +++ b/apps/routers/plugin.py @@ -0,0 +1,22 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from fastapi import APIRouter, Depends, status + +from apps.dependency.user import verify_user +from apps.entities.plugin import PluginData, PluginListData +from apps.scheduler.pool.pool import Pool + +router = APIRouter( + prefix="/api/plugin", + tags=["plugin"], + dependencies=[ + Depends(verify_user) + ] +) + + +# 前端展示插件详情 +@router.get("", response_model=PluginListData) +async def get_plugin_list(): + plugins = Pool().get_plugin_list() + return PluginListData(code=status.HTTP_200_OK, message="success", result=plugins) diff --git a/apps/routers/record.py b/apps/routers/record.py new file mode 100644 index 000000000..03cf37d6f --- /dev/null +++ b/apps/routers/record.py @@ -0,0 +1,43 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from typing import Union + +from fastapi import APIRouter, Depends, Query, status + +from apps.common.security import Security +from apps.dependency.user import verify_user, get_user +from apps.entities.response_data import RecordData, RecordListData, ResponseData +from apps.entities.user import User +from apps.manager.record import RecordManager +from apps.manager.conversation import ConversationManager + +router = APIRouter( + prefix="/api/record", + tags=["record"], + dependencies=[ + Depends(verify_user) + ] +) + + +@router.get("", response_model=Union[RecordListData, ResponseData]) +async def get_record( + user: User = Depends(get_user), + conversation_id: str = Query(min_length=1, max_length=100) +): + cur_conv = ConversationManager.get_conversation_by_conversation_id( + conversation_id) + if not cur_conv or cur_conv.user_sub != user.user_sub: + return ResponseData(code=status.HTTP_204_NO_CONTENT, message="session_id not found", result={}) + record_list = RecordManager.query_encrypted_data_by_conversation_id(conversation_id, order="asc") + result = [] + for item in record_list: + question = Security.decrypt( + item.encrypted_question, item.question_encryption_config) + answer = Security.decrypt( + item.encrypted_answer, item.answer_encryption_config) + tmp_record = RecordData( + conversation_id=item.conversation_id, record_id=item.record_id, question=question, answer=answer, + created_time=item.created_time, is_like=item.is_like, group_id=item.group_id) + result.append(tmp_record) + return RecordListData(code=status.HTTP_200_OK, message="success", result=result) diff --git a/apps/scheduler/__init__.py b/apps/scheduler/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/scheduler/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py new file mode 100644 index 000000000..e408d9df9 --- /dev/null +++ b/apps/scheduler/call/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.scheduler.call.sql import SQL +from apps.scheduler.call.api.api import API +from apps.scheduler.call.choice import Choice +from apps.scheduler.call.render.render import Render +from apps.scheduler.call.llm import LLM +from apps.scheduler.call.core import CallParams +from apps.scheduler.call.extract import Extract + +exported = [ + SQL, + API, + Choice, + Render, + LLM, + Extract +] diff --git a/apps/scheduler/call/api/__init__.py b/apps/scheduler/call/api/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/scheduler/call/api/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py new file mode 100644 index 000000000..2990e1343 --- /dev/null +++ b/apps/scheduler/call/api/api.py @@ -0,0 +1,199 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 工具:API调用 + +from __future__ import annotations + +import json +from typing import Dict, Tuple, Any, Union +import logging + +from apps.scheduler.call.core import CoreCall, CallParams +from apps.scheduler.gen_json import check_upload_file +from apps.scheduler.files import Files, choose_file +from apps.scheduler.utils import Json +from apps.scheduler.pool.pool import Pool +from apps.manager.plugin_token import PluginTokenManager +from apps.scheduler.call.api.sanitizer import APISanitizer + +from pydantic import Field +from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec +import aiohttp + + +logger = logging.getLogger('gunicorn.error') + + +class APIParams(CallParams): + plugin: str = Field(description="Plugin名称") + endpoint: str = Field(description="API接口HTTP Method 与 URI") + timeout: int = Field(description="工具超时时间", default=300) + retry: int = Field(description="调用发生错误时,最大的重试次数。", default=3) + + +class API(CoreCall): + name = "api" + description = "API调用工具,用于根据给定的用户输入和历史记录信息,向某一个API接口发送请求、获取数据。" + params_obj: APIParams + + server: str + data_type: Union[str, None] = None + session: aiohttp.ClientSession + usage: str + spec: ReducedOpenAPISpec + auth: Dict[str, Any] + session_id: str + + def __init__(self, params: Dict[str, Any]): + self.params_obj = APIParams(**params) + + async def call(self, fixed_params: Union[Dict[str, Any], None] = None): + # 参数 + method, url = self.params_obj.endpoint.split() + + # 从Pool中拿到Plugin的全部OpenAPI Spec + plugin_metadata = Pool().get_plugin(self.params_obj.plugin) + self.spec = Pool.deserialize_data(plugin_metadata.spec, plugin_metadata.signature) + self.auth = json.loads(plugin_metadata.auth) + self.session_id = self.params_obj.session_id + + # 服务器地址,只支持服务器为1个的情况 + self.server = self.spec.servers[0]["url"].rstrip("/") + + spec = None + # 从spec中找出该接口对应的spec + for item in self.spec.endpoints: + name, description, out = item + if name == self.params_obj.endpoint: + spec = item + if spec is None: + raise ValueError("Endpoint not found!") + + self.usage = spec[1] + + # 调用,然后返回数据 + self.session = aiohttp.ClientSession() + try: + result = await self._call_api(method, url, spec) + await self.session.close() + return result + except Exception as e: + await self.session.close() + raise Exception(e) + + async def _make_api_call(self, method: str, url: str, data: dict, files: aiohttp.FormData): + if self.data_type != "form": + header = { + "Content-Type": "application/json", + } + else: + header = {} + cookie = {} + params = {} + + if self.auth is not None and "type" in self.auth: + if self.auth["type"] == "header": + header.update(self.auth["args"]) + elif self.auth["type"] == "cookie": + cookie.update(self.auth["args"]) + elif self.auth["type"] == "params": + params.update(self.auth["args"]) + elif self.auth["type"] == "oidc": + header.update({ + "access-token": PluginTokenManager.get_plugin_token( + self.auth["domain"], + self.session_id, + self.auth["access_token_url"], + int(self.auth["token_expire_time"]) + ) + }) + + if method == "GET": + params.update(data) + return self.session.get(self.server + url, params=params, headers=header, cookies=cookie, + timeout=self.params_obj.timeout) + elif method == "POST": + if self.data_type == "form": + form_data = files + for key, val in data.items(): + form_data.add_field(key, val) + return self.session.post(self.server + url, data=form_data, headers=header, cookies=cookie, + timeout=self.params_obj.timeout) + else: + return self.session.post(self.server + url, json=data, headers=header, cookies=cookie, + timeout=self.params_obj.timeout) + else: + raise NotImplementedError("Method not implemented.") + + def _check_data_type(self, spec: dict) -> dict: + if "application/json" in spec: + self.data_type = "json" + return spec["application/json"]["schema"] + if "x-www-form-urlencoded" in spec: + self.data_type = "form" + return spec["x-www-form-urlencoded"]["schema"] + if "multipart/form-data" in spec: + self.data_type = "form" + return spec["multipart/form-data"]["schema"] + else: + raise NotImplementedError("Data type not implemented.") + + def _file_to_lists(self, spec: Dict[str, Any]) -> aiohttp.FormData: + file_form = aiohttp.FormData() + + if self.params_obj.files is None: + return file_form + + file_names = [] + for file in self.params_obj.files: + file_names.append(Files.get_by_id(file)["name"]) + + file_spec = check_upload_file(spec, file_names) + selected_file = choose_file(file_names, file_spec, self.params_obj.question, self.params_obj.background, self.usage) + + for key, val in json.loads(selected_file).items(): + if isinstance(val, str): + file_form.add_field(key, open(Files.get_by_name(val)["path"], "rb"), filename=val) + else: + for item in val: + file_form.add_field(key, open(Files.get_by_name(item)["path"], "rb"), filename=item) + return file_form + + async def _call_api(self, method: str, url: str, spec: Tuple[str, str, dict]): + param_spec = {} + + if method == "POST": + if "requestBody" in spec[2]: + param_spec = self._check_data_type(spec[2]["requestBody"]["content"]) + elif method == "GET": + if "parameters" in spec[2]: + param_spec = APISanitizer.parameters_to_spec(spec[2]["parameters"]) + else: + raise NotImplementedError("HTTP method not implemented.") + + if param_spec != {}: + json_data = await Json().generate_json(self.params_obj.background, self.params_obj.question, param_spec) + else: + json_data = {} + + if "properties" in param_spec: + file_list = self._file_to_lists(param_spec["properties"]) + else: + file_list = [] + + logger.info(f"调用接口{url},请求数据为{json_data}") + session_context = await self._make_api_call(method, url, json_data, file_list) + async with session_context as response: + if response.status != 200: + response_data = "API发生错误:API返回状态码{}, 详细原因为{},附加信息为{}。".format(response.status, response.reason, await response.text()) + else: + response_data = await response.text() + + # 返回值只支持JSON的情况 + if "responses" in spec[2]: + response_schema = spec[2]["responses"]["content"]["application/json"]["schema"] + else: + response_schema = {} + logger.info(f"调用接口{url}, 结果为 {response_data}") + + result = APISanitizer.process_response_data(response_data, url, self.params_obj.question, self.usage, response_schema) + return result diff --git a/apps/scheduler/call/api/sanitizer.py b/apps/scheduler/call/api/sanitizer.py new file mode 100644 index 000000000..93f897649 --- /dev/null +++ b/apps/scheduler/call/api/sanitizer.py @@ -0,0 +1,111 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from typing import Union, Dict, Any, List +from untruncate_json import untrunc +import json + +from apps.llm import get_llm, get_message_model +from apps.scheduler.parse_json import parse_json + + +class APISanitizer: + """ + 对API返回值进行处理 + """ + + def __init__(self): + raise NotImplementedError("APISanitizer不可被实例化") + + @staticmethod + def parameters_to_spec(raw_schema: List[Dict[str, Any]]): + """ + 将OpenAPI中GET接口List形式的请求体Spec转换为JSON Schema + :param raw_schema: OpenAPI数据 + :return: 转换后的JSON Schema + """ + + schema = { + "type": "object", + "required": [], + "properties": {} + } + for item in raw_schema: + if item["required"]: + schema["required"].append(item["name"]) + schema["properties"][item["name"]] = {} + schema["properties"][item["name"]]["description"] = item["description"] + for key, val in item["schema"].items(): + schema["properties"][item["name"]][key] = val + return schema + + @staticmethod + def _process_response_schema(response_data: str, response_schema: Dict[str, Any]) -> str: + """ + 对API返回值进行逐个字段处理 + :param response_data: API返回值原始数据 + :param response_schema: API返回值JSON Schema + :return: 处理后的API返回值 + """ + + # 工具执行报错,此时为错误信息,不予处理 + try: + response_dict = json.loads(response_data) + except Exception: + return response_data + + # openapi里没有HTTP 200对应的Schema,不予处理 + if not response_schema: + return response_data + + return json.dumps(parse_json(response_dict, response_schema), ensure_ascii=False) + + + @staticmethod + def process_response_data(response_data: Union[str, None], url: str, question: str, usage: str, response_schema: Dict[str, Any]) -> Dict[str, Any]: + """ + 对返回值进行整体处理 + :param response_data: API返回值的原始Dict + :param url: API地址 + :param question: 用户调用API时的输入 + :param usage: API接口的描述信息 + :param response_schema: API返回值的JSON Schema + :return: 处理后的返回值,打包为{"output": "xxx", "message": "xxx"}形式 + """ + + # 如果结果太长,不使用大模型进行总结;否则使用大模型生成自然语言总结 + if response_data is None: + return { + "output": "", + "message": f"调用接口{url}成功,但返回值为空。" + } + + if len(response_data) > 4096: + response_data = response_data[:4096] + response_data = response_data[:response_data.rfind(",") - 1] + response_data = untrunc.complete(response_data) + + response_data = APISanitizer._process_response_schema(response_data, response_schema) + + llm = get_llm() + msg_cls = get_message_model(llm) + messages = [ + msg_cls(role="system", + content="你是一个智能助手,能根据用户提供的指令、工具描述信息与工具输出信息,生成自然语言总结信息。要求尽可能详细,不要漏掉关键信息。"), + msg_cls(role="user", content=f"""## 用户指令 + {question} + + ## 工具用途 + {usage} + + ## 工具输出Schema + {response_schema} + + ## 工具输出 + {response_data}""") + ] + result_summary = llm.invoke(messages, timeout=30) + + return { + "output": response_data, + "message": result_summary.content + } diff --git a/apps/scheduler/call/choice.py b/apps/scheduler/call/choice.py new file mode 100644 index 000000000..89ab1f7d2 --- /dev/null +++ b/apps/scheduler/call/choice.py @@ -0,0 +1,49 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from typing import Dict, Any, List, Union +from pydantic import Field + +from apps.scheduler.call.core import CoreCall, CallParams +from apps.scheduler.utils.consistency import Consistency + + +class ChoiceParams(CallParams): + """ + Choice工具所需的额外参数 + """ + instruction: str = Field(description="针对哪一个问题进行答案选择?") + choices: List[Dict[str, Any]] = Field(description="Choice工具所有可能的选项") + + +class Choice(CoreCall): + """ + Choice工具。用于大模型在多个选项中选择一个,并跳转到对应的Step。 + """ + name = "choice" + description = "选择工具,用于根据给定的上下文和问题,判断正确/错误,或从选项列表中选择最符合用户要求的一项。" + params_obj: ChoiceParams + + def __init__(self, params: Dict[str, Any]): + """ + 初始化Choice工具,解析参数。 + :param params: Choice工具所需的参数 + """ + self.params_obj = ChoiceParams(**params) + + async def call(self, fixed_params: Union[Dict[str, Any], None] = None) -> Dict[str, Any]: + """ + 调用Choice工具。 + :param fixed_params: 经用户修正过的参数(暂未使用) + :return: Choice工具的输出信息。包含下一个Step的名称、自然语言解释等。 + """ + result = await Consistency().consistency( + instruction=self.params_obj.instruction, + background=self.params_obj.background, + data=self.params_obj.previous_data, + choices=self.params_obj.choices + ) + return { + "output": result, + "next_step": result, + "message": f"针对“{self.params_obj.instruction}”,作出的选择为:{result}。" + } diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py new file mode 100644 index 000000000..cc0f4f687 --- /dev/null +++ b/apps/scheduler/call/core.py @@ -0,0 +1,55 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 基础工具类 + +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional, Union +from pydantic import BaseModel, Field +import logging + + +logger = logging.getLogger('gunicorn.error') + +class CallParams(BaseModel): + """ + 所有Call都需要接受的参数。包含用户输入、上下文信息、上一个Step的输出等 + """ + background: str = Field(description="上下文信息") + question: str = Field(description="改写后的用户输入") + files: Optional[List[str]] = Field(description="用户询问该问题时上传的文件") + previous_data: Optional[Dict[str, Any]] = Field(description="Executor中上一个工具的结构化数据") + session_id: Optional[str] = Field(description="用户 user_sub", default="") + + +class CoreCall(ABC): + """ + Call抽象类。所有Call必须继承此类,并实现所有方法。 + """ + + # 工具名字 + name: str = "" + # 工具描述 + description: str = "" + # 工具的参数对象 + params_obj: CallParams + + @abstractmethod + def __init__(self, params: Dict[str, Any]): + """ + 初始化Call,并对参数进行解析。 + :param params: Call所需的参数。目前由Executor直接填充。后续可以借助LLM能力进行补全。 + """ + # 使用此种方式进行params校验 + self.params_obj = CallParams(**params) + raise NotImplementedError + + @abstractmethod + async def call(self, fixed_params: Union[Dict[str, Any], None] = None) -> Dict[str, Any]: + """ + 运行Call。 + :param fixed_params: 经用户修正后的参数。当前未使用,后续用户可对参数动态修改时使用。 + :return: Dict类型的数据。返回值中"output"为工具的原始返回信息(有格式字符串);"message"为工具经LLM处理后的返回信息(字符串)。也可以带有其他字段,其他字段将起到额外的说明和信息传递作用。 + """ + return { + "message": "", + "output": "" + } diff --git a/apps/scheduler/call/extract.py b/apps/scheduler/call/extract.py new file mode 100644 index 000000000..ae50a538a --- /dev/null +++ b/apps/scheduler/call/extract.py @@ -0,0 +1,57 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# JSON字段提取 + +from typing import List, Dict, Any, Union +from pydantic import Field +import json + +from apps.scheduler.call.core import CoreCall, CallParams + + +class ExtractParams(CallParams): + """ + 校验Extract Call需要的额外参数 + """ + keys: List[str] = Field(description="待提取的JSON字段名称") + + +class Extract(CoreCall): + """ + Extract 工具,用于从前一个工具的原始输出中提取指定字段 + """ + + name: str = "extract" + description: str = "从上一步的工具的原始JSON返回结果中,提取特定字段的信息。" + params_obj: ExtractParams + + def __init__(self, params: Dict[str, Any]): + self.params_obj = ExtractParams(**params) + + async def call(self, fixed_params: Union[Dict[str, Any], None] = None) -> Dict[str, Any]: + """ + 调用Extract工具 + :param fixed_params: 经用户确认后的参数(目前未使用) + :return: 提取出的字段 + """ + + if len(self.params_obj.keys) == 0: + raise ValueError("提供的JSON字段Key不能为空!") + + self.params_obj.previous_data = self.params_obj.previous_data["data"]["output"] + + # 根据用户给定的key,找到指定字段 + message_dict = {} + for key in self.params_obj.keys: + key_split = key.split(".") + current_dict = self.params_obj.previous_data + if isinstance(current_dict, str): + current_dict = json.loads(current_dict) + for dict_key in key_split: + current_dict = current_dict[dict_key] + message_dict[key_split[-1]] = current_dict + + return { + "message": json.dumps(message_dict, ensure_ascii=False), + # 临时将Output字段的类型设置为string,后续应统一改为dict + "output": json.dumps(message_dict, ensure_ascii=False) + } diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py new file mode 100644 index 000000000..76a712ef0 --- /dev/null +++ b/apps/scheduler/call/llm.py @@ -0,0 +1,73 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 工具:大模型处理 + +from __future__ import annotations + +from datetime import datetime +from typing import Dict, Any +import json + +from apps.scheduler.call.core import CoreCall, CallParams +from apps.llm import get_llm, get_message_model +from apps.scheduler.encoder import JSONSerializer + +from pydantic import Field +import pytz +from langchain_openai import ChatOpenAI +from sparkai.llm.llm import ChatSparkLLM + + +class LLMParams(CallParams): + temperature: float = Field(description="大模型温度设置", default=1.0) + system_prompt: str = Field(description="大模型系统提示词", default="你是一个乐于助人的助手。") + user_prompt: str = Field( + description="大模型用户提示词", + default=r"""{question} + + 工具信息: + {data} + + 附加信息: + 当前的时间为{time}。{context} + """) + timeout: int = Field(description="超时时间", default=30) + + +class LLM(CoreCall): + name = "llm" + description = "大模型调用工具,用于以指定的提示词和上下文信息调用大模型,并获得输出。" + + model: ChatOpenAI | ChatSparkLLM + params_obj: LLMParams + + def __init__(self, params: Dict[str, Any]): + self.model = get_llm() + self.message_class = get_message_model(self.model) + + self.params_obj = LLMParams(**params) + + async def call(self, fixed_params: Dict[str, Any] | None = None) -> Dict[str, Any]: + if fixed_params is not None: + self.params_obj = LLMParams(**fixed_params) + + # 参数 + time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + formatter = { + "time": time, + "context": self.params_obj.background, + "question": self.params_obj.question, + "data": self.params_obj.previous_data, + } + + timeout = self.params_obj.timeout + message = [ + self.message_class(role="system", content=self.params_obj.system_prompt.format(**formatter)), + self.message_class(role="user", content=self.params_obj.user_prompt.format(**formatter)), + ] + + result = self.model.invoke(message, timeout=timeout) + + return { + "output": result.content, + "message": "已成功调用大模型,对之前步骤的输出数据进行了处理", + } diff --git a/apps/scheduler/call/render/__init__.py b/apps/scheduler/call/render/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/scheduler/call/render/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/scheduler/call/render/option.json b/apps/scheduler/call/render/option.json new file mode 100644 index 000000000..7cacab150 --- /dev/null +++ b/apps/scheduler/call/render/option.json @@ -0,0 +1,20 @@ +{ + "tooltip": {}, + "legend": {}, + "dataset": { + "source": [] + }, + "xAxis": { + "type": "category", + "axisTick": { + "alignWithLabel": false + } + }, + "yAxis": { + "type": "value", + "axisTick": { + "alignWithLabel": false + } + }, + "series": [] +} \ No newline at end of file diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/render/render.py new file mode 100644 index 000000000..ac0666f65 --- /dev/null +++ b/apps/scheduler/call/render/render.py @@ -0,0 +1,118 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import json +import os +from typing import Dict, Any, List + +from apps.scheduler.call.core import CoreCall, CallParams +from apps.scheduler.encoder import JSONSerializer +from apps.scheduler.call.render.style import RenderStyle + + +class Render(CoreCall): + """ + Render Call,用于将SQL Tool查询出的数据转换为图表 + """ + + name = "render" + description = "渲染图表工具,可将给定的数据绘制为图表。" + params_obj: CallParams + + option_template: Dict[str, Any] + + def __init__(self, params: Dict[str, Any]): + """ + 初始化Render Call,校验参数,读取option模板 + :param params: Render Call参数 + """ + self.params_obj = CallParams(**params) + + option_location = os.path.join(os.path.dirname(os.path.realpath(__file__)), "option.json") + self.option_template = json.load(open(option_location, "r", encoding="utf-8")) + + async def call(self, fixed_params: Dict[str, Any] | None = None): + if fixed_params is not None: + self.params_obj = CallParams(**fixed_params) + + # 检测前一个工具是否为SQL + data = self.params_obj.previous_data + if data["type"] != "sql": + return { + "output": "", + "message": "图表生成失败!Render必须在SQL后调用!" + } + data = json.loads(data["data"]["output"]) + + # 判断数据格式是否满足要求 + # 样例:[{'openeuler_version': 'openEuler-22.03-LTS-SP2', '软件数量': 10}] + malformed = True + if isinstance(data, list): + if len(data) > 0 and isinstance(data[0], dict): + malformed = False + + # 将执行SQL工具查询到的数据转换为特定格式 + if malformed: + return { + "output": "", + "message": "SQL未查询到数据,或数据格式错误,无法生成图表!" + } + + # 对少量数据进行处理 + column_num = len(data[0]) - 1 + if column_num == 0: + data = Render._separate_key_value(data) + column_num = 1 + + # 该格式满足ECharts DataSource要求,与option模板进行拼接 + self.option_template["dataset"]["source"] = data + + llm_output = await RenderStyle().generate_option(self.params_obj.question) + add_style = "" + if "add" in llm_output: + add_style = llm_output["add"] + + self._parse_options(column_num, llm_output["style"], add_style, llm_output["scale"]) + + return { + "output": json.dumps(self.option_template, cls=JSONSerializer), + "message": "图表生成成功!图表将使用外置工具进行展示。" + } + + @staticmethod + def _separate_key_value(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 若数据只有一组(例如:{"aaa": "bbb"}),则分离键值对。 + 样例:{"type": "aaa", "value": "bbb"} + :param data: 待分离的数据 + :return: 分离后的数据 + """ + result = [] + for item in data: + for key, val in item.items(): + result.append({"type": key, "value": val}) + return result + + def _parse_options(self, column_num: int, graph_style: str, additional_style: str, scale_style: str): + series_template = {} + + if graph_style == "line": + series_template["type"] = "line" + elif graph_style == "scatter": + series_template["type"] = "scatter" + elif graph_style == "pie": + column_num = 1 + series_template["type"] = "pie" + if additional_style == "ring": + series_template["radius"] = ["40%", "70%"] + else: + series_template["type"] = "bar" + if additional_style == "stacked": + series_template["stack"] = "total" + + if scale_style == "log": + self.option_template["yAxis"]["type"] = "log" + + for i in range(column_num): + self.option_template["series"].append(series_template) diff --git a/apps/scheduler/call/render/style.py b/apps/scheduler/call/render/style.py new file mode 100644 index 000000000..2f225952f --- /dev/null +++ b/apps/scheduler/call/render/style.py @@ -0,0 +1,169 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from __future__ import annotations +from typing import Dict, Any +import asyncio + +import sglang +import openai + +from apps.llm import get_scheduler, create_vllm_stream, stream_to_str +from apps.common.thread import ProcessThreadPool + + +class RenderStyle: + system_prompt = """You are a helpful assistant. Help the user make style choices when drawing a chart. +Chart title should be short and less than 3 words. + +Available styles: +- `bar`: Bar graph +- `pie`: Pie graph +- `line`: Line graph +- `scatter`: Scatter graph + +Available bar graph styles: +- `normal`: Normal bar graph +- `stacked`: Stacked bar graph + +Available pie graph styles: +- `normal`: Normal pie graph +- `ring`: Ring pie graph + +Available scale styles: +- `linear`: Linear scale +- `log`: Logarithmic scale + +Here are some examples: + +EXAMPLE + +## Question + +查询数据库中的数据,并绘制堆叠柱状图。 + +## Thought + +Let's think step by step. The user requires drawing a stacked bar chart, so the chart type should be `bar`, \ +i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. + +## Answer + +The chart style should be: bar +The bar graph style should be: stacked + +END OF EXAMPLE +""" + user_prompt = """## Question + +{question} + +## Thought +""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + @staticmethod + @sglang.function + def _generate_option_sglang(s, system_prompt: str, user_prompt: str, question: str): + s += sglang.system(system_prompt) + s += sglang.user(user_prompt.format(question=question)) + + s += sglang.assistant_begin() + s += "Let's think step by step:\n" + for i in range(3): + s += f"{i}. " + sglang.gen(max_tokens=200, stop="\n") + "\n" + + s += "## Answer\n\n" + s += "The chart style should be: " + sglang.gen(choices=["bar", "scatter", "line", "pie"], name="style") + "\n" + if s["style"] == "bar": + s += "The bar graph style should be: " + sglang.gen(choices=["normal", "stacked"], name="add") + "\n" + # 饼图只对第一列有效 + elif s["style"] == "pie": + s += "The pie graph style should be: " + sglang.gen(choices=["normal", "ring"], name="add") + "\n" + s += "The scale style should be: " + sglang.gen(choices=["linear", "log"], name="scale") + s += sglang.assistant_end() + + async def _generate_option_vllm(self, backend: openai.AsyncOpenAI, question: str) -> Dict[str, Any]: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format(question=question)}, + ] + + stream = await create_vllm_stream(backend, messages, max_tokens=200, extra_body={ + "guided_regex": r"## Answer\n\nThe chart style should be: (bar|pie|line|scatter)\n" + }) + result = await stream_to_str(stream) + + result_dict = {} + if "bar" in result: + result_dict["style"] = "bar" + messages += [ + {"role": "assistant", "content": result}, + ] + stream = await create_vllm_stream(backend, messages, max_tokens=200, extra_body={ + "guided_regex": r"The bar graph style should be: (normal|stacked)\n" + }) + result = await stream_to_str(stream) + if "normal" in result: + result_dict["add"] = "normal" + elif "stacked" in result: + result_dict["add"] = "stacked" + messages += [ + {"role": "assistant", "content": result}, + ] + elif "pie" in result: + result_dict["style"] = "pie" + messages += [ + {"role": "assistant", "content": result}, + ] + stream = await create_vllm_stream(backend, messages, max_tokens=200, extra_body={ + "guided_regex": r"The pie graph style should be: (normal|ring)\n" + }) + result = await stream_to_str(stream) + if "normal" in result: + result_dict["add"] = "normal" + elif "ring" in result: + result_dict["add"] = "ring" + messages += [ + {"role": "assistant", "content": result}, + ] + elif "line" in result: + result_dict["style"] = "line" + elif "scatter" in result: + result_dict["style"] = "scatter" + + stream = await create_vllm_stream(backend, messages, max_tokens=200, extra_body={ + "guided_regex": r"The scale style should be: (linear|log)\n" + }) + result = await stream_to_str(stream) + if "linear" in result: + result_dict["scale"] = "linear" + elif "log" in result: + result_dict["scale"] = "log" + + return result_dict + + async def generate_option(self, question: str) -> Dict[str, Any]: + backend = get_scheduler() + if isinstance(backend, sglang.RuntimeEndpoint): + state_future = ProcessThreadPool().thread_executor.submit( + RenderStyle._generate_option_sglang.run, + question=question, + system_prompt=self.system_prompt, + user_prompt=self.user_prompt + ) + state = await asyncio.wrap_future(state_future) + result_dict = { + "style": state["style"], + "scale": state["scale"], + } + if state["style"] == "bar" or state["style"] == "pie": + result_dict["add"] = state["add"] + + return result_dict + + else: + return await self._generate_option_vllm(backend, question) \ No newline at end of file diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql.py new file mode 100644 index 000000000..cd97822f8 --- /dev/null +++ b/apps/scheduler/call/sql.py @@ -0,0 +1,83 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from typing import Dict, Any, Union +import aiohttp +import json +from sqlalchemy import create_engine, Engine, text +import logging + +from apps.scheduler.call.core import CoreCall, CallParams +from apps.common.config import config +from apps.scheduler.encoder import JSONSerializer + + +logger = logging.getLogger('gunicorn.error') + + +class SQL(CoreCall): + """ + SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。 + """ + + name: str = "sql" + description: str = "SQL工具,用于查询数据库中的结构化数据" + params_obj: CallParams + + session: aiohttp.ClientSession + engine: Engine + + def __init__(self, params: Dict[str, Any]): + """ + 初始化SQL工具。 + 解析SQL工具参数,拼接PostgreSQL连接字符串,创建SQLAlchemy Engine。 + :param params: SQL工具需要的参数。 + """ + self.session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) + self.params_obj = CallParams(**params) + + db_url = f'postgresql+psycopg2://{config["POSTGRES_USER"]}:{config["POSTGRES_PWD"]}@{config["POSTGRES_HOST"]}/{config["POSTGRES_DATABASE"]}' + self.engine = create_engine(db_url, pool_size=20, max_overflow=80, pool_recycle=300, pool_pre_ping=True) + + + async def call(self, fixed_params: Union[Dict[str, Any], None] = None) -> Dict[str, Any]: + """ + 运行SQL工具。 + 访问Chat2DB工具API,拿到针对用户输入的最多5条SQL语句。依次尝试每一条语句,直到查询出数据或全部不可用。 + :param fixed_params: 经用户确认后的参数(目前未使用) + :return: 从数据库中查询得到的数据,或报错信息 + """ + post_data = { + "question": self.params_obj.question, + "topk_sql": 5, + "use_llm_enhancements": True + } + headers = { + "Content-Type": "application/json" + } + + async with self.session.post(config["SQL_URL"], ssl=False, json=post_data, headers=headers) as response: + if response.status != 200: + return { + "output": "", + "message": "SQL查询错误:API返回状态码{}, 详细原因为{},附加信息为{}。".format(response.status, response.reason, await response.text()) + } + else: + result = json.loads(await response.text()) + logger.info(f"SQL工具返回的信息为:{result}") + + await self.session.close() + for item in result["sql_list"]: + try: + with self.engine.connect() as connection: + db_result = connection.execute(text(item["sql"])).all() + dataset_list = [] + for db_item in db_result: + dataset_list.append(db_item._asdict()) + return { + "output": json.dumps(dataset_list, cls=JSONSerializer, ensure_ascii=False), + "message": "数据库查询成功!" + } + except Exception: + continue + + raise ValueError("数据库查询出现错误!") diff --git a/apps/scheduler/core.py b/apps/scheduler/core.py new file mode 100644 index 000000000..db4932a1f --- /dev/null +++ b/apps/scheduler/core.py @@ -0,0 +1,49 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Executor基础类 + +from abc import ABC, abstractmethod +from typing import Any, List + +from pydantic import BaseModel, Field + + +class ExecutorParameters(BaseModel): + """ + 一个基础的Executor需要接受的参数 + """ + name: str = Field(..., description="Executor的名字") + question: str = Field(..., description="Executor所需的输入问题") + context: str = Field(..., description="Executor所需的上下文信息") + files: List[str] = Field(..., description="适用于该Executor") + + +class Executor(ABC): + """ + Executor抽象类,每一个Executor都需要继承此类并实现方法 + """ + + # Executor名称 + name: str = "" + # Executor描述 + description: str = "" + + # Executor保存LLM总结后的上下文,当前Call的原始输出 + context: str = "" + output: Any = None + + # 用户上传的文件ID + files: List[str] = [] + + @abstractmethod + def __init__(self, params: ExecutorParameters): + """ + 初始化Executor,并对参数进行解析和处理 + """ + raise NotImplementedError + + @abstractmethod + async def run(self): + """ + 运行Executor,返回最终结果(message)与最后一个Call的原始输出(output) + """ + raise NotImplementedError diff --git a/apps/scheduler/encoder.py b/apps/scheduler/encoder.py new file mode 100644 index 000000000..6805f505d --- /dev/null +++ b/apps/scheduler/encoder.py @@ -0,0 +1,29 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from json import JSONEncoder +import logging + +import numpy + + +logger = logging.getLogger('gunicorn.error') + + +class JSONSerializer(JSONEncoder): + """ + 自定义的JSON序列化方法。 + 当一个字段无法被序列化时,会使用掩码`[Data unable to represent in string]`替代 + """ + def default(self, o): + try: + if isinstance(o, numpy.integer): + return int(o) + elif isinstance(o, numpy.floating): + return float(o) + elif isinstance(o, numpy.ndarray): + return o.tolist() + result = JSONEncoder.default(self, o) + except TypeError as e: + logger.error(f"工具输出无法被序列化为字符串:{str(e)}") + result = "[Data unable to represent in string]" + return result diff --git a/apps/scheduler/executor/__init__.py b/apps/scheduler/executor/__init__.py new file mode 100644 index 000000000..d6c36304f --- /dev/null +++ b/apps/scheduler/executor/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.scheduler.executor.flow import FlowExecuteExecutor + +__all__ = [ + 'FlowExecuteExecutor' +] diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py new file mode 100644 index 000000000..941e32bdd --- /dev/null +++ b/apps/scheduler/executor/flow.py @@ -0,0 +1,178 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Flow执行Executor,动态构建 + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field +import logging +import traceback + +from apps.entities.plugin import Flow, Step +from apps.scheduler.core import Executor +from apps.scheduler.pool.pool import Pool +from apps.scheduler.utils import Summary, Evaluate, Reflect, BackProp + +logger = logging.getLogger('gunicorn.error') + + +class FlowExecutorInput(BaseModel): + name: str = Field(description="Flow的名称,格式为“插件名.工作流名”") + question: str = Field(description="Flow所需要的输入") + context: str = Field(description="Flow调用时的上下文信息") + files: Optional[List[str]] = Field(description="适用于当前Flow调用的用户文件名") + session_id: str = Field(description="当前的SessionID") + + +# 单个流的执行工具 +class FlowExecuteExecutor(Executor): + name: str + description: str + output: Dict[str, Any] + + question: str + origin_question: str + context: str + error: str = "当前输入的效果不佳" + + flow: Flow | None + files: List[str] | None + session_id: str + plugin: str + retry: int = 3 + + def __init__(self, params: Dict[str, Any]): + params_obj = FlowExecutorInput(**params) + # 指令与上下文 + self.question = params_obj.question + self.origin_question = params_obj.question + self.context = params_obj.context + self.files = params_obj.files + self.output = {} + self.session_id = params_obj.session_id + + # 名字与插件 + self.plugin, self.name = params_obj.name.split(".") + + # 载入对应的Flow全部信息和Step信息 + flow, flow_data = Pool().get_flow(name=self.name, plugin_name=self.plugin) + if flow is None or flow_data is None: + raise ValueError("Flow不合法!") + self.description = flow.description + self.plugin = flow.plugin + self.flow = flow_data + + # 运行流,返回各步骤经大模型总结后的内容,以及最后一步的工具原始输出 + async def run(self): + current_step: Step | None = self.flow.steps.get("start", None) + + stop_flag = False + while not stop_flag: + # 当前步骤不存在,结束执行 + if current_step is None or current_step.call_type == "none": + stop_flag = True + continue + + # 当步骤为end,最后一步 + if current_step.name == "end": + stop_flag = True + call_data, call_cls = Pool().get_call(current_step.call_type, self.plugin) + if call_data is None or call_cls is None: + yield "data: 尝试执行工具{}时发生错误:找不到该工具。\n\n".format(current_step.call_type) + stop_flag = True + continue + + # 向Call传递已知参数,Call完成参数生成 + call_param = current_step.params + call_param.update({ + "background": self.context, + "files": self.files, + "question": self.question, + "plugin": self.plugin, + "previous_data": self.output, + "session_id": self.session_id + }) + call_obj = call_cls(params=call_param) + + # 运行Call + yield "data: 正在调用{},请稍等...\n\n".format(current_step.call_type) + try: + result = await call_obj.call(fixed_params=call_param) + except Exception as e: + # 运行Call发生错误, + logger.error(msg="尝试使用工具{}时发生错误:{}".format(current_step.call_type, traceback.format_exc())) + self.error = str(e) + yield "data: " + "尝试使用工具{}时发生错误,任务无法继续执行。\n\n".format(current_step.call_type) + current_step = self.flow.on_error + continue + yield "data: 解析返回结果...\n\n" + + # 针对特殊Call进行特判 + if call_data.name == "choice": + # Choice选择了Step,直接跳转,不保存信息 + current_step = self.flow.steps.get(result["next_step"], None) + continue + else: + # 样例:{"type": "api", "data": {"message": "API返回值总结信息", "output": "API返回值原始数据(string)"}} + self.output["type"] = current_step.call_type + self.output["data"] = result + + # 需要进行打分的Call;执行Call完成后,进行打分 + if call_data.name in ["api",]: + score, reason = await Evaluate().generate_evaluation( + user_question=self.question, + tool_output=result, + tool_description=self.description + ) + + # 效果低于预期时,进行重试 + if score < 2.0 and self.retry > 0: + reflection = await Reflect().generate_reflect( + self.question, + { + "name": current_step.call_type, + "description": call_data.description + }, + call_input=call_param, + call_score_reason=reason + ) + + self.question = await BackProp().backprop( + user_input=self.question, + exception=self.error, + evaluation=reflection, + background=self.context + ) + + yield "data: 尝试执行{}时发生错误,正在尝试自我修正...\n\n".format(current_step.call_type) + self.retry -= 1 + if self.retry == 0: + yield "data: 调用{}失败,将使用模型能力作答。\n\n".format(current_step.call_type) + self.question = self.origin_question + current_step = self.flow.on_error + continue + yield "data: 生成摘要...\n\n" + # 默认行为:达到效果,或者达到最高重试次数,完成调用 + self.context = await Summary().generate_summary( + last_summary=self.context, + qa_pair=[ + self.origin_question, + result + ], + tool_info=[ + current_step.call_type, + call_data.description, + json.dumps(call_param, ensure_ascii=False) + ] + ) + self.question = self.origin_question + current_step = self.flow.steps.get(current_step.next, None) + + # 全部执行完成,输出最终结果 + flow_result = { + "message": self.context, + "output": self.output, + } + yield "final: " + json.dumps(flow_result, ensure_ascii=False) diff --git a/apps/scheduler/files.py b/apps/scheduler/files.py new file mode 100644 index 000000000..61544b750 --- /dev/null +++ b/apps/scheduler/files.py @@ -0,0 +1,138 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import json +import os +import threading +import time +from typing import Any, Dict, List + +import sglang + +from apps.common.config import config +from apps.scheduler.gen_json import gen_json +from apps.llm import get_scheduler + + +class Files: + mapping_lock = threading.Lock() + mapping: Dict[str, Dict[str, Any]] = {} + timeout: int = 24 * 60 * 60 + + def __init__(self): + raise RuntimeError("Files类不需要实例化") + + @classmethod + def add(cls, file_id: str, file_metadata: Dict[str, Any]): + cls.mapping_lock.acquire() + cls.mapping[file_id] = file_metadata + cls.mapping_lock.release() + + @classmethod + def _check_metadata(cls, file_id: str, metadata: Dict[str, Any]) -> bool: + if os.path.exists(os.path.join(config["TEMP_DIR"], metadata["path"])): + return True + else: + cls.mapping_lock.acquire() + cls.mapping.pop(file_id, None) + cls.mapping_lock.release() + return False + + """ + 样例: + { + "time": 1720840438.8062727, + "name": "test.txt", + "path": "/tmp/7fbe0b8f-ea8d-4ab9-a1cf-a4661bcd07bb.txt" + } + """ + @classmethod + def get_by_id(cls, file_id: str) -> Dict[str, Any] | None: + metadata = cls.mapping.get(file_id, None) + if metadata is None: + return None + else: + if cls._check_metadata(file_id, metadata): + return metadata + else: + return None + + @classmethod + def get_by_name(cls, file_name: str) -> Dict[str, Any] | None: + metadata = None + for key, val in cls.mapping.items(): + if file_name == val.get("name"): + metadata = val + + if metadata is None: + return None + else: + if cls._check_metadata(file_name, metadata): + return metadata + else: + return None + + @classmethod + def delete_old_files(cls): + cls.mapping_lock.acquire() + popped_key = [] + for key, val in cls.mapping.items(): + if time.time() - val["time"] >= cls.timeout: + popped_key.append(key) + continue + if not cls._check_metadata(key, val): + popped_key.append(key) + continue + for key in popped_key: + cls.mapping.pop(key) + cls.mapping_lock.release() + + +# 通过工具名称选择文件 +def choose_file(file_names: List[str], file_spec: dict, question: str, background: str, tool_usage: str): + def __choose_file(s): + s += sglang.system("""You are a helpful assistant who can select the files needed by the tool based on the tool's usage and the user's instruction. + + EXAMPLE + **Context:** + 此时为第一次调用工具,无上下文信息。 + + **Instruction:** + 帮我将上传的txt文件和Excel文件转换为Word文档 + + **Tool Usage:** + 获取用户上传文件,并将其转换为Word。 + + **Avaliable Files:** + ["1.txt", "log.txt", "sample.xlsx"] + + **Schema:** + {"type": "object", "properties": {"file_xlsx": {"type": "string", "pattern": "(1.txt|log.txt|sample.xlsx)"}, "file_txt": {"type": "array", "items": {"type": "string", "pattern": "(1.txt|log.txt|sample.xlsx)"}, "minItems": 1}}} + + Output: + {"file_xlsx": "sample.xlsx", "file_txt": ["1.txt", "log.txt"]}""") + s += sglang.user(f"""**Context:** + {background} + + **Instruction:** + {question} + + **Tool Usage:** + {tool_usage} + + **Available Files:** + {file_names} + + **Schema:** + {json.dumps(file_spec, ensure_ascii=False)}""") + + s += sglang.assistant("Output:\n" + sglang.gen(max_tokens=300, name="files", regex=gen_json(file_spec))) + + backend = get_scheduler() + if isinstance(backend, sglang.RuntimeEndpoint): + sglang.set_default_backend(backend) + + return sglang.function(__choose_file)()["files"] + else: + return [] diff --git a/apps/scheduler/gen_json.py b/apps/scheduler/gen_json.py new file mode 100644 index 000000000..b65df6307 --- /dev/null +++ b/apps/scheduler/gen_json.py @@ -0,0 +1,167 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +from typing import List + + +# 检查API是否需要上传文件;当前不支持文件上传嵌套进表单object内的情况 +def check_upload_file(schema: dict, available_files: List[str]) -> dict: + file_details = { + "type": "object", + "properties": {} + } + + pattern = "(" + for name in available_files: + pattern += name + "|" + pattern = pattern[:-1] + ")" + + for key, val in schema.items(): + if "format" in val and val["format"] == "binary": + file_details["properties"][key] = { + "type": "string", + "pattern": pattern + } + if val["type"] == "array": + if "format" in val["items"] and val["items"]["format"] == "binary": + file_details["properties"][key] = { + "type": "array", + "items": { + "type": "string", + "pattern": pattern + }, + "minItems": 1 + } + return file_details + + +# 处理字符串中的特殊字符 +def _process_string(string: str) -> str: + string = string.replace("$", r"\$") + string = string.replace("{", r"\{") + string = string.replace("}", r"\}") + # string = string.replace(".", r"\.") + string = string.replace("[", r"\[") + string = string.replace("]", r"\]") + string = string.replace("(", r"\(") + string = string.replace(")", r"\)") + string = string.replace("|", r"\|") + string = string.replace("?", r"\?") + string = string.replace("*", r"\*") + string = string.replace("+", r"\+") + string = string.replace("\\", "\\\\") + string = string.replace("^", r"\^") + return string + + +# 生成JSON正则字段;不支持主动$ref语法;不支持oneOf;allOf只支持1个schema的情况 +def gen_json(schema: dict) -> str: + if "anyOf" in schema: + regex = "(" + for item in schema["anyOf"]: + regex += gen_json(item) + regex += "|" + regex = regex.rstrip("|") + ")" + return regex + + if "allOf" in schema: + if len(schema["allOf"]) != 1: + raise NotImplementedError("allOf只支持1个schema的情况") + return gen_json(schema["allOf"][0]) + + if "enum" in schema: + choice_regex = "" + for item in schema["enum"]: + if schema["type"] == "boolean": + if item is True: + choice_regex += "true" + else: + choice_regex += "false" + elif schema["type"] == "string": + choice_regex += "\"" + _process_string(str(item)) + "\"" + else: + choice_regex += _process_string(str(item)) + choice_regex += "|" + return '(' + choice_regex.rstrip("|") + '),' + + if "pattern" in schema: + if schema["type"] == "string": + return "\"" + schema["pattern"] + "\"," + return schema["pattern"] + "," + + if "type" in schema: + # 布尔类型,例子:true + if schema["type"] == "boolean": + return r"(true|false)," + # 整数类型,例子:-100;最多支持9位 + if schema["type"] == "integer": + return r"[-\+]?[\d]{0,9}," + # 浮点数类型,例子:-1.2e+10;每一段数字最多支持9位 + if schema["type"] == "number": + return r"""[-\+]?[\d]{0,9}[.][\d]{0,9}(e[-\+]?[\d]{0,9})?,""" + # 字符串类型,例子:最小长度0,最大长度10 + if schema["type"] == "string": + regex = r'"([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])' + min_len = schema.get("minLength", 0) + regex += "{" + str(min_len) + if "maxLength" in schema: + if schema["maxLength"] < min_len: + raise ValueError("字符串最大长度不能小于最小长度") + regex += "," + str(schema["maxLength"]) + "}\"," + else: + regex += ",}\"," + return regex + # 数组 + if schema["type"] == "array": + min_len = schema.get("minItems", 0) + max_len = schema.get("maxItems", None) + if isinstance(max_len, int) and min_len > max_len: + raise ValueError("数组最大长度不能小于最小长度") + return _json_array(schema, min_len, max_len) + # 对象 + if schema["type"] == "object": + regex = _json_object(schema) + return regex + + +# 数组:暂时不支持PrefixItems;只支持数组中数据结构都一致的情况 +def _json_array(schema: dict, min_len: int, max_len: int | None) -> str: + if max_len is None: + num_repeats = rf"{{{max(min_len - 1, 0)},}}" + else: + num_repeats = rf"{{{max(min_len - 1, 0)},{max_len - 1}}}" + + item_regex = gen_json(schema["items"]).rstrip(",") + if not item_regex: + return "" + + regex = rf"\[(({item_regex})(,{item_regex}){num_repeats})?\]," + return regex + + +def _json_object(schema: dict) -> str: + if "required" in schema: + required = schema["required"] + else: + required = [] + + regex = r'\{' + + if "additionalProperties" in schema: + regex += gen_json({"type": "string"}) + "[ ]?:[ ]?" + gen_json(schema["additionalProperties"]) + + if "properties" in schema: + for key, val in schema["properties"].items(): + current_regex = gen_json(val) + if not current_regex: + continue + + regex += r'[ ]?"' + _process_string(key) + r'"[ ]?:[ ]?' + if key not in required: + regex += r"(null|" + current_regex.rstrip(",") + ")," + else: + regex += current_regex + + regex = regex.rstrip(",") + r'[ ]?\}' + return regex diff --git a/apps/scheduler/parse_json.py b/apps/scheduler/parse_json.py new file mode 100644 index 000000000..ab5c927d0 --- /dev/null +++ b/apps/scheduler/parse_json.py @@ -0,0 +1,55 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from datetime import datetime +import pytz +from typing import Any, Dict, Union + + +def parse_json(json_value: Any, spec_data: Dict[str, Any]): + """ + 使用递归的方式对JSON返回值进行处理 + :param json_value: 返回值中的字段 + :param spec_data: 返回值字段对应的JSON Schema + :return: 处理后的这部分返回值字段 + """ + + if "allOf" in spec_data: + processed_dict = {} + for item in spec_data["allOf"]: + processed_dict.update(parse_json(json_value, item)) + return processed_dict + + if "type" in spec_data: + if spec_data["type"] == "timestamp" and (isinstance(json_value, str) or isinstance(json_value, int)): + processed_timestamp = _process_timestamp(json_value) + return processed_timestamp + if spec_data["type"] == "array" and isinstance(json_value, list): + processed_list = [] + for item in json_value: + processed_list.append(parse_json(item, spec_data["items"])) + return processed_list + if spec_data["type"] == "object" and isinstance(json_value, dict): + processed_dict = {} + for key, val in json_value.items(): + if key not in spec_data["properties"]: + processed_dict[key] = val + continue + processed_dict[key] = parse_json(val, spec_data["properties"][key]) + return processed_dict + + return json_value + + +def _process_timestamp(timestamp_str: Union[str, int]) -> str: + """ + 将type为timestamp的字段转换为大模型可读的日期表示 + :param timestamp_str: 时间戳 + :return: 转换后的北京时间 + """ + try: + timestamp_int = int(timestamp_str) + except Exception: + return timestamp_str + + time = datetime.fromtimestamp(timestamp_int, tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + return time diff --git a/apps/scheduler/pool/__init__.py b/apps/scheduler/pool/__init__.py new file mode 100644 index 000000000..821dc0853 --- /dev/null +++ b/apps/scheduler/pool/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. diff --git a/apps/scheduler/pool/entities.py b/apps/scheduler/pool/entities.py new file mode 100644 index 000000000..dae18ac5b --- /dev/null +++ b/apps/scheduler/pool/entities.py @@ -0,0 +1,33 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from sqlalchemy import Column, Integer, String, LargeBinary +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class FlowItem(Base): + __tablename__ = 'flow' + id = Column(Integer, primary_key=True, autoincrement=True) + plugin = Column(String(length=100), nullable=False) + name = Column(String(length=100), nullable=False, unique=True) + description = Column(String(length=1500), nullable=False) + + +class PluginItem(Base): + __tablename__ = 'plugin' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(length=100), nullable=False, unique=True) + show_name = Column(String(length=100), nullable=False, unique=True) + description = Column(String(length=1500), nullable=False) + auth = Column(String(length=500), nullable=True) + spec = Column(LargeBinary, nullable=False) + signature = Column(String(length=100), nullable=False) + + +class CallItem(Base): + __tablename__ = 'call' + id = Column(Integer, primary_key=True, autoincrement=True) + plugin = Column(String(length=100), nullable=True) + name = Column(String(length=100), nullable=False) + description = Column(String(length=1500), nullable=False) diff --git a/apps/scheduler/pool/loader.py b/apps/scheduler/pool/loader.py new file mode 100644 index 000000000..5dfe0ce28 --- /dev/null +++ b/apps/scheduler/pool/loader.py @@ -0,0 +1,233 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import os +import sys +from typing import Dict, Any, List +import json +import importlib.util +import logging + +from apps.common.config import config +from apps.common.singleton import Singleton +from apps.entities.plugin import Flow, Step +from apps.scheduler.pool.pool import Pool +from apps.scheduler.call import exported + +import yaml +from langchain_community.agent_toolkits.openapi.spec import reduce_openapi_spec, ReducedOpenAPISpec + + +OPENAPI_FILENAME = "openapi.yaml" +METADATA_FILENAME = "plugin.json" +FLOW_DIR = "flows" +LIB_DIR = "lib" + +logger = logging.getLogger('gunicorn.error') + + +class PluginLoader: + """ + 载入单个插件的Loader。 + """ + plugin_location: str + plugin_name: str + + def __init__(self, name: str): + """ + 初始化Loader。 + 设置插件目录,随后遍历每一个 + """ + + self.plugin_location = os.path.join(config["PLUGIN_DIR"], name) + self.plugin_name = name + + metadata = self._load_metadata() + spec = self._load_openapi_spec() + Pool().add_plugin(name=name, spec=spec, metadata=metadata) + + if "automatic_flow" in metadata and metadata["automatic_flow"] is True: + flows = self._single_api_to_flow(spec) + else: + flows = [] + flows += self._load_flow() + Pool().add_flows(plugin=name, flows=flows) + + calls = self._load_lib() + Pool().add_calls(plugin=name, calls=calls) + + def _load_openapi_spec(self) -> ReducedOpenAPISpec | None: + spec_path = os.path.join(self.plugin_location, OPENAPI_FILENAME) + + if os.path.exists(spec_path): + spec = yaml.safe_load(open(spec_path, "r", encoding="utf-8")) + return reduce_openapi_spec(spec) + else: + return None + + def _load_metadata(self) -> Dict[str, Any]: + metadata_path = os.path.join(self.plugin_location, METADATA_FILENAME) + metadata = json.load(open(metadata_path, "r", encoding="utf-8")) + return metadata + + @staticmethod + def _single_api_to_flow(spec: ReducedOpenAPISpec | None = None) -> List[Dict[str, Any]]: + if not spec: + return [] + + flows = [] + for endpoint in spec.endpoints: + # 构造Step + step_dict = { + "start": Step( + name="start", + call_type="api", + params={ + "endpoint": endpoint[0] + }, + next="end" + ), + "end": Step( + name="end", + call_type="none" + ) + } + + # 构造Flow + flow = { + "name": endpoint[0], + "description": endpoint[1], + "data": Flow(steps=step_dict) + } + + flows.append(flow) + return flows + + def _load_flow(self) -> List[Dict[str, Any]]: + flow_path = os.path.join(self.plugin_location, FLOW_DIR) + flows = [] + if os.path.isdir(flow_path): + for item in os.listdir(flow_path): + current_flow_path = os.path.join(flow_path, item) + logger.info("载入Flow: {}".format(current_flow_path)) + + flow_yaml = yaml.safe_load(open(current_flow_path, "r", encoding="utf-8")) + + if "." in flow_yaml["name"]: + raise ValueError("Flow名称包含非法字符!") + + if "on_error" in flow_yaml: + error_step = Step(name="error", **flow_yaml["on_error"]) + else: + error_step = Step( + name="error", + call_type="llm", + params={ + "user_prompt": "当前工具执行发生错误,原始错误信息为:{data}. 请向用户展示错误信息,并给出可能的解决方案。\n\n背景信息:{context}" + } + ) + + steps = {} + for step in flow_yaml["steps"]: + steps[step["name"]] = Step(**step) + + if "next_flow" not in flow_yaml: + next_flow = None + else: + next_flow = flow_yaml["next_flow"] + + flows.append({ + "name": flow_yaml["name"], + "description": flow_yaml["description"], + "data": Flow(on_error=error_step, steps=steps, next_flow=next_flow), + }) + return flows + + def _load_lib(self) -> List[Any]: + lib_path = os.path.join(self.plugin_location, LIB_DIR) + if os.path.isdir(lib_path): + logger.info("载入Lib:{}".format(lib_path)) + # 插件lib载入到特定模块 + try: + spec = importlib.util.spec_from_file_location( + "apps.plugins." + self.plugin_name, + os.path.join(self.plugin_location, "lib") + ) + module = importlib.util.module_from_spec(spec) + sys.modules["apps.plugins." + self.plugin_name] = module + spec.loader.exec_module(module) + except Exception as e: + logger.info(msg=f"Failed to load plugin lib: {e}") + return [] + + # 注册模块所有工具 + calls = [] + for cls in sys.modules["apps.plugins." + self.plugin_name].exported: + try: + if self.check_user_class(cls): + calls.append(cls) + except Exception as e: + logger.info(msg=f"Failed to register tools: {e}") + continue + return calls + return [] + + @staticmethod + # 用户工具不强绑定父类,而是满足要求即可 + def check_user_class(cls) -> bool: + flag = True + + if not hasattr(cls, "name") or not isinstance(cls.name, str): + flag = False + if not hasattr(cls, "description") or not isinstance(cls.description, str): + flag = False + if not hasattr(cls, "spec") or not isinstance(cls.spec, dict): + flag = False + if not hasattr(cls, "__call__") or not callable(cls.__call__): + flag = False + + if not flag: + logger.info(msg="类{}不符合Call标准要求。".format(cls.__name__)) + + return flag + + +# 载入全部插件 +class Loader(metaclass=Singleton): + exclude_list: List[str] = [ + ".git", + "example" + ] + path: str = config["PLUGIN_DIR"] + + def __init__(self): + raise NotImplementedError("Loader无法被实例化") + + # 载入apps/scheduler/call下面的所有工具 + @classmethod + def load_predefined_call(cls): + calls = [] + for item in exported: + calls.append(item) + try: + Pool().add_calls(None, calls) + except Exception as e: + logger.info(msg=f"Failed to load predefined call: {str(e)}") + + # 首次初始化 + @classmethod + def init(cls): + cls.load_predefined_call() + for item in os.scandir(cls.path): + if item.is_dir() and item.name not in cls.exclude_list: + try: + PluginLoader(name=item.name) + except Exception as e: + logger.error(msg=f"Failed to load plugin: {str(e)}") + + # 后续热重载 + @classmethod + def reload(cls): + Pool().clean_db() + cls.init() diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py new file mode 100644 index 000000000..88e565f6d --- /dev/null +++ b/apps/scheduler/pool/pool.py @@ -0,0 +1,305 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +import hashlib +from threading import Lock +import pickle +import hmac +import json +from typing import Tuple, Dict, Any, List +import logging + +from sqlalchemy import create_engine, Engine, or_ +from sqlalchemy.orm import sessionmaker +from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec +import chromadb + +from apps.common.singleton import Singleton +from apps.scheduler.vector import VectorDB, DocumentWrapper +from apps.scheduler.pool.entities import Base +from apps.common.config import config +from apps.scheduler.pool.entities import PluginItem, FlowItem, CallItem +from apps.entities.plugin import PluginData +from apps.entities.plugin import Flow + + +logger = logging.getLogger('gunicorn.error') + + +class Pool(metaclass=Singleton): + write_lock: Lock = Lock() + relation_db: Engine + flow_collection: chromadb.Collection + plugin_collection: chromadb.Collection + + flow_pool: Dict[str, Any] = {} + call_pool: Dict[str, Any] = {} + + def __init__(self): + with self.write_lock: + # Init SQLite + self.relation_db = create_engine('sqlite:///:memory:') + Base.metadata.create_all(self.relation_db) + + # Init ChromaDB + self.create_collection() + + @staticmethod + def serialize_data(origin_data) -> Tuple[bytes, str]: + data = pickle.dumps(origin_data) + hmac_obj = hmac.new(key=bytes.fromhex(config["PICKLE_KEY"]), msg=data, digestmod=hashlib.sha256) + signature = hmac_obj.hexdigest() + return data, signature + + @staticmethod + def deserialize_data(data: bytes, signature: str): + hmac_obj = hmac.new(key=bytes.fromhex(config["PICKLE_KEY"]), msg=data, digestmod=hashlib.sha256) + current_signature = hmac_obj.hexdigest() + if current_signature != signature: + raise AssertionError("Pickle data has been modified!") + + return pickle.loads(data) + + def create_collection(self): + self.flow_collection = VectorDB.get_collection("flow") + self.plugin_collection = VectorDB.get_collection("plugin") + + def add_plugin(self, name: str, metadata: dict, spec: ReducedOpenAPISpec | None): + spec_data, signature = self.serialize_data(spec) + + if "auth" in metadata: + auth = json.dumps(metadata["auth"]) + else: + auth = "{}" + + plugin = PluginItem( + name=name, + show_name=metadata["name"], + description=metadata["description"], + auth=auth, + spec=spec_data, + signature=signature, + ) + with self.write_lock: + try: + with sessionmaker(bind=self.relation_db)() as session: + session.add(plugin) + session.commit() + except Exception as e: + logger.error(f"Import plugin failed: {str(e)}") + + doc = DocumentWrapper( + data=metadata["description"], + id=name + ) + + with self.write_lock: + VectorDB.add_docs(self.plugin_collection, [doc]) + + def add_flows(self, plugin: str, flows: List[Dict[str, Any]]): + docs = [] + flow_rows = [] + + # 此处,flow在向量数据库中名字加上了plugin前缀,防止ID冲突 + for item in flows: + current_row = FlowItem( + plugin=plugin, + name=item["name"], + description=item["description"] + ) + flow_rows.append(current_row) + + doc = DocumentWrapper( + id=plugin + "." + item["name"], + data=item["description"], + metadata={ + "plugin": plugin + } + ) + docs.append(doc) + with self.write_lock: + self.flow_pool[plugin + "." + item["name"]] = item["data"] + + with self.write_lock: + try: + with sessionmaker(bind=self.relation_db)() as session: + session.add_all(flow_rows) + session.commit() + except Exception as e: + logger.error(f"Import flow failed: {str(e)}") + VectorDB.add_docs(self.flow_collection, docs) + + def add_calls(self, plugin: str | None, calls: List[Any]): + call_metadata = [] + + for item in calls: + current_metadata = CallItem( + plugin=plugin, + name=item.name, + description=item.description + ) + call_metadata.append(current_metadata) + + with self.write_lock: + call_prefix = "" + if plugin is not None: + call_prefix += plugin + "." + self.call_pool[call_prefix + item.name] = item + + with self.write_lock: + with sessionmaker(bind=self.relation_db)() as session: + try: + session.add_all(call_metadata) + session.commit() + except Exception as e: + logger.error(f"Import plugin {plugin} call failed: {str(e)}") + + def clean_db(self): + try: + with self.write_lock: + Base.metadata.drop_all(bind=self.relation_db) + Base.metadata.create_all(bind=self.relation_db) + + VectorDB.delete_collection("flow") + VectorDB.delete_collection("plugin") + self.create_collection() + + self.flow_pool = {} + self.call_pool = {} + except Exception as e: + logger.error(f"Clean DB failed: {str(e)}") + + + def get_plugin_list(self) -> List[PluginData]: + plugin_list: List[PluginData] = [] + try: + with sessionmaker(bind=self.relation_db)() as session: + result = session.query(PluginItem).all() + except Exception as e: + logger.error(f"Get Plugin from DB failed: {str(e)}") + return [] + + for item in result: + plugin_list.append(PluginData( + id=item.name, + plugin_name=item.show_name, + plugin_description=item.description, + plugin_auth=json.loads(item.auth) + )) + + return plugin_list + + def get_flow(self, name: str, plugin_name: str) -> Tuple[FlowItem | None, Flow | None]: + # 查找Flow名对应的 信息和Step + if "." in name: + plugin, flow = name.split(".") + else: + plugin, flow = plugin_name, name + + try: + with sessionmaker(bind=self.relation_db)() as session: + result = session.query(FlowItem).filter_by(name=flow, plugin=plugin).first() + except Exception as e: + logger.error(f"Get Flow from DB failed: {str(e)}") + return None, None + + return result, self.flow_pool.get(plugin + "." + flow, None) + + + def get_plugin(self, name: str) -> PluginItem | None: + # 查找Plugin名对应的 信息 + try: + with sessionmaker(bind=self.relation_db)() as session: + result = session.query(PluginItem).filter_by(name=name).first() + except Exception as e: + logger.error(f"Get Plugin from DB failed: {str(e)}") + return None + + return result + + def get_k_plugins(self, question: str, top_k: int = 3): + result = self.plugin_collection.query( + query_texts=[question], + n_results=top_k + ) + + ids = result.get("ids", None) + if ids is None: + logger.error(f"Vector search failed: {result}") + return [] + + result_list = [] + with sessionmaker(bind=self.relation_db)() as session: + for current_id in ids[0]: + try: + result_item = session.query(PluginItem).filter_by(name=current_id).first() + if result_item is None: + continue + result_list.append(result_item) + except Exception as e: + logger.error(f"Get data from VectorDB failed: {str(e)}") + + return result_list + + def get_k_flows(self, question: str, plugin_list: List[str] | None = None, top_k: int = 3) -> List: + result = self.flow_collection.query( + query_texts=[question], + n_results=top_k, + where=Pool._construct_vector_query(plugin_list) + ) + + ids = result.get("ids", None) + if ids is None: + logger.error(f"Vector search failed: {result}") + return [] + + result_list = [] + with sessionmaker(bind=self.relation_db)() as session: + for current_id in ids[0]: + plugin_name, flow_name = current_id.split(".") + try: + result_item = session.query(FlowItem).filter_by(name=flow_name, plugin=plugin_name).first() + if result_item is None: + continue + result_list.append(result_item) + except Exception as e: + logger.error(f"Get data from VectorDB failed: {str(e)}") + + return result_list + + @staticmethod + def _construct_vector_query(plugin_list: List[str]) -> Dict[str, Any]: + constraint = {} + if len(plugin_list) == 0: + return {} + elif len(plugin_list) == 1: + constraint["plugin"] = { + "$eq": plugin_list[0] + } + else: + constraint["$or"] = [] + for plugin in plugin_list: + constraint["$or"].append({ + "plugin": { + "$eq": plugin + } + }) + return constraint + + + def get_call(self, name: str, plugin: str) -> Tuple[CallItem | None, Any]: + if "." not in name: + call_name = name + call_plugin = plugin + else: + call_name, call_plugin = name.split(".", 1) + + try: + with sessionmaker(bind=self.relation_db)() as session: + call_item = session.query(CallItem).filter_by(name=call_name).filter(or_(CallItem.plugin == call_plugin, CallItem.plugin == None)).first() + except Exception as e: + logger.error(f"Get Call from DB failed: {str(e)}") + return None, None + + return call_item, self.call_pool.get(name, None) diff --git a/apps/scheduler/scheduler.py b/apps/scheduler/scheduler.py new file mode 100644 index 000000000..8af8e29cb --- /dev/null +++ b/apps/scheduler/scheduler.py @@ -0,0 +1,206 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Agent调度器 + +from __future__ import annotations + +import json +from typing import List + +from apps.scheduler.executor import FlowExecuteExecutor +from apps.scheduler.pool.pool import Pool +from apps.scheduler.utils import Select, Recommend +from apps.llm import get_llm, get_message_model + + +MAX_RECOMMEND = 3 + + +class Scheduler: + """ + “调度器”,是最顶层的、控制Executor执行顺序和状态的逻辑。 + + 目前,Scheduler只会构造并执行1个Flow。后续可以改造为Router,用于连接多个Executor(Multi-Agent) + """ + + # 上下文 + context: str = "" + # 用户原始问题 + question: str + + def __init__(self): + raise NotImplementedError("Scheduler无法被实例化!") + + @staticmethod + async def choose_flow(question: str, user_selected_plugins: List[str]) -> str | None: + """ + 依据用户的输入和选择,构造对应的Flow。 + + - 当用户没有选择任何Plugin时,直接进行智能问答 + - 当用户选择Plugin时,挑选最适合的Flow + + :param question: 用户输入(用户问题) + :param user_selected_plugins: 用户选择的插件,可以一次选择多个 + :result: 经LLM选择的Flow Name + """ + + # 用户什么都不选,直接智能问答 + if len(user_selected_plugins) == 0: + return None + + # 自动识别:选择TopK插件 + elif len(user_selected_plugins) == 1 and user_selected_plugins[0] == "auto": + # 用户要求自动识别 + plugin_top = Pool().get_k_plugins(question) + # 聚合插件的Flow + plugin_top_list = [] + for plugin in plugin_top: + plugin_top_list.append(plugin.name) + + else: + # 用户指定了插件 + plugin_top_list = user_selected_plugins + + flows = Pool().get_k_flows(question, plugin_top_list, 2) + + # 使用大模型选择Top1 + flow_list = [] + for item in flows: + flow_list.append({ + "name": item.plugin + "." + item.name, + "description": item.description + }) + if len(user_selected_plugins) == 1 and user_selected_plugins[0] == "auto": + # 用户选择自动识别时,包含智能问答 + flow_list.append({ + "name": "KnowledgeBase", + "description": "回答上述工具无法直接进行解决的用户问题。" + }) + top_flow = await Select().top_flow(choice=flow_list, instruction=question) + + if top_flow == "KnowledgeBase": + return None + # 返回流的ID + return top_flow + + @staticmethod + async def run_certain_flow(context: str, question: str, user_selected_flow: str, session_id: str, files: List[str] | None): + """ + 构造FlowExecutor,并执行所选择的流 + + :param context: 上下文信息 + :param question: 用户输入(用户问题) + :param user_selected_flow: 用户所选择的Flow的Name + :param session_id: 当前用户的登录Session。目前用于部分插件鉴权,后续将用于Flow与用户交互过程中的暂停与恢复。 + :param files: 用户上传的文件的ID(暂未使用,后续配合LogGPT上传文件分析等需要文件的功能) + """ + flow_exec = FlowExecuteExecutor(params={ + "name": user_selected_flow, + "question": question, + "context": context, + "files": files, + "session_id": session_id + }) + + response = { + "message": "", + "output": {} + } + async for chunk in flow_exec.run(): + if "data" in chunk[:6]: + yield "data: " + json.dumps({"content": chunk[6:]}, ensure_ascii=False) + "\n\n" + else: + response = json.loads(chunk[7:]) + + # 返回自然语言结果和结构化数据结果 + llm = get_llm() + msg_cls = get_message_model(llm) + messages = [ + msg_cls(role="system", content="详细回答用户的问题,保留一切必要信息。工具输出中包含的Markdown代码段、Markdown表格等内容必须原封不动输出。"), + msg_cls(role="user", content=f"""## 用户问题 +{question} + +## 工具描述 +{flow_exec.description} + +## 工具输出 +{response}""") + ] + async for chunk in llm.astream(messages): + yield "data: " + json.dumps({"content": chunk.content}, ensure_ascii=False) + "\n\n" + + # 提取出最终的结构化信息 + # 样例:{"type": "api", "data": "API返回值原始数据(string)"} + structured_data = { + "type": response["output"]["type"], + "data": response["output"]["data"]["output"], + } + yield "data: " + json.dumps(structured_data, ensure_ascii=False) + "\n\n" + + @staticmethod + async def plan_next_flow(summary: str, current_flow_name: str | None, user_selected_plugins: List[str], question: str): + """ + 生成用户“下一步”Flow的推荐。 + + - 若Flow的配置文件中已定义`next_flow[]`字段,则直接使用该字段给定的值 + - 否则,使用LLM进行选择。将根据用户的插件选择情况限定范围 + + 选择“下一步”Flow后,根据当前Flow的执行结果和“下一步”Flow的描述,生成改写的或预测的问题。 + + :param summary: 上下文总结,包含当前Flow的执行结果。 + :param current_flow_name: 当前执行的Flow的Name,用于避免重复选择同一个Flow + :param user_selected_plugins: 用户选择的插件列表,用于限定推荐范围 + :param question: 用户当前Flow的问题输入 + :return: 列表,包含“下一步”Flow的Name和预测问题 + """ + if current_flow_name is not None: + # 是否有预定义的Flow关系?有就直接展示这些关系 + next_flow_data = [] + plugin_name, flow_name = current_flow_name.split(".") + _, current_flow_data = Pool().get_flow(flow_name, plugin_name) + predefined_next_flow_name = current_flow_data.next_flow + + if predefined_next_flow_name is not None: + result_num = 0 + # 最多只能有3个推荐Flow + for current_flow in predefined_next_flow_name: + result_num += 1 + if result_num > MAX_RECOMMEND: + break + # 从Pool中查找该Flow + flow_metadata, _ = Pool().get_flow(current_flow, plugin_name) + # 根据该Flow对应的Description,改写问题 + rewrite_question = await Recommend().recommend(action_description=flow_metadata.description, background=summary) + # 将改写后的问题与Flow名字的对应关系关联起来 + plugin_metadata = Pool().get_plugin(plugin_name) + next_flow_data.append({ + "id": plugin_name + "." + current_flow, + "name": plugin_metadata.show_name, + "question": rewrite_question + }) + + # 返回改写后的问题 + return next_flow_data + + # 没有预定义的Flow,走一次choose_flow + if len(user_selected_plugins) == 1 and user_selected_plugins[0] == "auto": + plugin_top = Pool().get_k_plugins(question) + user_selected_plugins = [] + for plugin in plugin_top: + user_selected_plugins.append(plugin.name) + + next_flow_data = [] + result = Pool().get_k_flows(question, user_selected_plugins) + for current_flow in result: + if current_flow.name == current_flow_name: + continue + + flow_metadata, _ = Pool().get_flow(current_flow.name, current_flow.plugin) + rewrite_question = await Recommend().recommend(action_description=flow_metadata.description, background=summary) + plugin_metadata = Pool().get_plugin(current_flow.plugin) + next_flow_data.append({ + "id": current_flow.plugin + "." + current_flow.name, + "name": plugin_metadata.show_name, + "question": rewrite_question + }) + + return next_flow_data diff --git a/apps/scheduler/utils/__init__.py b/apps/scheduler/utils/__init__.py new file mode 100644 index 000000000..8f3afc2fc --- /dev/null +++ b/apps/scheduler/utils/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.scheduler.utils.consistency import Consistency +from apps.scheduler.utils.evaluate import Evaluate +from apps.scheduler.utils.json import Json +from apps.scheduler.utils.reflect import Reflect +from apps.scheduler.utils.recommend import Recommend +from apps.scheduler.utils.select import Select +from apps.scheduler.utils.summary import Summary +from apps.scheduler.utils.backprop import BackProp + + +__all__ = [ + 'Consistency', + 'Evaluate', + 'Json', + 'Reflect', + 'Recommend', + 'Select', + 'Summary', +] diff --git a/apps/scheduler/utils/backprop.py b/apps/scheduler/utils/backprop.py new file mode 100644 index 000000000..40ede6e99 --- /dev/null +++ b/apps/scheduler/utils/backprop.py @@ -0,0 +1,47 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +from apps.llm import get_llm, get_message_model + + +class BackProp: + system_prompt: str = """根据提供的错误日志、评估结果和背景信息,优化原始用户输入内容。 + +要求: +1. 优化后的用户输入能够最大程度避免错误,同时其中的数据保持与原始用户输入一致。 +2. 不得编造数据。所有数据必须从原始用户输入和背景信息中获得。 +3. 优化后的用户输入应最大程度上保留原始用户输入中的所有信息,不要遗漏数据或细节。""" + user_prompt: str = """## 原始用户输入 +{user_input} + +## 错误日志 +{exception} + +## 评估结果 +{evaluation} + +## 背景信息 +{background} + +## 优化后的用户输入 +""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + async def backprop(self, user_input: str, exception: str, evaluation: str, background: str) -> str: + llm = get_llm() + msg_cls = get_message_model(llm) + messages = [ + msg_cls(role="system", content=self.system_prompt), + msg_cls(role="user", content=self.user_prompt.format( + user_input=user_input, exception=exception, evaluation=evaluation, background=background) + ) + ] + + result = llm.invoke(messages).content + return result diff --git a/apps/scheduler/utils/consistency.py b/apps/scheduler/utils/consistency.py new file mode 100644 index 000000000..554f04062 --- /dev/null +++ b/apps/scheduler/utils/consistency.py @@ -0,0 +1,159 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 使用大模型的随机化+投票方式选择最优答案 + +from __future__ import annotations +from typing import List, Dict, Any, Tuple +import asyncio +from collections import Counter + +import sglang +import openai + +from apps.common.thread import ProcessThreadPool +from apps.llm import get_scheduler, create_vllm_stream, stream_to_str + + +class Consistency: + system_prompt: str = """Your task is: choose the answer that best matches user instructions and contextual information. \ +The instruction and context information will be given in a certain format. Here are some examples: + +EXAMPLE +## Instruction + +用户是否询问了openEuler相关知识? + +## Context + +User asked whether iSula is better than Docker. iSula is a tool developed by the openEuler Community. iSula contains \ +features such as security-enhanced containers, performance optimizations and openEuler compatibility. + +## Choice + +The available choices are: + +- Yes +- No + +## Thought + +Let's think step by step. User mentioned 'iSula', which is a tool related to the openEuler Community. So the user \ +question is related to openEuler. + +## Answer + +Yes + +END OF EXAMPLE""" + user_prompt: str = """## Instruction + +{question} + +## Context + +{background} +Previous Output: {data} + +## Choice + +The available choices are: + +{choice_list} + +## Thought +""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + @staticmethod + def _choices_to_prompt(choices: List[Dict[str, Any]]) -> Tuple[str, List[str]]: + choices_prompt = "" + choice_str_list = [] + for choice in choices: + choices_prompt += "- {}: {}\n".format(choice["step"], choice["description"]) + choice_str_list.append(choice["step"]) + return choices_prompt, choice_str_list + + @staticmethod + @sglang.function + def _generate_consistency_sglang(s, system_prompt: str, user_prompt: str, instruction: str, + background: str, data: Dict[str, Any], choices: List[Dict[str, Any]], answer_num: int): + s += sglang.system(system_prompt) + + choice_prompt, choice_str_list = Consistency._choices_to_prompt(choices) + + s += sglang.user(user_prompt.format( + question=instruction, + background=background, + choice_list=choice_prompt, + data=data + )) + forks = s.fork(answer_num) + + for i, f in enumerate(forks): + f += sglang.assistant_begin() + f += "Let's think step by step. " + sglang.gen(max_tokens=512, stop="\n\n") + f += "\n\n## Answer\n\n" + sglang.gen(choices=choice_str_list, name="result") + f += sglang.assistant_end() + + result_list = [] + for item in forks: + result_list.append(item["result"]) + + s["major"] = result_list + + async def _generate_consistency_vllm(self, backend: openai.AsyncOpenAI, instruction: str, background: str, data: Dict[str, Any], choices: List[Dict[str, Any]], answer_num: int) -> List[str]: + choice_prompt, choice_str_list = Consistency._choices_to_prompt(choices) + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + question=instruction, + background=background, + data=data, + choice_list=choice_prompt + ) + "\nLet's think step by step."}, + ] + + result_list = [] + for i in range(answer_num): + message_branch = messages + stream = await create_vllm_stream(backend, message_branch, max_tokens=512, extra_body={}) + reasoning = await stream_to_str(stream) + message_branch += [ + {"role": "assistant", "content": reasoning}, + {"role": "user", "content": "## Answer\n\n"} + ] + + choice_regex = "(" + for choice in choice_str_list: + choice_regex += choice + "|" + choice_regex = choice_regex.rstrip("|") + ")" + + stream = await create_vllm_stream(backend, message_branch, max_tokens=16, extra_body={ + "guided_regex": choice_regex + }) + result_list.append(await stream_to_str(stream)) + + return result_list + + async def consistency(self, instruction: str, background: str, data: Dict[str, Any], choices: List[Dict[str, Any]], answer_num: int = 3) -> str: + backend = get_scheduler() + if isinstance(backend, openai.AsyncOpenAI): + result_list = await self._generate_consistency_vllm(backend, instruction, background, data, choices, answer_num) + else: + sglang.set_default_backend(backend) + state_future = ProcessThreadPool().thread_executor.submit( + Consistency._generate_consistency_sglang.run, + instruction=instruction, choices=choices, answer_num=answer_num, + system_prompt=self.system_prompt, user_prompt=self.user_prompt, + background=background, data=data + ) + state = await asyncio.wrap_future(state_future) + result_list = state["major"] + + count = Counter(result_list) + return count.most_common(1)[0][0] diff --git a/apps/scheduler/utils/evaluate.py b/apps/scheduler/utils/evaluate.py new file mode 100644 index 000000000..116515b33 --- /dev/null +++ b/apps/scheduler/utils/evaluate.py @@ -0,0 +1,127 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 使用大模型进行结果评价 + +from __future__ import annotations +from typing import Tuple + +import sglang +import openai + +from apps.llm import get_scheduler, create_vllm_stream, stream_to_str + + +class Evaluate: + system_prompt = """You are an expert evaluation system for a tool calling chatbot. +You are given the following information: +- a user query, and +- a tool output + +You may also be given a reference description to use for reference in your evaluation. + +Your job is to judge the relevance and correctness of the tool output. \ +Output a single score that represents a holistic evaluation. You must return your response in a line with only the score. \ +Do not return answers in any other format. On a separate line provide your reasoning for the score as well. + +Follow these guidelines for scoring: +- Your score has to be between 1 and 5, where 1 is the worst and 5 is the best. +- If the tool output is not relevant to the user query, you should give a score of 1. +- If the tool output is relevant but contains mistakes, you should give a score between 2 and 3. +- If the tool output is relevant and fully correct, you should give a score between 4 and 5. +- If 'error', code '500', 'failed' appeared in the tool output, it's more likely a mistake, you should give a score lower than 3. +- If 'success', code '200', 'succeed' appeared in the tool output, it's more likely a correct output, you should give a score higher than 4. + +Example response is given below: + +EXAMPLE +## Score + +4.0 + +## Reason + +The tool output is relevant to the user query, \ +but it made up the data for one field and didn't use the default value from the reference description. + +END OF EXAMPLE""" + user_prompt = """## User Query + +{user_question} + +## Tool Output + +{tool_output} + +## Reference Description + +{tool_description}""" + + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + @staticmethod + @sglang.function + def _generate_evaluation_sglang(s, system_prompt: str, user_prompt: str, user_question: str, tool_output: str, tool_description: str): + s += sglang.system(system_prompt) + + s += sglang.user(user_prompt.format( + user_question=user_question, + tool_output=tool_output, + tool_description=tool_description + )) + + s += sglang.assistant_begin() + s += "## Score\n\n" + sglang.gen(name="score", regex=r"[\d]\.[\d]") + "\n\n" + s += "## Reason\n\n" + sglang.gen(name="reason", max_tokens=500) + s += sglang.assistant_end() + + async def _generate_evaluation_vllm(self, backend: openai.AsyncOpenAI, user_question: str, + tool_output: str, tool_description: str) -> Tuple[float, str]: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + user_question=user_question, + tool_output=tool_output, + tool_description=tool_description + )} + ] + + stream = await create_vllm_stream(backend, messages, max_tokens=50, extra_body={ + "guided_regex": r"## Score\n\n[0-5].[0-9]" + }) + score = await stream_to_str(stream)[-3:] + + messages += [ + {"role": "assistant", "content": score}, + {"role": "user", "content": "## Reason\n\n"} + ] + + stream = await create_vllm_stream(backend, messages, max_tokens=500, extra_body={}) + reason = await stream_to_str(stream) + + return float(score), reason + + async def generate_evaluation(self, user_question: str, tool_output: str, tool_description: str) -> Tuple[float, str]: + backend = get_scheduler() + if isinstance(backend, sglang.RuntimeEndpoint): + sglang.set_default_backend(backend) + state = Evaluate._generate_evaluation_sglang.run( + system_prompt=self.system_prompt, + user_prompt=self.user_prompt, + user_question=user_question, + tool_output=tool_output, + tool_description=tool_description, + stream=True + ) + + reason = "" + async for chunk in state.text_async_iter(var_name="reason"): + reason += chunk + + score = float(state["score"]) + return score, reason + else: + return await self._generate_evaluation_vllm(backend, user_question, tool_output, tool_description) diff --git a/apps/scheduler/utils/json.py b/apps/scheduler/utils/json.py new file mode 100644 index 000000000..fcc6c6bf5 --- /dev/null +++ b/apps/scheduler/utils/json.py @@ -0,0 +1,227 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations +import json +from typing import Dict, Any +import logging +from datetime import datetime +import pytz +import re + +import sglang +import openai + +from apps.llm import get_scheduler, create_vllm_stream, stream_to_str, get_llm, get_message_model, get_json_code_block +from apps.scheduler.gen_json import gen_json + +logger = logging.getLogger('gunicorn.error') + + +class Json: + system_prompt = r"""You must call the following function one time to answer the given question. For each function call \ +return a valid json object with only function parameters. + +Output must be in { } XML tags. For example: +{"parameter_name": "value"} + +Requirements: +- Output as few parameters as possible, and avoid using optional parameters in the generated results unless necessary. +- If a parameter is not mentioned in the user's instruction, use its default value. +- If no default value is specified, use `0` for integers, `0.0` for numbers, `null` for strings, `[]` for arrays \ +and `{}` for objects. +- Don’t make up parameters. Values can only be obtained from given user input, background information, and JSON Schema. +- The example values are only used to demonstrate the data format. Do not fill the example values in the generated results. + +Here is an example: + +EXAMPLE +## Question +查询杭州天气信息 + +## Parameters JSON Schema + +```json +{"properties":{"city":{"type":"string","example":"London","default":"London","description":"City name."},"country":{"type":"string","example":"UK","description":"Optional parameter. If not set, auto-detection is performed."},"date":{"type":"string","example":"2024-09-01","description":"The date of the weather."},"meter":{"type":"integer","default":"c","description":"If the units are in Celsius, the value is \"c\"; if the units are in Fahrenheit, the value is \"f\".","enum":["c","f"]}},"required":["city","meter"]} +``` + +## Background Information + +Empty. + +## Current Time + +2024-09-02 10:00:00 + +## Thought + +The user needs to query the weather information of Hangzhou. According to the given JSON Schema, city and meter are required parameters. The user did not explicitly provide the query date, so date should be empty. The user is querying the weather in Hangzhou, so the value of city should be Hangzhou. The user did not specify the temperature unit type, so the default value "c" is used. + +## Result + +```json +{"city": "Hangzhou", "meter": "c"} +``` +END OF EXAMPLE""" + user_prompt = """## Question + +{question} + +## Parameters JSON Schema + +```json +{spec_data} +``` + +## Background Information + +{background} + +## Current Time +{time}""" + + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + @staticmethod + @sglang.function + def _generate_json_sglang(s, system_prompt: str, user_prompt: str, background: str, question: str, spec_regex: str, spec: str): + s += sglang.system(system_prompt) + s += sglang.user(user_prompt.format( + question=question, + spec_data=spec, + background=background, + time=datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + ) + "# Thought\n\n") + + s += sglang.assistant(sglang.gen(max_tokens=1000, temperature=0.5)) + s += sglang.user("## Result\n\n") + s += sglang.assistant("" + \ + sglang.gen(name="data", max_tokens=1000, regex=spec_regex, temperature=0.01) \ + + "") + + async def _generate_json_vllm(self, backend: openai.AsyncOpenAI, background: str, question: str, spec: str) -> str: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + question=question, + spec_data=spec, + background=background, + time=datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + ) + "# Thought\n\n"}, + {"role": "assistant", "content": ""}, + ] + + stream = await create_vllm_stream(backend, messages, max_tokens=1000, extra_body={ + "guided_json": spec, + "guided_decoding_backend": "lm-format-enforcer" + }) + + json_str = await stream_to_str(stream) + return json_str + + @staticmethod + def _remove_null_params(input_val): + if isinstance(input_val, dict): + new_dict = {} + for key, value in input_val.items(): + nested = Json._remove_null_params(value) + if isinstance(nested, bool) or isinstance(nested, int) or isinstance(nested, float): + new_dict[key] = nested + elif nested: + new_dict[key] = nested + return new_dict + elif isinstance(input_val, list): + new_list = [] + for v in input_val: + cleaned_v = Json._remove_null_params(v) + if cleaned_v: + new_list.append(cleaned_v) + if len(new_list) > 0: + return new_list + else: + return input_val + + @staticmethod + def _check_json_valid(spec: dict, json_data: dict): + pass + + async def generate_json(self, background: str, question: str, spec: dict) -> Dict[str, Any]: + spec_regex = gen_json(spec) + if not spec_regex: + spec_regex = "{}" + logger.info(f"JSON正则:{spec_regex}") + + if not background: + background = "Empty." + + llm = get_llm() + msg_cls = get_message_model(llm) + messages = [ + msg_cls(role="system", content="""## Role + +You are a assistant who generates API call parameters. Your task is generating API call parameters according to the JSON Schema and user input. +The call parameters must be in JSON format and must be wrapped in the following Markdown code block: + +```json +// Here are the generated JSON parameters. +``` + +## Requirements + +When generating, You must follow these requirements: + +1. Use as few parameters as possible. Optional parameters should be 'null' unless it's necessary. e.g. `{"search_key": null}` +2. The order of keys in the generated JSON data must be the same as the order in the JSON Schema. +3. Do not add comments, instructions, or other irrelevant text to the generated code block; +4. Don’t make up parameters, don’t assume parameters. The value of the parameter can only be obtained from the given user input, background information and JSON Schema. +5. Before generating JSON, give your thought of the given question. Be helpful and concise. +6. Output strictly in the format described by JSON Schema. +7. The examples are only used to demonstrate the data format. Do not use the examples directly in the generated results."""), + msg_cls(role="user", content=self.user_prompt.format( + question=question, + spec_data=spec, + background=background, + time=datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") + ) + "\n\nLet's think step by step.") + ] + + result = llm.invoke(messages).content + logger.info(f"生成的JSON参数为:{result}") + try: + result_str = get_json_code_block(result) + + if not re.match(spec_regex, result_str): + raise ValueError("JSON not valid.") + data = Json._remove_null_params(json.loads(result_str)) + + return data + except Exception as e: + logger.error(f"直接生成JSON失败:{e}") + + + backend = get_scheduler() + if isinstance(backend, sglang.RuntimeEndpoint): + sglang.set_default_backend(backend) + state = Json._generate_json_sglang.run( + system_prompt=self.system_prompt, + user_prompt=self.user_prompt, + background=background, + question=question, + spec_regex=spec_regex, + spec=spec + ) + + result = "" + async for chunk in state.text_async_iter(var_name="data"): + result += chunk + logger.info(f'Structured Output生成的参数为: {result}') + return Json._remove_null_params(json.loads(result)) + else: + spec_str = json.dumps(spec, ensure_ascii=False) + result = await self._generate_json_vllm(backend, background, question, spec_str) + logger.info(f"Structured Output生成的参数为:{result}") + return Json._remove_null_params(json.loads(result)) diff --git a/apps/scheduler/utils/recommend.py b/apps/scheduler/utils/recommend.py new file mode 100644 index 000000000..899ce7941 --- /dev/null +++ b/apps/scheduler/utils/recommend.py @@ -0,0 +1,59 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 使用大模型进行问题改写 + +from __future__ import annotations + +from apps.llm import get_llm, get_message_model + + +class Recommend: + system_prompt: str = """依照给出的工具描述和其他信息,生成符合用户目标的改写问题,或符合逻辑的预测问题。生成的问题将用于指导用户进行下一步的提问。 +要求: +1. 以用户身份进行问题生成。 +2. 工具描述的优先级高于用户问题或用户目标。当工具描述和用户问题相关性较小时,优先使用工具描述结合其他信息进行预测问题生成。 +3. 必须为疑问句或祈使句。 +4. 不得超过30个字。 +5. 不要输出任何额外信息。 + +下面是一组示例: + +EXAMPLE +## 工具描述 +查询天气数据 + +## 背景信息 +人类向AI询问杭州的著名旅游景点,大模型提供了杭州西湖、杭州钱塘江等多个著名景点的信息。 + +## 问题 +帮我查询今天的杭州天气数据 +END OF EXAMPLE""" + user_prompt: str = """ +## 工具描述 +{action_description} + +## 背景信息 +{background} + +## 问题 +""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + async def recommend(self, action_description: str, background: str = "Empty.") -> str: + llm = get_llm() + msg_cls = get_message_model(llm) + + messages = [ + msg_cls(role="system", content=self.system_prompt), + msg_cls(role="user", content=self.user_prompt.format( + action_description=action_description, + background=background + )) + ] + + result = llm.invoke(messages) + return result.content diff --git a/apps/scheduler/utils/reflect.py b/apps/scheduler/utils/reflect.py new file mode 100644 index 000000000..1d34dfe99 --- /dev/null +++ b/apps/scheduler/utils/reflect.py @@ -0,0 +1,125 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 使用大模型进行解析和改错 + +from __future__ import annotations +from typing import Dict, Any +import json + +import sglang +import openai + +from apps.llm import get_scheduler, create_vllm_stream, stream_to_str + + +class Reflect: + system_prompt = """You are an advanced reasoning agent that can improve based \ +on self-reflection. You will be given a previous reasoning trial in which you are given a task and a action. \ +You tried to accomplish the task with the certain action and generated input but failed. Your goal is to write a \ +few sentences to explain why your attempt is wrong, and write a guidance according to your explanation. \ +You will need this as guidance when you try again later. Only provide a few sentence description in your answer, \ +not the future action and inputs. + +Here are some examples: + +EXAMPLE +## Previous Trial Instruction + +查询机器192.168.100.1的CVE信息。 + +## Action + +使用的工具是 A-Ops.CVE,作用为:查询特定主机IP的全部CVE信息。 + +## Action Input + +```json +{"host_ip": "192.168.100.1", "num": 0} +``` + +## Observation + +采取该Action后,输出的信息为空,不符合用户的指令要求。这可能是由请求参数设置不正确导致的结果,也可能是Action本身存在问题,或机器中并不存在CVE。 + +## Guidance + +Action Input中,"num"字段被设置为了0。这个字段可能与最终显示的CVE条目数量有关。可以将该字段的值修改为100,再次尝试使用该接口。在获得有效的\ +CVE信息后我,将继续后续步骤。我将继续优化Action Input,以获得更多符合用户指令的结果。 + +END OF EXAMPLE""" + user_prompt = """## Previous Trial Instruction + +{instruction} + +## Action + +{call} + +## Action Input + +```json +{call_input} +``` + +## Observation + +{call_score_reason} + +## Guidance""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + @staticmethod + @sglang.function + def _generate_reflect_sglang(s, system_prompt: str, user_prompt: str, instruction: str, call: str, call_input: str, call_score_reason: str): + s += sglang.system(system_prompt) + s += sglang.user(user_prompt.format( + instruction=instruction, + call=call, + call_input=call_input, + call_score_reason=call_score_reason + )) + s += sglang.assistant(sglang.gen(name="result", max_tokens=1500)) + + async def _generate_reflect_vllm(self, backend: openai.AsyncOpenAI, instruction: str, + call: str, call_input: str, call_score_reason: str) -> str: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + instruction=instruction, + call_name=call, + call_input=call_input, + call_score_reason=call_score_reason + )}, + ] + + stream = create_vllm_stream(backend, messages, max_tokens=1500, extra_body={}) + return await stream_to_str(stream) + + async def generate_reflect(self, instruction: str, call: Dict[str, Any], call_input: Dict[str, Any], call_score_reason: str) -> str: + backend = get_scheduler() + call_str = "使用的工具是 {},作用为:{}".format(call["name"], call["description"]) + call_input_str = json.dumps(call_input, ensure_ascii=False) + + if isinstance(backend, sglang.RuntimeEndpoint): + sglang.set_default_backend(backend) + state = Reflect._generate_reflect_sglang.run( + system_prompt=self.system_prompt, + user_prompt=self.user_prompt, + instruction=instruction, + call=call_str, + call_input=call_input_str, + call_score_reason=call_score_reason, + stream=True + ) + + result = "" + async for chunk in state.text_async_iter(var_name="result"): + result += chunk + return result + + else: + return await self._generate_reflect_vllm(backend, instruction, call_str, call_input_str, call_score_reason) diff --git a/apps/scheduler/utils/select.py b/apps/scheduler/utils/select.py new file mode 100644 index 000000000..52055c99f --- /dev/null +++ b/apps/scheduler/utils/select.py @@ -0,0 +1,159 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 使用大模型选择Top N最匹配语义的项 + +from __future__ import annotations + +import asyncio +from typing import List, Any, Dict + +import sglang +import openai + +from apps.llm import create_vllm_stream, get_scheduler, stream_to_str +from apps.common.thread import ProcessThreadPool + + +class Select: + system_prompt = """Your task is: choose the tool that best matches user instructions and contextual information \ +based on the description of the tool. + +Tool name and its description will be given in the format: + +```xml + + + Tool Name + Tool Description + + +``` + +Here are some examples: + +EXAMPLE + +## Instruction + +使用天气API,查询明天的天气信息 + +## Tools + +```xml + + + API + 请求特定API,获得返回的JSON数据 + + + SQL + 查询数据库,获得table中的数据 + + +``` + +## Thinking + +Let's think step by step. There's no tool available to get weather forecast directly, so I need to try using other \ +tools to obtain weather information. API tools can retrieve external data through the use of APIs, and weather \ +information may be stored in external data. As the user instructions explicitly mentioned the use of the weather API, \ +the API tool should be prioritized. SQL tools are used to retrieve information from databases. Given the variable \ +and dynamic nature of weather data, it is unlikely to be stored in a database. Therefore, the priority of \ +SQL tools is relatively low. + +## Answer + +Thus the selected tool is: API. + +END OF EXAMPLE""" + user_prompt = """## Instruction + +{question} + +## Tools + +```xml +{tools} +``` + +## Thinking + +Let's think step by step.""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + @staticmethod + def _flows_to_xml(choice: List[Dict[str, Any]]) -> str: + result = "\n" + for tool in choice: + result += "\t\n\t\t{name}\n\t\t{description}\n\t\n".format( + name=tool["name"], description=tool["description"] + ) + result += "" + return result + + @staticmethod + @sglang.function + def _top_flows_sglang(s, system_prompt: str, user_prompt: str, choice: List[Dict[str, Any]], instruction: str): + s += sglang.system(system_prompt) + s += sglang.user(user_prompt.format( + question=instruction, + tools=Select._flows_to_xml(choice), + )) + s += sglang.assistant(sglang.gen(max_tokens=1500, stop="\n\n")) + s += sglang.user("\n\n##Answer\n\nThus the selected tool is: ") + s += sglang.assistant_begin() + + choice_list = [] + for item in choice: + choice_list.append(item["name"]) + s += sglang.gen(choices=choice_list, name="choice") + s += sglang.assistant_end() + + async def _top_flows_vllm(self, backend: openai.AsyncOpenAI, choice: List[Dict[str, Any]], instruction: str) -> str: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + question=instruction, + tools=Select._flows_to_xml(choice), + )} + ] + + stream = await create_vllm_stream(backend, messages, max_tokens=1500, extra_body={}) + result = await stream_to_str(stream) + + messages += [ + {"role": "assistant", "content": result}, + {"role": "user", "content": "## Answer\n\nThus the selected tool is: "} + ] + + choice_regex = "(" + for item in choice: + choice_regex += item["name"] + "|" + choice_regex = choice_regex.rstrip("|") + ")" + + stream = await create_vllm_stream(backend, messages, max_tokens=200, extra_body={ + "guided_regex": choice_regex + }) + result = await stream_to_str(stream) + + return result + + async def top_flow(self, choice: List[Dict[str, Any]], instruction: str) -> str: + backend = get_scheduler() + if isinstance(backend, sglang.RuntimeEndpoint): + sglang.set_default_backend(backend) + state_future = ProcessThreadPool().thread_executor.submit( + Select._top_flows_sglang.run, + system_prompt=self.system_prompt, + user_prompt=self.user_prompt, + choice=choice, + instruction=instruction + ) + state = await asyncio.wrap_future(state_future) + return state["choice"] + else: + return await self._top_flows_vllm(backend, choice, instruction) diff --git a/apps/scheduler/utils/summary.py b/apps/scheduler/utils/summary.py new file mode 100644 index 000000000..e5d25387a --- /dev/null +++ b/apps/scheduler/utils/summary.py @@ -0,0 +1,91 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 使用大模型生成对话总结 + +from __future__ import annotations +from typing import List + +from apps.llm import get_llm, get_message_model + + +class Summary: + system_prompt = """Progressively summarize the lines of conversation provided, adding onto the previous summary \ +returning a new summary. Summary should be less than 2000 words. Examples are given below. + +EXAMPLE +## Previous Summary + +人类询问AI有关openEuler容器平台应当使用哪个软件的问题,AI向其推荐了iSula安全容器平台。 + +## Conversations + +### User + +iSula有什么特点? + +### Assistant + +iSula 的特点如下: +轻量语言:C/C++,Rust on the way +北向接口:提供CRI接口,支持对接Kubernetes; 同时提供便捷使用的命令行 +南向接口:支持OCI runtime和镜像规范,支持平滑替换 +容器形态:支持系统容器、虚机容器等多种容器形态 +扩展能力:提供插件化架构,可根据用户需要开发定制化插件 + +### Used Tool + +- name: Search +- description: 查询关键字对应的openEuler产品简介。 +- output: `{"total":1,"data":["iSula是openEuler推出的一个安全容器运行平台。"]}` + +## Summary +人类询问AI有关openEuler容器平台应当使用哪个软件的问题,AI向其推荐了iSula安全容器平台。人类询问iSula有何特点,\ +AI使用Search工具搜索了“iSula”关键字,获得了1条搜索结果,即iSula的定义。AI列举了轻量语言、北向接口、\ +南向接口、容器形态、扩展能力五种特点。 + +END OF EXAMPLE""" + user_prompt = """## Previous Summary +{last_summary} + +## Conversations + +### User + +{user_question} + +### Assistant + +{llm_output} + +### Used Tool + +- name: {tool_name} +- description: {tool_description} +- output: `{tool_output}` + +## Summary +""" + + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None): + if system_prompt is not None: + self.system_prompt = system_prompt + if user_prompt is not None: + self.user_prompt = user_prompt + + async def generate_summary(self, last_summary: str, qa_pair: List[str], tool_info: List[str]) -> str: + llm = get_llm() + msg_cls = get_message_model(llm) + + messages = [ + msg_cls(role="system", content=self.system_prompt), + msg_cls(role="user", content=self.user_prompt.format( + last_summary=last_summary, + user_question=qa_pair[0], + llm_output=qa_pair[1], + tool_name=tool_info[0], + tool_description=tool_info[1], + tool_output=tool_info[2] + )) + ] + + result = llm.invoke(messages) + return result.content diff --git a/apps/scheduler/vector.py b/apps/scheduler/vector.py new file mode 100644 index 000000000..b17047ce9 --- /dev/null +++ b/apps/scheduler/vector.py @@ -0,0 +1,132 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from typing import List, Optional + +import chromadb +from chromadb import Documents, Embeddings, EmbeddingFunction, Collection +from pydantic import BaseModel, Field +import requests +import logging + +from apps.common.config import config + + +logger = logging.getLogger('gunicorn.error') + + +def get_embedding(text: List[str]): + """ + 访问Vectorize的Embedding API,获得向量化数据 + :param text: 待向量化文本(多条文本组成List) + :return: 文本对应的向量(顺序与text一致,也为List) + """ + + api = config["VECTORIZE_HOST"].rstrip("/") + "/embedding" + response = requests.post( + api, + json={"texts": text} + ) + + return response.json() + + +# 模块内部类,不应在模块外部使用 +class DocumentWrapper(BaseModel): + """ + 单个ChromaDB文档的结构 + """ + data: str = Field(description="文档内容") + id: str = Field(description="文档ID,用于确保唯一性") + metadata: Optional[dict] = Field(description="文档元数据", default=None) + + +class RAGEmbedding(EmbeddingFunction): + """ + ChromaDB用于进行文本向量化的函数 + """ + def __call__(self, input: Documents) -> Embeddings: + return get_embedding(input) + + +class VectorDB: + """ + ChromaDB单例 + """ + client: chromadb.ClientAPI = chromadb.Client() + + def __init__(self): + raise NotImplementedError("VectorDB不应被实例化") + + @classmethod + def get_collection(cls, collection_name: str) -> Collection: + """ + 创建并返回ChromaDB集合 + :param collection_name: 集合名称,字符串 + :return: ChromaDB集合对象 + """ + + try: + return cls.client.get_or_create_collection(collection_name, embedding_function=RAGEmbedding(), + metadata={"hnsw:space": "cosine"}) + except Exception as e: + logger.error(f"Get collection failed: {e}") + + @classmethod + def delete_collection(cls, collection_name: str): + """ + 删除ChromaDB集合 + :param collection_name: 集合名称,字符串 + :return: + """ + cls.client.delete_collection(collection_name) + + @classmethod + def add_docs(cls, collection: Collection, docs: List[DocumentWrapper]): + """ + 向ChromaDB集合中添加文档 + :param collection: ChromaDB集合对象 + :param docs: 待向量化的文档List + :return: + """ + + doc_list = [] + metadata_list = [] + id_list = [] + for doc in docs: + doc_list.append(doc.data) + id_list.append(doc.id) + metadata_list.append(doc.metadata) + + collection.add( + ids=id_list, + metadatas=metadata_list, + documents=doc_list + ) + + @classmethod + def get_docs(cls, collection: Collection, question: str, requirements: dict, num: int = 3) -> List[DocumentWrapper]: + """ + 根据输入,从ChromaDB中查询K个向量最相似的文档 + :param collection: ChromaDB集合对象 + :param question: 查询输入 + :param requirements: 查询过滤条件 + :param num: Top K中K的值 + :return: 文档List,包含文档内容、元数据、ID + """ + result = collection.query( + query_texts=[question], + where=requirements, + n_results=num, + include=["documents", "metadatas"] + ) + + item_list = [] + length = min(num, len(result["ids"])) + for i in range(length): + item_list.append(DocumentWrapper( + id=result["ids"][i], + metadata=result["metadatas"][i], + documents=result["documents"][i] + )) + + return item_list diff --git a/apps/service/__init__.py b/apps/service/__init__.py new file mode 100644 index 000000000..a280be295 --- /dev/null +++ b/apps/service/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.service.domain import Domain +from apps.service.activity import Activity +from apps.service.rag import RAG +from apps.service.history import History +from apps.service.suggestion import Suggestion +from apps.service.summary import ChatSummary + +__all__ = ["Domain", "Activity", "RAG", "History", "Suggestion", "ChatSummary"] \ No newline at end of file diff --git a/apps/service/activity.py b/apps/service/activity.py new file mode 100644 index 000000000..d2622e374 --- /dev/null +++ b/apps/service/activity.py @@ -0,0 +1,34 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.models.redis import RedisConnectionPool + +class Activity: + """ + 用户活动控制,限制单用户同一时间只能提问一个问题 + """ + def __init__(self): + raise NotImplementedError("Activity无法被实例化!") + + @staticmethod + def is_active(user_sub) -> bool: + """ + 判断当前用户是否正在提问(占用GPU资源) + :param user_sub: 用户实体ID + :return: 判断结果,正在提问则返回True + """ + with RedisConnectionPool.get_redis_connection() as r: + if not r.get(f'{user_sub}_active'): + return False + else: + r.expire(f'{user_sub}_active', 300) + return True + + @staticmethod + def remove_active(user_sub): + """ + 清除用户的活动标识,释放GPU资源 + :param user_sub: 用户实体ID + :return: + """ + with RedisConnectionPool.get_redis_connection() as r: + r.delete(f'{user_sub}_active') diff --git a/apps/service/domain.py b/apps/service/domain.py new file mode 100644 index 000000000..c8b4de0fe --- /dev/null +++ b/apps/service/domain.py @@ -0,0 +1,98 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import json +import logging + +from apps.llm import get_json_code_block, get_llm, get_message_model + + +logger = logging.getLogger('gunicorn.error') + + +class Domain: + """ + 用户领域画像 + """ + def __init__(self): + raise NotImplementedError("Domain无法被实例化!") + + @staticmethod + def check_domain(question, answer, domain): + llm = get_llm() + prompt = f""" + 请判断以下对话内容涉及领域列表中的哪几个领域 + + 请按照以下json格式输出: + ```json + {{ + "domain":["domain1","domain2","domain3",...] //可能属于一个或者多个领域,必须出现在领域列表中,如果都不涉及可以为空 + }} + ``` + + 对话内容: + 提问: {question} + 回答: {answer} + + 领域列表: + {domain} + """ + output = llm.invoke(prompt) + logger.info("domain_output: {}".format(output)) + try: + json_str = get_json_code_block(output.content) + result = json.loads(json_str) + return result['domain'] + except Exception as e: + logger.error(f"检测领域信息出错:{str(e)}") + return [] + + @staticmethod + def generate_suggestion(summary, last_chat, domain): + llm = get_llm() + msg_cls = get_message_model(llm) + + system_prompt = """根据提供的用户领域和历史对话内容,生成三条预测问题,用于指导用户进行下一步的提问。搜索建议必须遵从用户领域,并结合背景信息。 + 要求:生成的问题必须为祈使句或疑问句,不得超过30字,生成的问题不要与用户提问完全相同。严格按照以下JSON格式返回: + ```json + {{ + "suggestions":["Q:suggestion1","Q:suggestion2","Q:suggestion3"] //返回三条问题 + }} + ```""" + + user_prompt = """## 背景信息 + {summary} + + ## 最近对话 + {last_chat} + + ## 用户领域 + {domain}""" + + messages = [ + msg_cls(role="system", content=system_prompt), + msg_cls(role="user", content=user_prompt.format(summary=summary, last_chat=last_chat, domain=domain)) + ] + + output = llm.invoke(messages) + print(output) + try: + json_str = get_json_code_block(output.content) + result = json.loads(json_str) + format_result = [] + for item in result['suggestions']: + if item.startswith("Q:"): + format_result.append({ + "id": "", + "name": "", + "question": item[2:] + }) + else: + format_result.append({ + "id": "", + "name": "", + "question": item + }) + return format_result + except Exception as e: + logger.error(f"生成推荐问题出错:{str(e)}") + return [] diff --git a/apps/service/history.py b/apps/service/history.py new file mode 100644 index 000000000..07ede45fb --- /dev/null +++ b/apps/service/history.py @@ -0,0 +1,59 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from __future__ import annotations + +from typing import List, Dict +import uuid + +from apps.manager.record import RecordManager +from apps.common.security import Security +from apps.manager.conversation import ConversationManager + + +class History: + """ + 获取对话历史记录 + """ + def __init__(self): + raise NotImplementedError("History类无法被实例化!") + + @staticmethod + def get_latest_records(conversation_id: str, record_id: str | None = None, n: int = 1): + # 是重新生成,从record_id中拿出group_id + if record_id is not None: + record = RecordManager().query_encrypted_data_by_record_id(record_id) + group_id = record.group_id + # 全新生成,创建新的group_id + else: + group_id = str(uuid.uuid4().hex) + + record_list = RecordManager().query_encrypted_data_by_conversation_id( + conversation_id, n, group_id) + record_list_sorted = sorted(record_list, key=lambda x: x.created_time) + + return group_id, record_list_sorted + + @staticmethod + def get_history_messages(conversation_id, record_id): + group_id, record_list_sorted = History.get_latest_records(conversation_id, record_id) + history: List[Dict[str, str]] = [] + for item in record_list_sorted: + tmp_question = Security.decrypt( + item.encrypted_question, item.question_encryption_config) + tmp_answer = Security.decrypt( + item.encrypted_answer, item.answer_encryption_config) + history.append({"role": "user", "content": tmp_question}) + history.append({"role": "assistant", "content": tmp_answer}) + return group_id, history + + @staticmethod + def get_summary(conversation_id): + """ + 根据对话ID,从数据库中获取对话的总结 + :param conversation_id: 对话ID + :return: 对话总结信息,字符串或None + """ + conv = ConversationManager.get_conversation_by_conversation_id(conversation_id) + if conv.summary is None: + return "" + return conv.summary diff --git a/apps/service/rag.py b/apps/service/rag.py new file mode 100644 index 000000000..f855751a5 --- /dev/null +++ b/apps/service/rag.py @@ -0,0 +1,46 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import json +import aiohttp + +from apps.common.config import config +from apps.service import Activity + + +class RAG: + """ + 调用RAG服务,获取知识库答案 + """ + + def __init__(self): + raise NotImplementedError("RAG类无法被实例化!") + + @staticmethod + async def get_rag_result(user_sub: str, question: str, language: str, history: list): + url = config["RAG_HOST"].rstrip("/") + "/kb/get_stream_answer" + headers = { + "Content-Type": "application/json" + } + data = { + "question": question, + "history": history, + "language": language, + "kb_sn": f'{language}_default_test', + "top_k": 5, + "fetch_source": False + } + if config['RAG_KB_SN']: + data.update({"kb_sn": config['RAG_KB_SN']}) + payload = json.dumps(data, ensure_ascii=False) + + yield "data: " + json.dumps({"content": "正在查询知识库,请稍等...\n\n"}) + "\n\n" + # asyncio HTTP请求 + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=300)) as session: + async with session.post(url, headers=headers, data=payload, ssl=False) as response: + async for line in response.content: + line_str = line.decode('utf-8') + + if line_str != "data: [DONE]" and Activity.is_active(user_sub): + yield line_str + else: + return diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py new file mode 100644 index 000000000..72e0c0445 --- /dev/null +++ b/apps/service/suggestion.py @@ -0,0 +1,34 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.manager.domain import DomainManager +from apps.manager.user_domain import UserDomainManager +from apps.service.domain import Domain + + +class Suggestion: + def __init__(self): + raise NotImplementedError("Suggestion类无法被实例化!") + + @staticmethod + def update_user_domain(user_sub: str, question: str, answer: str): + domain_list = DomainManager.get_domain() + domain = {} + for item in domain_list: + domain[item.domain_name] = item.domain_description + + domain_list = Domain.check_domain(question, answer, domain) + for item in domain_list: + UserDomainManager.update_user_domain_by_user_sub_and_domain_name(user_sub=user_sub, domain_name=item) + return + + @staticmethod + def generate_suggestions(user_sub, summary, question, answer): + user_domain = UserDomainManager.get_user_domain_by_user_sub_and_topk(user_sub, 1) + domain = {} + for item in user_domain: + domain[item.domain_name] = item.domain_description + format_result = Domain.generate_suggestion(summary, { + "question": question, + "answer": answer + }, domain) + return format_result \ No newline at end of file diff --git a/apps/service/summary.py b/apps/service/summary.py new file mode 100644 index 000000000..9903c3c7d --- /dev/null +++ b/apps/service/summary.py @@ -0,0 +1,19 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +from apps.llm import get_llm, get_message_model + +class ChatSummary: + def __init__(self): + raise NotImplementedError("Summary类无法被实例化!") + + @staticmethod + async def generate_chat_summary(last_summary: str, question: str, answer: str): + llm = get_llm() + msg_cls = get_message_model(llm) + messages = [ + msg_cls(role="system", content="Progressively summarize the lines of conversation provided, adding onto the previous summary."), + msg_cls(role="user", content=f"{last_summary}\n\nQuestion: {question}\nAnswer: {answer}"), + ] + + result = llm.invoke(messages) + return result.content diff --git a/apps/utils/user_exporter.py b/apps/utils/user_exporter.py new file mode 100644 index 000000000..58b07eebc --- /dev/null +++ b/apps/utils/user_exporter.py @@ -0,0 +1,242 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + +import argparse +import datetime +import os +import re +import sys +import secrets +import shutil +import zipfile + +from openpyxl import Workbook + +from apps.common.security import Security +from apps.manager.audit_log import AuditLogData, AuditLogManager +from apps.manager.record import RecordManager +from apps.manager.user import UserManager +from apps.manager.conversation import ConversationManager + + +class UserExporter: + start_row_id = 1 + chat_xlsx_column = ['question', 'answer', 'created_time'] + chat_column_map = { + 'question_column': 1, + 'answer_column': 2, + 'created_time_column': 3 + } + user_info_xlsx_column = [ + 'user_sub', 'organization', + 'created_time', 'login_time', 'revision_number' + ] + user_info_column_map = { + 'user_sub_column': 1, + 'organization_column': 2, + 'created_time_column': 3, + 'login_time_column': 4, + 'revision_number_column': 5 + } + + @staticmethod + def get_datetime_from_str(date_str, date_format): + date_time_obj = datetime.datetime.strptime(date_str, date_format) + date_time_obj = datetime.datetime(date_time_obj.year, date_time_obj.month, date_time_obj.day) + timestamp = date_time_obj.timestamp() + return timestamp + + @staticmethod + def zip_xlsx_folder(tmp_out_dir): + dir_name = os.path.dirname(tmp_out_dir) + last_dir_name = os.path.basename(tmp_out_dir) + xlsx_file_name_list = os.listdir(tmp_out_dir) + zip_file_dir = os.path.join(dir_name, last_dir_name+'.zip') + with zipfile.ZipFile(zip_file_dir, 'w') as zip_file: + for xlsx_file_name in xlsx_file_name_list: + xlsx_file_path = os.path.join(tmp_out_dir, xlsx_file_name) + zip_file.write(xlsx_file_path) + return zip_file_dir + + @staticmethod + def save_chat_to_xlsx(xlsx_dir, chat_list): + workbook = Workbook() + sheet = workbook.active + for i, column in enumerate(UserExporter.chat_xlsx_column): + sheet.cell(row=UserExporter.start_row_id, column=i+1, value=column) + row_id = UserExporter.start_row_id + 1 + for chat in chat_list: + question = chat[0] + answer = chat[1] + created_time = chat[2] + sheet.cell(row=row_id, + column=UserExporter.chat_column_map['question_column'], + value=question) + sheet.cell(row=row_id, + column=UserExporter.chat_column_map['answer_column'], + value=answer) + sheet.cell(row=row_id, + column=UserExporter.chat_column_map['created_time_column'], + value=created_time) + row_id += 1 + workbook.save(xlsx_dir) + + @staticmethod + def save_user_info_to_xlsx(xlsx_dir, user_info): + workbook = Workbook() + sheet = workbook.active + for i, column in enumerate(UserExporter.user_info_xlsx_column): + sheet.cell(row=UserExporter.start_row_id, column=i+1, value=column) + row_id = UserExporter.start_row_id + 1 + user_sub = user_info.user_sub + organization = user_info.organization + created_time = user_info.created_time + login_time = user_info.login_time + revision_number = user_info.revision_number + sheet.cell(row=row_id, + column=UserExporter.user_info_column_map['user_sub_column'], + value=user_sub) + sheet.cell(row=row_id, + column=UserExporter.user_info_column_map['organization_column'], + value=organization) + sheet.cell(row=row_id, + column=UserExporter.user_info_column_map['created_time_column'], + value=created_time) + sheet.cell(row=row_id, + column=UserExporter.user_info_column_map['login_time_column'], + value=login_time) + sheet.cell(row=row_id, + column=UserExporter.user_info_column_map['revision_number_column'], + value=revision_number) + workbook.save(xlsx_dir) + + @staticmethod + def export_user_info_to_xlsx(tmp_out_dir, user_sub): + user_info = UserManager.get_userinfo_by_user_sub(user_sub) + xlsx_file_name = 'user_info_'+user_sub+'.xlsx' + xlsx_file_name = re.sub(r'[<>:"/\\|?*]', '_', xlsx_file_name) + xlsx_file_name = xlsx_file_name.replace(' ', '_') + xlsx_dir = os.path.join(tmp_out_dir, xlsx_file_name) + UserExporter.save_user_info_to_xlsx(xlsx_dir, user_info) + + @staticmethod + def export_chats_to_xlsx(tmp_out_dir, user_sub, start_day, end_day): + user_qa_records = ConversationManager.get_conversation_by_user_sub( + user_sub) + for user_qa_record in user_qa_records: + chat_id = user_qa_record.conversation_id + chat_tile = re.sub(r'[<>:"/\\|?*]', '_', user_qa_record.title) + chat_tile = chat_tile.replace(' ', '_')[:20] + chat_created_time = str(user_qa_record.created_time) + encrypted_qa_records = RecordManager.query_encrypted_data_by_conversation_id( + chat_id) + chat = [] + for record in encrypted_qa_records: + question = Security.decrypt(record.encrypted_question, + record.question_encryption_config) + answer = Security.decrypt(record.encrypted_answer, + record.answer_encryption_config) + qa_record_created_time = record.created_time + if start_day is not None: + if UserExporter.get_datetime_from_str(record.created_time, "%Y-%m-%d %H:%M:%S") < start_day: + continue + if end_day is not None: + if UserExporter.get_datetime_from_str(record.created_time, "%Y-%m-%d %H:%M:%S") > end_day: + continue + chat.append([question, answer, qa_record_created_time]) + xlsx_file_name = 'chat_'+chat_tile[:20] + '_'+chat_created_time+'.xlsx' + xlsx_file_name = xlsx_file_name.replace(' ', '') + xlsx_dir = os.path.join(tmp_out_dir, xlsx_file_name) + UserExporter.save_chat_to_xlsx(xlsx_dir, chat) + + @staticmethod + def export_user_data(users_dir, user_sub, export_preferences=None, start_day=None, end_day=None): + export_preferences = export_preferences or ['user_info', 'chat'] + rand_num = secrets.randbits(128) + tmp_out_dir = os.path.join('./', users_dir, str(rand_num)) + if os.path.exists(tmp_out_dir): + shutil.rmtree(tmp_out_dir) + os.mkdir(tmp_out_dir) + os.chmod(tmp_out_dir, 0o750) + if 'user_info' in export_preferences: + UserExporter.export_user_info_to_xlsx(tmp_out_dir, user_sub) + if 'chat' in export_preferences: + UserExporter.export_chats_to_xlsx(tmp_out_dir, user_sub, start_day, end_day) + zip_file_path = UserExporter.zip_xlsx_folder(tmp_out_dir) + shutil.rmtree(tmp_out_dir) + return zip_file_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--user_sub", type=str, required=True, + help='''Please provide usr_sub identifier for the export \ + process. This ID ensures that the exported data is \ + accurately associated with your user profile.If this \ + field is \"all\", then all user information will be \ + exported''') + parser.add_argument("--export_preferences", type=str, required=True, + help='''Please enter your export preferences by specifying \ + 'chat' and/or 'user_info', separated by a space \ + if including both. Ensure that your input is limited to \ + these options for accurate data export processing.''') + parser.add_argument("--start_day", type=str, required=False, + help='''User record export start date, format reference is \ + as follows: 2024_03_23''') + parser.add_argument("--end_day", type=str, required=False, + help='''User record export end date, format reference is \ + as follows: 2024_03_23''') + args = vars(parser.parse_args()) + arg_user_sub = args['user_sub'] + arg_export_preferences = args['export_preferences'].split(' ') + start_day = args['start_day'] + end_day = args['end_day'] + try: + if start_day is not None: + start_day = UserExporter.get_datetime_from_str(start_day, "%Y_%m_%d") + except Exception as e: + data = AuditLogData( + method_type='internal_user_exporter', source_name='start_day_exchange', ip='internal', + result=f'start_day_exchange failed due error: {e}', + reason=f'导出用户数据时,起始时间填写有误' + ) + AuditLogManager.add_audit_log(arg_user_sub, data) + try: + if end_day is not None: + end_day = UserExporter.get_datetime_from_str(end_day, "%Y_%m_%d") + except Exception as e: + data = AuditLogData( + method_type='internal_user_exporter', source_name='end_day_exchange', ip='internal', + result=f'end_day_exchange failed due error: {e}', + reason=f'导出用户数据时,结束时间填写有误' + ) + AuditLogManager.add_audit_log(arg_user_sub, data) + if arg_user_sub == "all": + user_sub_list = UserManager.get_all_user_sub() + else: + user_sub_list = [arg_user_sub] + users_dir = str(secrets.randbits(128)) + if os.path.exists(users_dir): + shutil.rmtree(users_dir) + os.mkdir(users_dir) + os.chmod(users_dir, 0o750) + for arg_user_sub in user_sub_list: + arg_user_sub = arg_user_sub[0] + try: + export_path = UserExporter.export_user_data( + users_dir, arg_user_sub, arg_export_preferences, start_day, end_day) + audit_export_preference = f', preference: {arg_export_preferences}' if arg_export_preferences else '' + data = AuditLogData( + method_type='internal_user_exporter', source_name='export_user_data', ip='internal', + result=f'exported user data of id: {arg_user_sub}{audit_export_preference}, path: {export_path}', + reason=f'用户(id: {arg_user_sub})请求导出数据' + ) + AuditLogManager.add_audit_log(arg_user_sub, data) + except Exception as e: + data = AuditLogData( + method_type='internal_user_exporter', source_name='export_user_data', ip='internal', + result=f'export_user_data failed due error: {e}', + reason=f'用户(id: {arg_user_sub})请求导出数据失败' + ) + AuditLogManager.add_audit_log(arg_user_sub, data) + zip_file_path = UserExporter.zip_xlsx_folder(users_dir) + shutil.rmtree(users_dir) diff --git a/assets/euler-copilot-frame.sql b/assets/euler-copilot-frame.sql new file mode 100644 index 000000000..5b0bdd205 --- /dev/null +++ b/assets/euler-copilot-frame.sql @@ -0,0 +1,77 @@ +CREATE TABLE `user` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `user_sub` varchar(100) NOT NULL, + `passwd` varchar(100) DEFAULT NULL, + `organization` varchar(100) DEFAULT NULL, + `created_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + `login_time` datetime DEFAULT NULL, + `revision_number` varchar(100) DEFAULT NULL, + `credit` int unsigned NOT NULL DEFAULT 100, + `is_whitelisted` boolean NOT NULL DEFAULT 0, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `audit_log` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `user_sub` varchar(100) DEFAULT NULL, + `method_type` varchar(100) DEFAULT NULL, + `source_name` varchar(100) DEFAULT NULL, + `ip` varchar(100) DEFAULT NULL, + `result` varchar(100) DEFAULT NULL, + `reason` varchar(100) DEFAULT NULL, + `created_time` datetime DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `comment` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `qa_record_id` varchar(100) NOT NULL UNIQUE, + `is_like` boolean DEFAULT NULL, + `dislike_reason` varchar(100) DEFAULT NULL, + `reason_link` varchar(200) DEFAULT NULL, + `reason_description` varchar(500) DEFAULT NULL, + `created_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + `user_sub` varchar(100) DEFAULT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `user_qa_record` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `user_qa_record_id` varchar(100) NOT NULL UNIQUE, + `user_sub` varchar(100) NOT NULL, + `title` varchar(200) NOT NULL, + `created_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `qa_record` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `user_qa_record_id` varchar(100) NOT NULL, + `encrypted_question` text NOT NULL, + `question_encryption_config` varchar(1000) NOT NULL, + `encrypted_answer` text NOT NULL, + `answer_encryption_config` varchar(1000) NOT NULL, + `created_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + `qa_record_id` varchar(100) NOT NULL UNIQUE, + `group_id` varchar(100) DEFAULT NULL, + PRIMARY KEY (`id`), + KEY `idx_user_qa_record_id` (`user_qa_record_id`) USING BTREE +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `api_key` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `user_sub` varchar(100) NOT NULL, + `api_key_hash` varchar(16) NOT NULL, + `created_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; + +CREATE TABLE `question_blacklist` ( + `id` bigint unsigned NOT NULL AUTO_INCREMENT, + `question` text NOT NULL, + `answer` text NOT NULL, + `is_audited` boolean NOT NULL DEFAULT FALSE, + `reason_description` varchar(200) DEFAULT NULL, + `created_time` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; diff --git a/assets/host.example.json b/assets/host.example.json new file mode 100644 index 000000000..424d71f99 --- /dev/null +++ b/assets/host.example.json @@ -0,0 +1,12 @@ +{ + "hosts": [ + { + "name": "host name", + "desc": "description of host", + "ip": "host_ip", + "port": 22, + "username": "username", + "pkey_path": "/xx/xx/xx/host/pkey.pem" + } + ] +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..87862a026 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,41 @@ +pytz==2024.1 +pydantic==2.7.1 +vanna==0.6.2 +pandas==2.2.2 +langchain==0.1.16 +langchain-openai==0.1.6 +pgvector==0.2.5 +sqlalchemy==2.0.23 +sglang==0.3.0 +requests==2.32.3 +ipython==8.18.1 +python-dotenv==1.0.0 +cryptography==42.0.2 +redis==4.5.4 +uvicorn==0.21.0 +apscheduler==3.10.0 +fastapi==0.110.2 +aiohttp==3.9.5 +paramiko==3.4.0 +asgiref==3.8.1 +starlette==0.37.2 +pyyaml==6.0.1 +chromadb==0.5.0 +pyjwt==2.8.0 +limits==3.7.0 +paramiko==3.4.0 +more-itertools==10.2.0 +spark-ai-python==0.4.1 +psycopg2-binary==2.9.9 +PyMySQL==1.1.1 +python-multipart==0.0.9 +aiofiles==24.1.0 +coverage==7.6.0 +numpy==1.26.4 +openpyxl==3.1.5 +openai==1.41.0 +langchain-core==0.1.52 +langchain-community==0.0.38 +gunicorn==23.0.0 +untruncate-json==1.0.0 +JSON-minify==0.3.0 \ No newline at end of file diff --git a/sdk/example_plugin/flows/flow.yaml b/sdk/example_plugin/flows/flow.yaml new file mode 100644 index 000000000..318d681e1 --- /dev/null +++ b/sdk/example_plugin/flows/flow.yaml @@ -0,0 +1,53 @@ +name: test +description: 测试工作流 +on_error: + call_type: llm + params: + user_prompt: | + 背景信息: + {context} + + 错误信息: + {output} + + 使用自然语言解释这一信息,并给出可能的解决方法。 +steps: + - name: start + call_type: api + dangerous: true + params: + endpoint: GET /api/test + next: flow_choice + - name: flow_choice + call_type: choice + params: + instruction: 工具的返回值是否为Markdown报告? + choices: + - step: end + description: 返回值为Markdown格式时,选择此项 + - step: report_gen + description: 返回值不是Markdown格式时,选择此项 + - name: report_gen + call_type: llm + params: + system_prompt: 你是一个擅长Linux系统性能优化,且能够根据具体情况撰写分析报告的智能助手。 + user_prompt: | + 用户问题: + {question} + + 工具的输出信息: + {message} + + 背景信息: + {context} + + 根据上述信息,撰写系统性能分析报告。 + next: end + - name: end + call_type: extract + params: + keys: + - content +next_flow: + - test2 + - test3 diff --git a/sdk/example_plugin/lib/__init__.py b/sdk/example_plugin/lib/__init__.py new file mode 100644 index 000000000..c4b58a3ce --- /dev/null +++ b/sdk/example_plugin/lib/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# 这里应当导入所有工具类型 +from .user_tool import UserTool + +# 在_exported变量中,加入所有工具类型 +_exported = [ + UserTool +] \ No newline at end of file diff --git a/sdk/example_plugin/lib/sub_lib/__init__.py b/sdk/example_plugin/lib/sub_lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sdk/example_plugin/lib/user_tool.py b/sdk/example_plugin/lib/user_tool.py new file mode 100644 index 000000000..c65eb98be --- /dev/null +++ b/sdk/example_plugin/lib/user_tool.py @@ -0,0 +1,45 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Python工具基本形式,供用户参考 +from __future__ import annotations + +from typing import Optional, Any, List, Dict + +# 可以使用子模块 +from . import sub_lib + +from pydantic import BaseModel, Field + + +# 此处为工具接受的各项参数。参数可在flow中配置,也可由大模型自动填充 +class UserCallParams(BaseModel): + background: str = Field(description="上下文信息,由Executor自动传递") + question: str = Field(description="给Call提供的用户输入,由Executor自动传递") + files: List[str] = Field(description="用户询问问题时上传的文件,由Executor自动传递") + previous_data: Optional[Dict[str, Any]] = Field(description="Flow中前一个Call输出的结构化数据") + must: str = Field(description="这是必填参数的实例", default="这是默认值的示例") + opt: Optional[int] = Field(description="这是可选参数的示例") + + +# 这是工具类的基础形式 +class UserTool: + name: str = "user_tool" # 工具名称,会体现在flow中的on_error[].tool和steps[].tool字段内 + description: str = "用户自定义工具样例" # 工具描述,后续将用于自动编排工具 + params_obj: UserCallParams + + def __init__(self, params: Dict[str, Any]): + # 此处验证传递给Call的参数是否合法 + self.params_obj = UserCallParams(**params) + pass + + # 此处为工具调用逻辑。注意:函数的参数名称与类型不可改变 + async def call(self, fixed_params: dict) -> Dict[str, Any]: + # fixed_params:如果用户因为dangerous等原因修改了params,则此处修改params_obj + self.params_obj = UserCallParams(**fixed_params) + + output = "" + message = "" + # 返回值为dict类型,其中output字段为工具的原始数据(带格式);message字段为工具经LLM处理后的数据(仅字符串);您还可以提供其他数据字段 + return { + "output": output, + "message": message, + } diff --git a/sdk/example_plugin/openapi.yaml b/sdk/example_plugin/openapi.yaml new file mode 100644 index 000000000..56f2e5e14 --- /dev/null +++ b/sdk/example_plugin/openapi.yaml @@ -0,0 +1,37 @@ +openapi: 3.0.0 +info: + version: 1.0.0 + title: "文档标题" + +servers: + - url: "http://example.com:port/suffix" + +paths: + /url: + post: + dangerous: true + description: "API的描述信息" + requestBody: + description: "API请求体的总描述" + content: + application/json: + schema: + type: object + properties: + data: + type: string + example: "字段的样例值" + description: "字段的描述信息" + responses: + '200': + description: "API返回体的总描述" + content: + application/json: + schema: + type: object + properties: + name: + type: string + example: "字段的样例值" + description: "字段的描述信息" + pattern: "[\d].[\d]" \ No newline at end of file diff --git a/sdk/example_plugin/plugin.json b/sdk/example_plugin/plugin.json new file mode 100644 index 000000000..31d4307d3 --- /dev/null +++ b/sdk/example_plugin/plugin.json @@ -0,0 +1,11 @@ +{ + "id": "1", + "name": "示例插件", + "description": "这是示例插件,用于演示编写插件的格式。\n 'type' 可以为 'api' 、 'local',指定了插件的类型。", + "auth": { + "type": "header", + "args": { + "Authorization": "" + } + } +} \ No newline at end of file -- Gitee