diff --git a/Dockerfile b/Dockerfile
index 3e3de77938d02721e2cc28084dd76b37698b2aca..dd92a1468089e7bcc7c15133f32311e4001ad030 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,18 +1,11 @@
FROM hub.oepkgs.net/neocopilot/data_chain_back_end_base:0.9.6-x86
-COPY --chown=1001:1001 --chmod=750 ./ /rag-service/
+COPY --chmod=750 ./ /rag-service/
WORKDIR /rag-service
ENV PYTHONPATH /rag-service
USER root
-RUN sed -i 's/umask 002/umask 027/g' /etc/bashrc && \
- sed -i 's/umask 022/umask 027/g' /etc/bashrc && \
- # yum remove -y python3-pip gdb-gdbserver && \
- sh -c "find /usr /etc \( -name *yum* -o -name *dnf* -o -name *vi* \) -exec rm -rf {} + || true" && \
- sh -c "find /usr /etc \( -name ps -o -name top \) -exec rm -rf {} + || true" && \
- sh -c "rm -f /usr/bin/find /usr/bin/oldfind || true"
-USER eulercopilot
CMD ["/bin/bash", "run.sh"]
diff --git a/Dockerfile-base b/Dockerfile-base
index 8222781d900b063371cb8413c696ce050c444213..58fbae5fee55a7cc9d9a718d310bab589a9f1536 100644
--- a/Dockerfile-base
+++ b/Dockerfile-base
@@ -8,19 +8,15 @@ RUN sed -i 's|http://repo.openeuler.org/|https://mirrors.huaweicloud.com/openeul
yum makecache &&\
yum update -y &&\
yum install -y mesa-libGL java python3 python3-pip shadow-utils &&\
- yum clean all && \
- groupadd -g 1001 eulercopilot && useradd -u 1001 -g eulercopilot eulercopilot
+ yum clean all
-# 创建 /rag-service 目录并设置权限
-RUN mkdir -p /rag-service && chown -R 1001:1001 /rag-service
-
-# 切换到 eulercopilot 用户
-USER eulercopilot
+# 创建 /rag-service
+RUN mkdir -p /rag-service
# 复制 requirements.txt 文件到 /rag-service 目录
-COPY --chown=1001:1001 requirements.txt /rag-service/
-COPY --chown=1001:1001 tika-server-standard-2.9.2.jar /rag-service/
-COPY --chown=1001:1001 download_model.py /rag-service/
+COPY requirements.txt /rag-service/
+COPY tika-server-standard-2.9.2.jar /rag-service/
+COPY download_model.py /rag-service/
# 安装 Python 依赖
RUN pip3 install --no-cache-dir -r /rag-service/requirements.txt --index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
diff --git a/chat2db/.gitignore b/chat2db/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..3040d646d234ca4d2ec099a0378f110f2e7f3775
--- /dev/null
+++ b/chat2db/.gitignore
@@ -0,0 +1,2 @@
+__pycache__/
+.vscode/
\ No newline at end of file
diff --git a/chat2db/app/__init__.py b/chat2db/app/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/chat2db/app/app.py b/chat2db/app/app.py
deleted file mode 100644
index 71be6ed2b862a04647bfb08845a7424e8adf6041..0000000000000000000000000000000000000000
--- a/chat2db/app/app.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import uvicorn
-from fastapi import FastAPI
-import sys
-from chat2db.app.router import sql_example
-from chat2db.app.router import sql_generate
-from chat2db.app.router import database
-from chat2db.app.router import table
-from chat2db.config.config import config
-import logging
-
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-app = FastAPI()
-
-app.include_router(sql_example.router)
-app.include_router(sql_generate.router)
-app.include_router(database.router)
-app.include_router(table.router)
-
-if __name__ == '__main__':
- try:
- ssl_enable = config["SSL_ENABLE"]
- if ssl_enable:
- uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]),
- proxy_headers=True, forwarded_allow_ips='*',
- ssl_certfile=config["SSL_CERTFILE"],
- ssl_keyfile=config["SSL_KEYFILE"],
- )
- else:
- uvicorn.run(app, host=config["UVICORN_IP"], port=int(config["UVICORN_PORT"]),
- proxy_headers=True, forwarded_allow_ips='*'
- )
- except Exception as e:
- exit(1)
diff --git a/chat2db/app/base/ac_automation.py b/chat2db/app/base/ac_automation.py
deleted file mode 100644
index 3012f2bb73d599771f63ef0cd2617e3f43d73dbb..0000000000000000000000000000000000000000
--- a/chat2db/app/base/ac_automation.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-import copy
-import logging
-import sys
-
-class Node:
- def __init__(self, dep, pre_id):
- self.dep = dep
- self.pre_id = pre_id
- self.pre_nearest_children_id = {}
- self.children_id = {}
- self.data_frame = None
-
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-
-class DictTree:
- def __init__(self):
- self.root = 0
- self.node_list = [Node(0, -1)]
-
- def load_data(self, data_dict):
- for key in data_dict:
- self.insert_data(key, data_dict[key])
- self.init_pre()
-
- def insert_data(self, keyword, data_frame):
- if not isinstance(keyword,str):
- return
- if len(keyword) == 0:
- return
- node_index = self.root
- try:
- for i in range(len(keyword)):
- if keyword[i] not in self.node_list[node_index].children_id.keys():
- self.node_list.append(Node(self.node_list[node_index].dep+1, 0))
- self.node_list[node_index].children_id[keyword[i]] = len(self.node_list)-1
- node_index = self.node_list[node_index].children_id[keyword[i]]
- except Exception as e:
- logging.error(f'关键字插入失败由于:{e}')
- return
- self.node_list[node_index].data_frame = data_frame
-
- def init_pre(self):
- q = [self.root]
- l = 0
- r = 1
- try:
- while l < r:
- node_index = q[l]
- self.node_list[node_index].pre_nearest_children_id = self.node_list[self.node_list[node_index].pre_id].children_id.copy()
- l += 1
- for key, val in self.node_list[node_index].children_id.items():
- q.append(val)
- r += 1
- if key in self.node_list[node_index].pre_nearest_children_id.keys():
- pre_id = self.node_list[node_index].pre_nearest_children_id[key]
- self.node_list[val].pre_id = pre_id
- self.node_list[node_index].pre_nearest_children_id[key] = val
- except Exception as e:
- logging.error(f'字典树前缀构建失败由于:{e}')
- return
-
- def get_results(self, content: str):
- content = content.lower()
- pre_node_index = self.root
- nex_node_index = None
- results = []
- logging.info(f'当前问题{content}')
- try:
- for i in range(len(content)):
- if content[i] in self.node_list[pre_node_index].pre_nearest_children_id.keys():
- nex_node_index = self.node_list[pre_node_index].pre_nearest_children_id[content[i]]
- else:
- nex_node_index = 0
- if self.node_list[pre_node_index].dep >= self.node_list[nex_node_index].dep:
- if self.node_list[pre_node_index].data_frame is not None:
- results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame))
- pre_node_index = nex_node_index
- logging.info(f'当前深度{self.node_list[pre_node_index].dep}')
- if self.node_list[pre_node_index].data_frame is not None:
- results.extend(copy.deepcopy(self.node_list[pre_node_index].data_frame))
- except Exception as e:
- logging.error(f'结果获取失败由于:{e}')
- return results
diff --git a/chat2db/app/base/mysql.py b/chat2db/app/base/mysql.py
deleted file mode 100644
index b47322bc4dec4b254e86af04ea95d9f47a63457c..0000000000000000000000000000000000000000
--- a/chat2db/app/base/mysql.py
+++ /dev/null
@@ -1,217 +0,0 @@
-
-import asyncio
-import aiomysql
-import concurrent.futures
-import logging
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy import create_engine, text
-import sys
-from concurrent.futures import ThreadPoolExecutor
-from urllib.parse import urlparse
-from chat2db.app.base.meta_databbase import MetaDatabase
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-
-class Mysql(MetaDatabase):
- executor = ThreadPoolExecutor(max_workers=10)
-
- async def test_database_connection(database_url):
- try:
- with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
- future = executor.submit(Mysql._connect_and_query, database_url)
- result = future.result(timeout=5)
- return result
- except concurrent.futures.TimeoutError:
- logging.error('mysql数据库连接超时')
- return False
- except Exception as e:
- logging.error(f'mysql数据库连接失败由于{e}')
- return False
-
- @staticmethod
- def _connect_and_query(database_url):
- try:
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- session = sessionmaker(bind=engine)()
- session.execute(text("SELECT 1"))
- session.close()
- return True
- except Exception as e:
- raise e
-
- @staticmethod
- async def drop_table(database_url, table_name):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with sessionmaker(engine)() as session:
- sql_str = f"DROP TABLE IF EXISTS {table_name};"
- session.execute(text(sql_str))
-
- @staticmethod
- async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword):
- try:
- url = urlparse(database_url)
- db_config = {
- 'host': url.hostname or 'localhost',
- 'port': int(url.port or 3306),
- 'user': url.username or 'root',
- 'password': url.password or '',
- 'db': url.path.strip('/')
- }
-
- async with aiomysql.create_pool(**db_config) as pool:
- async with pool.acquire() as conn:
- async with conn.cursor() as cur:
- primary_key_query = """
- SELECT
- COLUMNS.column_name
- FROM
- information_schema.tables AS TABLES
- INNER JOIN information_schema.columns AS COLUMNS ON TABLES.table_name = COLUMNS.table_name
- WHERE
- TABLES.table_schema = %s AND TABLES.table_name = %s AND COLUMNS.column_key = 'PRI';
- """
-
- # 尝试执行查询
- await cur.execute(primary_key_query, (db_config['db'], table_name))
- primary_key_list = await cur.fetchall()
- if not primary_key_list:
- return []
- primary_key_names = ', '.join([record[0] for record in primary_key_list])
- columns = f'{primary_key_names}, {keyword}'
- query = f'SELECT {columns} FROM {table_name};'
- await cur.execute(query)
- results = await cur.fetchall()
-
- def _process_results(results, primary_key_list):
- tmp_dict = {}
- for row in results:
- key = str(row[-1])
- if key not in tmp_dict:
- tmp_dict[key] = []
- pk_values = [str(row[i]) for i in range(len(primary_key_list))]
- tmp_dict[key].append(pk_values)
-
- return {
- 'primary_key_list': [record[0] for record in primary_key_list],
- 'keyword_value_dict': tmp_dict
- }
- result = await asyncio.get_event_loop().run_in_executor(
- Mysql.executor,
- _process_results,
- results,
- primary_key_list
- )
- return result
-
- except Exception as e:
- logging.error(f'mysql数据检索失败由于 {e}')
-
- @staticmethod
- async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list):
- sql_str = f'SELECT * FROM {table_name} where '
- for i in range(len(primary_key_list)):
- sql_str += primary_key_list[i]+'= \''+primary_key_value_list[i]+'\''
- if i != len(primary_key_list)-1:
- sql_str += ' and '
- sql_str += ';'
- return sql_str
-
- @staticmethod
- async def get_table_info(database_url, table_name):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with sessionmaker(engine)() as session:
- sql_str = f"""SELECT TABLE_COMMENT FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table_name}';"""
- table_note = session.execute(text(sql_str)).one()[0]
- if table_note == '':
- table_note = table_name
- table_note = {
- 'table_name': table_name,
- 'table_note': table_note
- }
- return table_note
-
- @staticmethod
- async def get_column_info(database_url, table_name):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with engine.connect() as conn:
- sql_str = f"""
- SELECT column_name, column_type, column_comment FROM information_schema.columns where TABLE_NAME='{table_name}';
- """
- results = conn.execute(text(sql_str), {'table_name': table_name}).all()
- column_info_list = []
- for result in results:
- column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]})
- return column_info_list
-
- @staticmethod
- async def get_all_table_name_from_database_url(database_url):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with engine.connect() as connection:
- result = connection.execute(text("SHOW TABLES"))
- table_name_list = [row[0] for row in result]
- return table_name_list
-
- @staticmethod
- async def get_rand_data(database_url, table_name, cnt=10):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- try:
- with sessionmaker(engine)() as session:
- sql_str = f'''SELECT *
- FROM {table_name}
- ORDER BY RAND()
- LIMIT {cnt};'''
- dataframe = str(session.execute(text(sql_str)).all())
- except Exception as e:
- dataframe = ''
- logging.error(f'随机从数据库中获取数据失败由于{e}')
- return dataframe
-
- @staticmethod
- async def try_excute(database_url, sql_str):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with sessionmaker(engine)() as session:
- result = session.execute(text(sql_str)).all()
- return Mysql.result_to_json(result)
diff --git a/chat2db/app/base/postgres.py b/chat2db/app/base/postgres.py
deleted file mode 100644
index a29a4427f34e689f632cbae79fd2337740ede824..0000000000000000000000000000000000000000
--- a/chat2db/app/base/postgres.py
+++ /dev/null
@@ -1,236 +0,0 @@
-import asyncio
-import asyncpg
-import concurrent.futures
-import logging
-from sqlalchemy.orm import sessionmaker
-from sqlalchemy import create_engine, text
-import sys
-from concurrent.futures import ThreadPoolExecutor
-from chat2db.app.base.meta_databbase import MetaDatabase
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-
-def handler(signum, frame):
- raise TimeoutError("超时")
-
-
-class Postgres(MetaDatabase):
- executor = ThreadPoolExecutor(max_workers=10)
-
- async def test_database_connection(database_url):
- try:
- with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
- future = executor.submit(Postgres._connect_and_query, database_url)
- result = future.result(timeout=5)
- return result
- except concurrent.futures.TimeoutError:
- logging.error('postgres数据库连接超时')
- return False
- except Exception as e:
- logging.error(f'postgres数据库连接失败由于{e}')
- return False
-
- @staticmethod
- def _connect_and_query(database_url):
- try:
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- session = sessionmaker(bind=engine)()
- session.execute(text("SELECT 1"))
- session.close()
- return True
- except Exception as e:
- raise e
-
- @staticmethod
- async def drop_table(database_url, table_name):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with sessionmaker(engine)() as session:
- sql_str = f"DROP TABLE IF EXISTS {table_name};"
- session.execute(text(sql_str))
-
- @staticmethod
- async def select_primary_key_and_keyword_from_table(database_url, table_name, keyword):
- try:
- dsn = database_url.replace('+psycopg2', '')
- conn = await asyncpg.connect(dsn=dsn)
- primary_key_query = """
- SELECT
- kcu.column_name
- FROM
- information_schema.table_constraints AS tc
- JOIN information_schema.key_column_usage AS kcu
- ON tc.constraint_name = kcu.constraint_name
- WHERE
- tc.constraint_type = 'PRIMARY KEY'
- AND tc.table_name = $1;
- """
- primary_key_list = await conn.fetch(primary_key_query, table_name)
- if not primary_key_list:
- return []
- columns = ', '.join([record['column_name'] for record in primary_key_list]) + f', {keyword}'
- query = f'SELECT {columns} FROM {table_name};'
- results = await conn.fetch(query)
-
- def _process_results(results, primary_key_list):
- tmp_dict = {}
- for row in results:
- key = str(row[-1])
- if key not in tmp_dict:
- tmp_dict[key] = []
- pk_values = [str(row[i]) for i in range(len(primary_key_list))]
- tmp_dict[key].append(pk_values)
-
- return {
- 'primary_key_list': [record['column_name'] for record in primary_key_list],
- 'keyword_value_dict': tmp_dict
- }
- result = await asyncio.get_event_loop().run_in_executor(
- Postgres.executor,
- _process_results,
- results,
- primary_key_list
- )
- await conn.close()
-
- return result
- except Exception as e:
- logging.error(f'postgres数据检索失败由于 {e}')
- return None
-
- @staticmethod
- async def assemble_sql_query_base_on_primary_key(table_name, primary_key_list, primary_key_value_list):
- sql_str = f'SELECT * FROM {table_name} where '
- for i in range(len(primary_key_list)):
- sql_str += primary_key_list[i]+'='+'\''+primary_key_value_list[i]+'\''
- if i != len(primary_key_list)-1:
- sql_str += ' and '
- sql_str += ';'
- return sql_str
-
- @staticmethod
- async def get_table_info(database_url, table_name):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with engine.connect() as conn:
- sql_str = """
- SELECT
- d.description AS table_description
- FROM
- pg_class t
- JOIN
- pg_description d ON t.oid = d.objoid
- WHERE
- t.relkind = 'r' AND
- d.objsubid = 0 AND
- t.relname = :table_name; """
- result = conn.execute(text(sql_str), {'table_name': table_name}).one_or_none()
- if result is None:
- table_note = table_name
- else:
- table_note = result[0]
- table_note = {
- 'table_name': table_name,
- 'table_note': table_note
- }
- return table_note
-
- @staticmethod
- async def get_column_info(database_url, table_name):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with engine.connect() as conn:
- sql_str = """
- SELECT
- a.attname as 字段名,
- format_type(a.atttypid,a.atttypmod) as 类型,
- col_description(a.attrelid,a.attnum) as 注释
- FROM
- pg_class as c,pg_attribute as a
- where
- a.attrelid = c.oid
- and
- a.attnum>0
- and
- c.relname = :table_name;
- """
- results = conn.execute(text(sql_str), {'table_name': table_name}).all()
- column_info_list = []
- for result in results:
- column_info_list.append({'column_name': result[0], 'column_type': result[1], 'column_note': result[2]})
- return column_info_list
-
- @staticmethod
- async def get_all_table_name_from_database_url(database_url):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with engine.connect() as connection:
- sql_str = '''
- SELECT table_name
- FROM information_schema.tables
- WHERE table_schema = 'public';
- '''
- result = connection.execute(text(sql_str))
- table_name_list = [row[0] for row in result]
- return table_name_list
-
- @staticmethod
- async def get_rand_data(database_url, table_name, cnt=10):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- try:
- with sessionmaker(engine)() as session:
- sql_str = f'''SELECT *
- FROM {table_name}
- ORDER BY RANDOM()
- LIMIT {cnt};'''
- dataframe = str(session.execute(text(sql_str)).all())
- except Exception as e:
- dataframe = ''
- logging.error(f'随机从数据库中获取数据失败由于{e}')
- return dataframe
-
- @staticmethod
- async def try_excute(database_url, sql_str):
- engine = create_engine(
- database_url,
- pool_size=20,
- max_overflow=80,
- pool_recycle=300,
- pool_pre_ping=True
- )
- with sessionmaker(engine)() as session:
- result=session.execute(text(sql_str)).all()
- return Postgres.result_to_json(result)
diff --git a/chat2db/app/router/database.py b/chat2db/app/router/database.py
deleted file mode 100644
index 37aacca406d2de39aaa56b983bf7d65b3d29f2e3..0000000000000000000000000000000000000000
--- a/chat2db/app/router/database.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-
-import logging
-import uuid
-from fastapi import APIRouter, status
-from typing import Optional
-import sys
-from chat2db.model.request import DatabaseAddRequest, DatabaseDelRequest, DatabaseSqlGenerateRequest
-from chat2db.model.response import ResponseData
-from chat2db.manager.database_info_manager import DatabaseInfoManager
-from chat2db.manager.table_info_manager import TableInfoManager
-from chat2db.manager.column_info_manager import ColumnInfoManager
-from chat2db.app.service.diff_database_service import DiffDatabaseService
-from chat2db.app.service.sql_generate_service import SqlGenerateService
-from chat2db.app.service.keyword_service import keyword_service
-from chat2db.app.base.vectorize import Vectorize
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-router = APIRouter(
- prefix="/database"
-)
-
-
-@router.post("/add", response_model=ResponseData)
-async def add_database_info(request: DatabaseAddRequest):
- database_url = request.database_url
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- if not DiffDatabaseService.is_database_type_allow(database_type):
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="不支持当前数据库",
- result={}
- )
- flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="无法连接当前数据库",
- result={}
- )
- database_id = await DatabaseInfoManager.add_database(database_url)
- if database_id is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="数据库连接添加失败,当前存在重复数据库配置",
- result={'database_id': database_id}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="success",
- result={'database_id': database_id}
- )
-
-
-@router.post("/del", response_model=ResponseData)
-async def del_database_info(request: DatabaseDelRequest):
- database_id = request.database_id
- database_url = request.database_url
- if database_id:
- flag = await DatabaseInfoManager.del_database_by_id(database_id)
- else:
- flag = await DatabaseInfoManager.del_database_by_url(database_url)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="删除数据库配置失败,数据库配置不存在",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="删除数据库配置成功",
- result={}
- )
-
-
-@router.get("/query", response_model=ResponseData)
-async def query_database_info():
- database_info_list = await DatabaseInfoManager.get_all_database_info()
- return ResponseData(
- code=status.HTTP_200_OK,
- message="查询数据库配置成功",
- result={'database_info_list': database_info_list}
- )
-
-
-@router.get("/list", response_model=ResponseData)
-async def list_table_in_database(database_id: uuid.UUID, table_filter: str = ''):
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- if database_url is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="查询数据库内表格配置失败,数据库配置不存在",
- result={}
- )
- if not DiffDatabaseService.is_database_type_allow(database_type):
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="不支持当前数据库",
- result={}
- )
- flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="无法连接当前数据库",
- result={}
- )
- table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url)
- results = []
- for table_name in table_name_list:
- if table_filter in table_name:
- results.append(table_name)
- return ResponseData(
- code=status.HTTP_200_OK,
- message="查询数据库配置成功",
- result={'table_name_list': results}
- )
-
-
-@router.post("/sql", response_model=ResponseData)
-async def generate_sql_from_database(request: DatabaseSqlGenerateRequest):
- database_url = request.database_url
- table_name_list = request.table_name_list
- question = request.question
- use_llm_enhancements = request.use_llm_enhancements
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- if not DiffDatabaseService.is_database_type_allow(database_type):
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="不支持当前数据库",
- result={}
- )
- flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="无法连接当前数据库",
- result={}
- )
- tmp_table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url)
- database_id = await DatabaseInfoManager.get_database_id_by_url(database_url)
- if database_id is None:
- database_id = await DatabaseInfoManager.add_database(database_url)
- for table_name in tmp_table_name_list:
- try:
- tmp_dict = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name)
- table_note = tmp_dict['table_note']
- table_note_vector = await Vectorize.vectorize_embedding(table_note)
- table_id = await TableInfoManager.add_table_info(database_id, table_name, table_note, table_note_vector)
- column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name)
- for column_info in column_info_list:
- await ColumnInfoManager.add_column_info_with_table_id(
- table_id, column_info['column_name'],
- column_info['column_type'],
- column_info['column_note'])
- except Exception as e:
- import traceback
- logging.error(f'{table_name}')
- logging.error(f'表格信息获取失败由于:{traceback.format_exc()}')
- continue
- if table_name_list:
- table_id_list = []
- for table_name in table_name_list:
- table_id = await TableInfoManager.get_table_id_by_database_id_and_table_name(database_id, table_name)
- if table_id is None:
- continue
- table_id_list.append(table_id)
- else:
- table_id_list = None
- results = {}
- sql_list = await SqlGenerateService.generate_sql_base_on_example(
- database_id=database_id, question=question, table_id_list=table_id_list,
- use_llm_enhancements=use_llm_enhancements)
- try:
- sql_list += await keyword_service.generate_sql(question, database_id, table_id_list)
- results['sql_list'] = sql_list[:request.topk]
- results['database_url'] = database_url
- except Exception as e:
- logging.error(f'sql生成失败由于{e}')
- return ResponseData(
- code=status.HTTP_400_BAD_REQUEST,
- message="sql生成失败",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK, message="success",
- result=results
- )
diff --git a/chat2db/app/router/sql_example.py b/chat2db/app/router/sql_example.py
deleted file mode 100644
index 08f913912211a71646ebeb33bed571a46f95d1dc..0000000000000000000000000000000000000000
--- a/chat2db/app/router/sql_example.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-
-import logging
-import uuid
-from fastapi import APIRouter, status
-import sys
-
-from chat2db.model.request import SqlExampleAddRequest, SqlExampleDelRequest, SqlExampleUpdateRequest, SqlExampleGenerateRequest
-from chat2db.model.response import ResponseData
-from chat2db.manager.database_info_manager import DatabaseInfoManager
-from chat2db.manager.table_info_manager import TableInfoManager
-from chat2db.manager.sql_example_manager import SqlExampleManager
-from chat2db.app.service.sql_generate_service import SqlGenerateService
-from chat2db.app.base.vectorize import Vectorize
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-router = APIRouter(
- prefix="/sql/example"
-)
-
-
-@router.post("/add", response_model=ResponseData)
-async def add_sql_example(request: SqlExampleAddRequest):
- table_id = request.table_id
- table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
- if table_info is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格不存在",
- result={}
- )
- database_id = table_info['database_id']
- question = request.question
- question_vector = await Vectorize.vectorize_embedding(question)
- sql = request.sql
- try:
- sql_example_id = await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector)
- except Exception as e:
- logging.error(f'sql案例添加失败由于{e}')
- return ResponseData(
- code=status.HTTP_400_BAD_REQUEST,
- message="sql案例添加失败",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="success",
- result={'sql_example_id': sql_example_id}
- )
-
-
-@router.post("/del", response_model=ResponseData)
-async def del_sql_example(request: SqlExampleDelRequest):
- sql_example_id = request.sql_example_id
- flag = await SqlExampleManager.del_sql_example_by_id(sql_example_id)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="sql案例不存在",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="sql案例删除成功",
- result={}
- )
-
-
-@router.get("/query", response_model=ResponseData)
-async def query_sql_example(table_id: uuid.UUID):
- sql_example_list = await SqlExampleManager.query_sql_example_by_table_id(table_id)
- return ResponseData(
- code=status.HTTP_200_OK,
- message="查询sql案例成功",
- result={'sql_example_list': sql_example_list}
- )
-
-
-@router.post("/update", response_model=ResponseData)
-async def update_sql_example(request: SqlExampleUpdateRequest):
- sql_example_id = request.sql_example_id
- question = request.question
- question_vector = await Vectorize.vectorize_embedding(question)
- sql = request.sql
- flag = await SqlExampleManager.update_sql_example_by_id(sql_example_id, question, sql, question_vector)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="sql案例不存在",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="sql案例更新成功",
- result={}
- )
-
-
-@router.post("/generate", response_model=ResponseData)
-async def generate_sql_example(request: SqlExampleGenerateRequest):
- table_id = request.table_id
- generate_cnt = request.generate_cnt
- table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
- if table_info is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格不存在",
- result={}
- )
- table_name = table_info['table_name']
- database_id = table_info['database_id']
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- sql_var = request.sql_var
- sql_example_list = []
- for i in range(generate_cnt):
- try:
- tmp_dict = await SqlGenerateService.generate_sql_base_on_data(database_url, table_name, sql_var)
- except Exception as e:
- logging.error(f'sql案例生成失败由于{e}')
- continue
- if tmp_dict is None:
- continue
- question = tmp_dict['question']
- question_vector = await Vectorize.vectorize_embedding(question)
- sql = tmp_dict['sql']
- await SqlExampleManager.add_sql_example(question, sql, table_id, question_vector)
- tmp_dict['database_id'] = database_id
- tmp_dict['table_id'] = table_id
- sql_example_list.append(tmp_dict)
- return ResponseData(
- code=status.HTTP_200_OK,
- message="sql案例生成成功",
- result={
- 'sql_example_list': sql_example_list
- }
- )
diff --git a/chat2db/app/router/sql_generate.py b/chat2db/app/router/sql_generate.py
deleted file mode 100644
index 69ff0d2bb0e7d114151d29c56994dac4b5754d1a..0000000000000000000000000000000000000000
--- a/chat2db/app/router/sql_generate.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-
-import logging
-from fastapi import APIRouter, status
-import sys
-
-from chat2db.manager.database_info_manager import DatabaseInfoManager
-from chat2db.manager.table_info_manager import TableInfoManager
-from chat2db.manager.column_info_manager import ColumnInfoManager
-from chat2db.model.request import SqlGenerateRequest, SqlRepairRequest, SqlExcuteRequest
-from chat2db.model.response import ResponseData
-from chat2db.app.service.sql_generate_service import SqlGenerateService
-from chat2db.app.service.keyword_service import keyword_service
-from chat2db.app.service.diff_database_service import DiffDatabaseService
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-router = APIRouter(
- prefix="/sql"
-)
-
-
-@router.post("/generate", response_model=ResponseData)
-async def generate_sql(request: SqlGenerateRequest):
- database_id = request.database_id
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- table_id_list = request.table_id_list
- question = request.question
- use_llm_enhancements = request.use_llm_enhancements
- results = {}
- sql_list = await SqlGenerateService.generate_sql_base_on_example(
- database_id=database_id, question=question, table_id_list=table_id_list,
- use_llm_enhancements=use_llm_enhancements)
- try:
- sql_list += await keyword_service.generate_sql(question, database_id, table_id_list)
- results['sql_list'] = sql_list[:request.topk]
- results['database_url'] = database_url
- except Exception as e:
- logging.error(f'sql生成失败由于{e}')
- return ResponseData(
- code=status.HTTP_400_BAD_REQUEST,
- message="sql生成失败",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK, message="success",
- result=results
- )
-
-
-@router.post("/repair", response_model=ResponseData)
-async def repair_sql(request: SqlRepairRequest):
- database_id = request.database_id
- table_id = request.table_id
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- if database_url is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="当前数据库配置不存在",
- result={}
- )
- table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
- if table_info is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格不存在",
- result={}
- )
- if table_info['database_id'] != database_id:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格不属于当前数据库",
- result={}
- )
- column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
- sql = request.sql
- message = request.message
- question = request.question
- try:
- sql = await SqlGenerateService.repair_sql(database_type, table_info, column_info_list, sql, message, question)
- except Exception as e:
- logging.error(f'sql修复失败由于{e}')
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="sql修复失败",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="sql修复成功",
- result={'database_id': database_id,
- 'table_id': table_id,
- 'sql': sql}
- )
-
-
-@router.post("/execute", response_model=ResponseData)
-async def execute_sql(request: SqlExcuteRequest):
- database_id = request.database_id
- sql = request.sql
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- if database_url is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="当前数据库配置不存在",
- result={}
- )
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- try:
- results = await DiffDatabaseService.database_map[database_type].try_excute(database_url, sql)
- except Exception as e:
- import traceback
- logging.error(f'sql执行失败由于{traceback.format_exc()}')
- return ResponseData(
- code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- message="sql执行失败",
- result={'Error': str(e)}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="sql执行成功",
- result=results
- )
diff --git a/chat2db/app/router/table.py b/chat2db/app/router/table.py
deleted file mode 100644
index 33ca4f9940bea6d3e60ee2adaf96be30a95d69ff..0000000000000000000000000000000000000000
--- a/chat2db/app/router/table.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-
-import logging
-import uuid
-from fastapi import APIRouter, status
-import sys
-
-from chat2db.model.request import TableAddRequest, TableDelRequest, EnableColumnRequest
-from chat2db.model.response import ResponseData
-from chat2db.manager.database_info_manager import DatabaseInfoManager
-from chat2db.manager.table_info_manager import TableInfoManager
-from chat2db.manager.column_info_manager import ColumnInfoManager
-from chat2db.app.service.diff_database_service import DiffDatabaseService
-from chat2db.app.base.vectorize import Vectorize
-from chat2db.app.service.keyword_service import keyword_service
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-router = APIRouter(
- prefix="/table"
-)
-
-
-@router.post("/add", response_model=ResponseData)
-async def add_database_info(request: TableAddRequest):
- database_id = request.database_id
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- if database_url is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="当前数据库配置不存在",
- result={}
- )
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="无法连接当前数据库",
- result={}
- )
- table_name = request.table_name
- table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url)
- if table_name not in table_name_list:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格不存在",
- result={}
- )
- tmp_dict = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name)
- table_note = tmp_dict['table_note']
- table_note_vector = await Vectorize.vectorize_embedding(table_note)
- table_id = await TableInfoManager.add_table_info(database_id, table_name, table_note, table_note_vector)
- if table_id is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格添加失败,当前存在重复表格",
- result={}
- )
- column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name)
- for column_info in column_info_list:
- await ColumnInfoManager.add_column_info_with_table_id(
- table_id, column_info['column_name'],
- column_info['column_type'],
- column_info['column_note'])
- return ResponseData(
- code=status.HTTP_200_OK,
- message="success",
- result={'table_id': table_id}
- )
-
-
-@router.post("/del", response_model=ResponseData)
-async def del_table_info(request: TableDelRequest):
- table_id = request.table_id
- flag = await TableInfoManager.del_table_by_id(table_id)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="表格不存在",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="删除表格成功",
- result={}
- )
-
-
-@router.get("/query", response_model=ResponseData)
-async def query_table_info(database_id: uuid.UUID):
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- if database_url is None:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="当前数据库配置不存在",
- result={}
- )
- table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
- return ResponseData(
- code=status.HTTP_200_OK,
- message="查询表格成功",
- result={'table_info_list': table_info_list}
- )
-
-
-@router.get("/column/query", response_model=ResponseData)
-async def query_column(table_id: uuid.UUID):
- column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
- return ResponseData(
- code=status.HTTP_200_OK,
- message="",
- result={'column_info_list': column_info_list}
- )
-
-
-@router.post("/column/enable", response_model=ResponseData)
-async def enable_column(request: EnableColumnRequest):
- column_id = request.column_id
- enable = request.enable
- flag = await ColumnInfoManager.update_column_info_enable(column_id, enable)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="列不存在",
- result={}
- )
- column_info = await ColumnInfoManager.get_column_info_by_column_id(column_id)
- column_name = column_info['column_name']
- table_id = column_info['table_id']
- table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
- database_id = table_info['database_id']
- if enable:
- flag = await keyword_service.add(database_id, table_id, column_name)
- else:
- flag = await keyword_service.del_by_column_name(database_id, table_id, column_name)
- if not flag:
- return ResponseData(
- code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- message="列关键字功能开启/关闭失败",
- result={}
- )
- return ResponseData(
- code=status.HTTP_200_OK,
- message="列关键字功能开启/关闭成功",
- result={}
- )
diff --git a/chat2db/app/service/diff_database_service.py b/chat2db/app/service/diff_database_service.py
deleted file mode 100644
index bb9f979679339182e5213d782c5ba0b5be3d047f..0000000000000000000000000000000000000000
--- a/chat2db/app/service/diff_database_service.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import re
-from urllib.parse import urlparse
-from chat2db.app.base.mysql import Mysql
-from chat2db.app.base.postgres import Postgres
-
-
-class DiffDatabaseService():
- database_types = ["mysql", "postgresql", "opengauss"]
- database_map = {"mysql": Mysql, "postgresql": Postgres, "opengauss": Postgres}
-
- @staticmethod
- def get_database_service(database_type):
- if database_type not in DiffDatabaseService.database_types:
- raise f"不支持当前数据库类型{database_type}"
- return DiffDatabaseService.database_map[database_type]
-
- @staticmethod
- def get_database_type_from_url(database_url):
- result = urlparse(database_url)
- try:
- database_type = result.scheme.split('+')[0]
- except Exception as e:
- raise e
- return database_type.lower()
-
- @staticmethod
- def is_database_type_allow(database_type):
- return database_type in DiffDatabaseService.database_types
diff --git a/chat2db/app/service/keyword_service.py b/chat2db/app/service/keyword_service.py
deleted file mode 100644
index 685c341b5f106943b3e21eb5fe7367a4d4b7669b..0000000000000000000000000000000000000000
--- a/chat2db/app/service/keyword_service.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-import asyncio
-import copy
-import uuid
-import sys
-import threading
-from concurrent.futures import ThreadPoolExecutor
-from chat2db.app.service.diff_database_service import DiffDatabaseService
-from chat2db.app.base.ac_automation import DictTree
-from chat2db.manager.database_info_manager import DatabaseInfoManager
-from chat2db.manager.table_info_manager import TableInfoManager
-from chat2db.manager.column_info_manager import ColumnInfoManager
-import logging
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-
-class KeywordManager():
- def __init__(self):
- self.keyword_asset_dict = {}
- self.lock = threading.Lock()
- self.data_frame_dict = {}
-
- async def load_keywords(self):
- database_info_list = await DatabaseInfoManager.get_all_database_info()
- for database_info in database_info_list:
- database_id = database_info['database_id']
- table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
- cnt=0
- for table_info in table_info_list:
- table_id = table_info['table_id']
- column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id, True)
- for i in range(len(column_info_list)):
- column_info = column_info_list[i]
- cnt+=1
- try:
- column_name = column_info['column_name']
- await self.add(database_id, table_id, column_name)
- except Exception as e:
- logging.error('关键字数据结构生成失败')
- def add_excutor(self, rd_id, database_id, table_id, table_info, column_info_list, column_name):
- tmp_dict = self.data_frame_dict[rd_id]
- tmp_dict_tree = DictTree()
- tmp_dict_tree.load_data(tmp_dict['keyword_value_dict'])
- if database_id not in self.keyword_asset_dict.keys():
- self.keyword_asset_dict[database_id] = {}
- with self.lock:
- if table_id not in self.keyword_asset_dict[database_id].keys():
- self.keyword_asset_dict[database_id][table_id] = {}
- self.keyword_asset_dict[database_id][table_id]['table_info'] = table_info
- self.keyword_asset_dict[database_id][table_id]['column_info_list'] = column_info_list
- self.keyword_asset_dict[database_id][table_id]['primary_key_list'] = copy.deepcopy(
- tmp_dict['primary_key_list'])
- self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'] = {}
- self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name] = tmp_dict_tree
- del self.data_frame_dict[rd_id]
-
- async def add(self, database_id, table_id, column_name):
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
- table_name = table_info['table_name']
- tmp_dict = await DiffDatabaseService.get_database_service(
- database_type).select_primary_key_and_keyword_from_table(database_url, table_name, column_name)
- if tmp_dict is None:
- return
- rd_id = str(uuid.uuid4)
- self.data_frame_dict[rd_id] = tmp_dict
- del database_url
- column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
- try:
- thread = threading.Thread(target=self.add_excutor, args=(rd_id, database_id, table_id,
- table_info, column_info_list, column_name,))
- thread.start()
- except Exception as e:
- logging.error(f'创建增加线程失败由于{e}')
- return False
- return True
-
- async def update_keyword_asset(self):
- database_info_list = DatabaseInfoManager.get_all_database_info()
- for database_info in database_info_list:
- database_id = database_info['database_id']
- table_info_list = TableInfoManager.get_table_info_by_database_id(database_id)
- for table_info in table_info_list:
- table_id = table_info['table_id']
- column_info_list = ColumnInfoManager.get_column_info_by_table_id(table_id, True)
- for column_info in column_info_list:
- await self.add(database_id, table_id, column_info['column_name'])
-
- async def del_by_column_name(self, database_id, table_id, column_name):
- try:
- with self.lock:
- if database_id in self.keyword_asset_dict.keys():
- if table_id in self.keyword_asset_dict[database_id].keys():
- if column_name in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].keys():
- del self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'][column_name]
- except Exception as e:
- logging.error(f'字典树删除失败由于{e}')
- return False
- return True
-
- async def generate_sql(self, question, database_id, table_id_list=None):
- with self.lock:
- results = []
- if database_id in self.keyword_asset_dict.keys():
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- for table_id in self.keyword_asset_dict[database_id].keys():
- if table_id_list is None or table_id in table_id_list:
- table_info = self.keyword_asset_dict[database_id][table_id]['table_info']
- primary_key_list = self.keyword_asset_dict[database_id][table_id]['primary_key_list']
- primary_key_value_list = []
- try:
- for dict_tree in self.keyword_asset_dict[database_id][table_id]['dict_tree_dict'].values():
- primary_key_value_list += dict_tree.get_results(question)
- except Exception as e:
- logging.error(f'从字典树中获取结果失败由于{e}')
- continue
- for i in range(len(primary_key_value_list)):
- sql_str = await DiffDatabaseService.get_database_service(database_type).assemble_sql_query_base_on_primary_key(
- table_info['table_name'], primary_key_list, primary_key_value_list[i])
- tmp_dict = {'database_id': database_id, 'table_id': table_id, 'sql': sql_str}
- results.append(tmp_dict)
- del database_url
- return results
-
-
-keyword_service = KeywordManager()
-asyncio.run(keyword_service.load_keywords())
diff --git a/chat2db/app/service/sql_generate_service.py b/chat2db/app/service/sql_generate_service.py
deleted file mode 100644
index f20f97706650424d862ad7d9a6036b439795379c..0000000000000000000000000000000000000000
--- a/chat2db/app/service/sql_generate_service.py
+++ /dev/null
@@ -1,363 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-import asyncio
-import yaml
-import re
-import json
-import random
-import sys
-import uuid
-import logging
-from pandas.core.api import DataFrame as DataFrame
-
-from chat2db.manager.database_info_manager import DatabaseInfoManager
-from chat2db.manager.table_info_manager import TableInfoManager
-from chat2db.manager.column_info_manager import ColumnInfoManager
-from chat2db.manager.sql_example_manager import SqlExampleManager
-from chat2db.app.service.diff_database_service import DiffDatabaseService
-from chat2db.llm.chat_with_model import LLM
-from chat2db.config.config import config
-from chat2db.app.base.vectorize import Vectorize
-
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-
-class SqlGenerateService():
-
- @staticmethod
- async def merge_table_and_column_info(table_info, column_info_list):
- table_name = table_info.get('table_name', '')
- table_note = table_info.get('table_note', '')
- note = '
\n'
- note += '\n'+'表名 | \n'+'
\n'
- note += '\n'+f'{table_name} | \n'+'
\n'
- note += '\n'+'表的注释 | \n'+'
\n'
- note += '\n'+f'{table_note} | \n'+'
\n'
- note += '\n'+' 字段 | \n字段类型 | \n字段注释 | \n'+'
\n'
- for column_info in column_info_list:
- column_name = column_info.get('column_name', '')
- column_type = column_info.get('column_type', '')
- column_note = column_info.get('column_note', '')
- note += '\n'+f' {column_name} | \n{column_type} | \n{column_note} | \n'+'
\n'
- note += '
'
- return note
-
- @staticmethod
- def extract_list_statements(list_string):
- pattern = r'\[.*?\]'
- matches = re.findall(pattern, list_string)
- if len(matches) == 0:
- return ''
- tmp = matches[0]
- tmp = tmp.replace('\'', '\"')
- tmp = tmp.replace(',', ',')
- return tmp
-
- @staticmethod
- async def get_most_similar_table_id_list(database_id, question, table_choose_cnt):
- table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
- random.shuffle(table_info_list)
- table_id_set = set()
- for table_info in table_info_list:
- table_id = table_info['table_id']
- table_id_set.add(str(table_id))
- try:
- with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt = prompt_dict.get('table_choose_prompt', '')
- table_entries = '\n'
- table_entries += '\n'+' 主键 | \n表注释 | \n'+'
\n'
- token_upper = 2048
- for table_info in table_info_list:
- table_id = table_info['table_id']
- table_note = table_info['table_note']
- if len(table_entries) + len(
- '\n' + f' {table_id} | \n{table_note} | \n' + '
\n') > token_upper:
- break
- table_entries += '\n'+f' {table_id} | \n{table_note} | \n'+'
\n'
- table_entries += '
'
- prompt = prompt.format(table_cnt=table_choose_cnt, table_entries=table_entries, question=question)
- # logging.info(f'在大模型增强模式下,选择表的prompt构造成功:{prompt}')
- except Exception as e:
- logging.error(f'在大模型增强模式下,选择表的prompt构造失败由于:{e}')
- return []
- try:
- llm = LLM(model_name=config['LLM_MODEL'],
- openai_api_base=config['LLM_URL'],
- openai_api_key=config['LLM_KEY'],
- max_tokens=config['LLM_MAX_TOKENS'],
- request_timeout=60,
- temperature=0.5)
- except Exception as e:
- llm = None
- logging.error(f'在大模型增强模式下,选择表的过程中,与大模型建立连接失败由于:{e}')
- table_id_list = []
- if llm is not None:
- for i in range(2):
- content = await llm.chat_with_model(prompt, '请输包含选择表主键的列表')
- try:
- sub_table_id_list = json.loads(SqlGenerateService.extract_list_statements(content))
- except:
- sub_table_id_list = []
- for j in range(len(sub_table_id_list)):
- if sub_table_id_list[j] in table_id_set and uuid.UUID(sub_table_id_list[j]) not in table_id_list:
- table_id_list.append(uuid.UUID(sub_table_id_list[j]))
- if len(table_id_list) < table_choose_cnt:
- table_choose_cnt -= len(table_id_list)
- for i in range(min(table_choose_cnt, len(table_info_list))):
- table_id = table_info_list[i]['table_id']
- if table_id is not None and table_id not in table_id_list:
- table_id_list.append(table_id)
- return table_id_list
-
- @staticmethod
- async def find_most_similar_sql_example(
- database_id, table_id_list, question, use_llm_enhancements=False, table_choose_cnt=2, sql_example_choose_cnt=10,
- topk=5):
- try:
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- except Exception as e:
- logging.error(f'数据库{database_id}信息获取失败由于{e}')
- return []
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- del database_url
- try:
- question_vector = await Vectorize.vectorize_embedding(question)
- except Exception as e:
- logging.error(f'问题向量化失败由于:{e}')
- return {}
- sql_example = []
- data_frame_list = []
- if table_id_list is None:
- if use_llm_enhancements:
- table_id_list = await SqlGenerateService.get_most_similar_table_id_list(database_id, question, table_choose_cnt)
- else:
- try:
- table_info_list = await TableInfoManager.get_table_info_by_database_id(database_id)
- table_id_list = []
- for table_info in table_info_list:
- table_id_list.append(table_info['table_id'])
- max_retry = 3
- sql_example_list = []
- for _ in range(max_retry):
- try:
- sql_example_list = await asyncio.wait_for(SqlExampleManager.get_topk_sql_example_by_cos_dis(
- question_vector=question_vector,
- table_id_list=table_id_list, topk=table_choose_cnt * 2),
- timeout=5
- )
- break
- except Exception as e:
- logging.error(f'非增强模式下,sql_example获取失败:{e}')
- table_id_list = []
- for sql_example in sql_example_list:
- table_id_list.append(sql_example['table_id'])
- except Exception as e:
- logging.error(f'非增强模式下,表id获取失败由于:{e}')
- return []
- table_id_list = list(set(table_id_list))
- if len(table_id_list) < table_choose_cnt:
- try:
- expand_table_id_list = await asyncio.wait_for(TableInfoManager.get_topk_table_by_cos_dis(
- database_id, question_vector, table_choose_cnt - len(table_id_list)), timeout=5
- )
- table_id_list += expand_table_id_list
- except Exception as e:
- logging.error(f'非增强模式下,表id补充失败由于:{e}')
- exist_table_id = set()
- note_list = []
- for i in range(min(2, len(table_id_list))):
- table_id = table_id_list[i]
- if table_id in exist_table_id:
- continue
- exist_table_id.add(table_id)
- try:
- table_info = await TableInfoManager.get_table_info_by_table_id(table_id)
- column_info_list = await ColumnInfoManager.get_column_info_by_table_id(table_id)
- except Exception as e:
- logging.error(f'表{table_id}注释获取失败由于{e}')
- continue
- note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
- note_list.append(note)
- max_retry = 3
- sql_example_list = []
- for _ in range(max_retry):
- try:
- sql_example_list = await asyncio.wait_for(SqlExampleManager.get_topk_sql_example_by_cos_dis(
- question_vector,
- table_id_list=[table_id],
- topk=sql_example_choose_cnt),
- timeout=5
- )
- break
- except Exception as e:
- logging.error(f'获取id为{table_id}的表的最相近的{topk}条sql案例失败由于:{e}')
- question_sql_list = []
- for i in range(len(sql_example_list)):
- question_sql_list.append(
- {'question': sql_example_list[i]['question'],
- 'sql': sql_example_list[i]['sql']})
- data_frame_list.append({'table_id': table_id, 'table_info': table_info,
- 'column_info_list': column_info_list, 'sql_example_list': question_sql_list})
- return data_frame_list
-
- @staticmethod
- async def merge_sql_example(sql_example_list):
- sql_example = ''
- for i in range(len(sql_example_list)):
- sql_example += '问题'+str(i)+':\n'+sql_example_list[i].get('question',
- '')+'\nsql'+str(i)+':\n'+sql_example_list[i].get('sql', '')+'\n'
- return sql_example
-
- @staticmethod
- async def extract_select_statements(sql_string):
- pattern = r"(?i)select[^;]*;"
- matches = re.findall(pattern, sql_string)
- if len(matches) == 0:
- return ''
- sql = matches[0]
- sql = sql.strip()
- sql.replace(',', ',')
- return sql
-
- @staticmethod
- async def generate_sql_base_on_example(
- database_id, question, table_id_list=None, sql_generate_cnt=1, use_llm_enhancements=False):
- try:
- database_url = await DatabaseInfoManager.get_database_url_by_id(database_id)
- except Exception as e:
- logging.error(f'数据库{database_id}信息获取失败由于{e}')
- return {}
- if database_url is None:
- raise Exception('数据库配置不存在')
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- data_frame_list = await SqlGenerateService.find_most_similar_sql_example(database_id, table_id_list, question, use_llm_enhancements)
- try:
- with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- llm = LLM(model_name=config['LLM_MODEL'],
- openai_api_base=config['LLM_URL'],
- openai_api_key=config['LLM_KEY'],
- max_tokens=config['LLM_MAX_TOKENS'],
- request_timeout=60,
- temperature=0.5)
- results = []
- for data_frame in data_frame_list:
- prompt = prompt_dict.get('sql_generate_base_on_example_prompt', '')
- table_info = data_frame.get('table_info', '')
- table_id = table_info['table_id']
- column_info_list = data_frame.get('column_info_list', '')
- note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
- sql_example = await SqlGenerateService.merge_sql_example(data_frame.get('sql_example_list', []))
- try:
- prompt = prompt.format(
- database_url=database_url, note=note, k=len(data_frame.get('sql_example_list', [])),
- sql_example=sql_example, question=question)
- except Exception as e:
- logging.info(f'sql生成失败{e}')
- return []
- ge_cnt = 0
- ge_sql_cnt = 0
- while ge_cnt < 10*sql_generate_cnt and ge_sql_cnt < sql_generate_cnt:
- sql = await llm.chat_with_model(prompt, f'请输出一条在与{database_type}下能运行的sql,以分号结尾')
- sql = await SqlGenerateService.extract_select_statements(sql)
- if len(sql):
- ge_sql_cnt += 1
- tmp_dict = {'database_id': database_id, 'table_id': table_id, 'sql': sql}
- results.append(tmp_dict)
- ge_cnt += 1
- if len(results) == sql_generate_cnt:
- break
- except Exception as e:
- logging.error(f'sql生成失败由于:{e}')
- return results
-
- @staticmethod
- async def generate_sql_base_on_data(database_url, table_name, sql_var=False):
- database_type = None
- database_type = DiffDatabaseService.get_database_type_from_url(database_url)
- flag = await DiffDatabaseService.get_database_service(database_type).test_database_connection(database_url)
- if not flag:
- return None
- table_name_list = await DiffDatabaseService.get_database_service(database_type).get_all_table_name_from_database_url(database_url)
- if table_name not in table_name_list:
- return None
- table_info = await DiffDatabaseService.get_database_service(database_type).get_table_info(database_url, table_name)
- column_info_list = await DiffDatabaseService.get_database_service(database_type).get_column_info(database_url, table_name)
- note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
-
- def count_char(str, char):
- return sum(1 for c in str if c == char)
- llm = LLM(model_name=config['LLM_MODEL'],
- openai_api_base=config['LLM_URL'],
- openai_api_key=config['LLM_KEY'],
- max_tokens=config['LLM_MAX_TOKENS'],
- request_timeout=60,
- temperature=0.5)
- for i in range(5):
- data_frame = await DiffDatabaseService.get_database_service(database_type).get_rand_data(database_url, table_name)
- try:
- with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt = prompt_dict['question_generate_base_on_data_prompt'].format(
- note=note, data_frame=data_frame)
- question = await llm.chat_with_model(prompt, '请输出一个问题')
- if count_char(question, '?') > 1 or count_char(question, '?') > 1:
- continue
- except Exception as e:
- logging.error(f'问题生成失败由于{e}')
- continue
- try:
- with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt = prompt_dict['sql_generate_base_on_data_prompt'].format(
- database_type=database_type,
- note=note, data_frame=data_frame, question=question)
- sql = await llm.chat_with_model(prompt, f'请输出一条可以用于查询{database_type}的sql,要以分号结尾')
- sql = await SqlGenerateService.extract_select_statements(sql)
- if not sql:
- continue
- except Exception as e:
- logging.error(f'sql生成失败由于{e}')
- continue
- try:
- if sql_var:
- await DiffDatabaseService.get_database_service(database_type).try_excute(database_url, sql)
- except Exception as e:
- logging.error(f'生成的sql执行失败由于{e}')
- continue
- return {
- 'question': question,
- 'sql': sql
- }
- return None
-
- @staticmethod
- async def repair_sql(database_type, table_info, column_info_list, sql_failed, sql_failed_message, question):
- try:
- with open('./chat2db/templetes/prompt.yaml', 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- llm = LLM(model_name=config['LLM_MODEL'],
- openai_api_base=config['LLM_URL'],
- openai_api_key=config['LLM_KEY'],
- max_tokens=config['LLM_MAX_TOKENS'],
- request_timeout=60,
- temperature=0.5)
- try:
- note = await SqlGenerateService.merge_table_and_column_info(table_info, column_info_list)
- prompt = prompt_dict.get('sql_expand_prompt', '')
- prompt = prompt.format(
- database_type=database_type, note=note, sql_failed=sql_failed,
- sql_failed_message=sql_failed_message,
- question=question)
- except Exception as e:
- logging.error(f'sql修复失败由于{e}')
- return ''
- sql = await llm.chat_with_model(prompt, f'请输出一条在与{database_type}下能运行的sql,要以分号结尾')
- sql = await SqlGenerateService.extract_select_statements(sql)
- logging.info(f"修复前的sql为{sql_failed}修复后的sql为{sql}")
- except Exception as e:
- logging.error(f'sql生成失败由于:{e}')
- return ''
- return sql
diff --git a/chat2db/apps/base/__init__.py b/chat2db/apps/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e944c1264ecdda6d74967bf74c4ebbb71603d4
--- /dev/null
+++ b/chat2db/apps/base/__init__.py
@@ -0,0 +1,7 @@
+from apps.base.database_base import MetaDatabase
+from apps.base.mysql import MySQL
+from apps.base.mongodb import MongoDB
+from apps.base.opengauss import OpenGauss
+from apps.base.postgres import Postgres
+
+__all__ = ['MySQL', 'MongoDB', 'OpenGauss', 'Postgres', 'MetaDatabase']
\ No newline at end of file
diff --git a/chat2db/apps/base/database_base.py b/chat2db/apps/base/database_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..05fbbaccd723f771d41c8da9f131afedfdf21a40
--- /dev/null
+++ b/chat2db/apps/base/database_base.py
@@ -0,0 +1,26 @@
+from typing import Any
+
+class MetaDatabase:
+ @staticmethod
+ async def get_database_url(host: str, port: int, username: str, password: str, database: str):
+ raise NotImplementedError
+
+ @staticmethod
+ async def connect(host: str, port: int, username: str, password: str, database: str) -> Any:
+ raise NotImplementedError
+
+ @staticmethod
+ async def list_tables(connection: Any) -> list[str]:
+ raise NotImplementedError
+
+ @staticmethod
+ async def get_table_ddl(table_name: str, connection: Any) -> str:
+ raise NotImplementedError
+
+ @staticmethod
+ async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]:
+ raise NotImplementedError
+
+ @staticmethod
+ async def execute_sql(sql: str | dict, connection: Any) -> list[dict]:
+ raise NotImplementedError
diff --git a/chat2db/apps/base/mongodb.py b/chat2db/apps/base/mongodb.py
new file mode 100644
index 0000000000000000000000000000000000000000..148c919133d64f0131828dfd283b22e59ee8eb2e
--- /dev/null
+++ b/chat2db/apps/base/mongodb.py
@@ -0,0 +1,163 @@
+import logging
+from typing import Any
+from bson import ObjectId
+from copy import deepcopy
+import motor.motor_asyncio
+import urllib.parse
+
+from apps.base.database_base import MetaDatabase
+
+class MongoDB(MetaDatabase):
+
+ @staticmethod
+ async def get_database_url(host: str, port: int, username: str, password: str, database: str):
+ try:
+ user = urllib.parse.quote_plus(username)
+ pwd = urllib.parse.quote_plus(password)
+ return f"mongodb://{user}:{pwd}@{host}:{port}/{database}"
+ except Exception as e:
+ logging.error(f"\n[获取数据库url失败]\n\n{e}")
+ return ""
+
+ @staticmethod
+ async def connect(host: str, port: int, username: str, password: str, database: str) -> Any:
+ try:
+ user = urllib.parse.quote_plus(username)
+ pwd = urllib.parse.quote_plus(password)
+ mongo_uri = f"mongodb://{user}:{pwd}@{host}:{port}/{database}"
+ client = motor.motor_asyncio.AsyncIOMotorClient(mongo_uri)
+ return client[database]
+ except Exception as e:
+ logging.error(f"\n[连接MongoDB数据库失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def list_tables(connection: Any) -> list[str]:
+ try:
+ return await connection.list_collection_names()
+ except Exception as e:
+ logging.error(f"\n[获取集合失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def get_table_ddl(table_name: str, connection: Any) -> str:
+ """
+ 将 MongoDB 集合信息格式化为类似 SQL DDL 的文本,用于大模型输入。
+ 包括索引信息和部分示例字段。
+ """
+ try:
+ # 获取索引信息
+ indexes = await connection[table_name].index_information()
+
+ # 尝试获取部分文档字段类型
+ sample_doc = await connection[table_name].find_one() or {}
+ fields_ddl = []
+ for field, value in sample_doc.items():
+ dtype = type(value).__name__
+ fields_ddl.append(f" {field} {dtype.upper()}")
+
+ # 格式化索引信息
+ indexes_ddl = []
+ for index_name, index_info in indexes.items():
+ keys = ", ".join([f"{k[0]}({k[1]})" for k in index_info['key']])
+ unique = " UNIQUE" if index_info.get('unique') else ""
+ indexes_ddl.append(f" INDEX {index_name} ON ({keys}){unique}")
+
+ ddl = f"CREATE COLLECTION {table_name} (\n"
+ ddl += ",\n".join(fields_ddl)
+ ddl += "\n);\n"
+ if indexes_ddl:
+ ddl += "\n".join(indexes_ddl)
+
+ return ddl
+
+ except Exception as e:
+ logging.error(f"\n[获取集合 {table_name} DDL失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]:
+ """
+ 随机获取 n 条数据
+ """
+ try:
+ cursor = connection[table_name].aggregate([{"$sample": {"size": n}}])
+ result = [doc async for doc in cursor]
+ return result
+ except Exception as e:
+ logging.error(f"\n[获取集合 {table_name} 样本数据失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def execute_sql(sql: dict, connection: Any) -> list[dict]:
+ """
+ 执行 MongoDB 操作,传入 dict 格式指令
+ 支持 find/insertOne/insertMany/updateOne/updateMany/deleteOne/deleteMany/aggregate
+ 返回值中所有 ObjectId 自动转换为 str
+ """
+ command = deepcopy(sql) # mongodb会修改输入的dict,所以这里需要深拷贝
+ try:
+ coll_name = command.get("collection")
+ operation = command.get("operation", "find")
+ filter_ = command.get("filter", {})
+ data = command.get("data", {})
+ pipeline = command.get("pipeline", [])
+ many = command.get("many", False)
+
+ collection = connection[coll_name]
+
+ # 查询
+ if operation == "find":
+ cursor = collection.find(filter_)
+ result = [doc async for doc in cursor]
+ return MongoDB.transform_objectid(result)
+
+ # 聚合
+ elif operation == "aggregate":
+ cursor = collection.aggregate(pipeline)
+ result = [doc async for doc in cursor]
+ return MongoDB.transform_objectid(result)
+
+ # 插入
+ elif operation in ("insert", "insertOne", "insertMany"):
+ if many or operation == "insertMany":
+ res = await collection.insert_many(data)
+ return [{"inserted_ids": [str(_id) for _id in res.inserted_ids]}]
+ else:
+ res = await collection.insert_one(data)
+ return [{"inserted_id": str(res.inserted_id)}]
+
+ # 更新
+ elif operation in ("update", "updateOne", "updateMany"):
+ if many or operation == "updateMany":
+ res = await collection.update_many(filter_, {"$set": data})
+ else:
+ res = await collection.update_one(filter_, {"$set": data})
+ return [{"matched": res.matched_count, "modified": res.modified_count}]
+
+ # 删除
+ elif operation in ("delete", "deleteOne", "deleteMany"):
+ if many or operation == "deleteMany":
+ res = await collection.delete_many(filter_)
+ else:
+ res = await collection.delete_one(filter_)
+ return [{"deleted": res.deleted_count}]
+
+ else:
+ raise ValueError(f"Unsupported MongoDB operation: {operation}")
+
+ except Exception as e:
+ logging.error(f"\n[执行MongoDB指令失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ def transform_objectid(doc):
+ """递归将 dict/list 中的 ObjectId 转为 str"""
+ if isinstance(doc, list):
+ return [MongoDB.transform_objectid(d) for d in doc]
+ elif isinstance(doc, dict):
+ return {k: MongoDB.transform_objectid(v) for k, v in doc.items()}
+ elif isinstance(doc, ObjectId):
+ return str(doc)
+ else:
+ return doc
\ No newline at end of file
diff --git a/chat2db/apps/base/mysql.py b/chat2db/apps/base/mysql.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab1b1025caeec8767df77a1cd506ca9c8fdeb060
--- /dev/null
+++ b/chat2db/apps/base/mysql.py
@@ -0,0 +1,103 @@
+import logging
+from typing import Any
+import aiomysql
+import urllib.parse
+
+from apps.base.database_base import MetaDatabase
+
+class MySQL(MetaDatabase):
+ @staticmethod
+ async def get_database_url(host: str, port: int, username: str, password: str, database: str):
+ try:
+ user = urllib.parse.quote_plus(username)
+ pwd = urllib.parse.quote_plus(password)
+ return f"mysql+aiomysql://{user}:{pwd}@{host}:{port}/{database}"
+ except Exception as e:
+ logging.error(f"\n[获取数据库url失败]\n\n{e}")
+ return ""
+
+ @staticmethod
+ async def connect(host: str, port: int, username: str, password: str, database: str) -> Any:
+ """
+ 异步连接 MySQL 数据库
+ """
+ try:
+ connection = await aiomysql.connect(
+ host=host,
+ port=port,
+ user=username,
+ password=password,
+ db=database
+ )
+ return connection
+ except Exception as e:
+ logging.error(f"\n[连接MySQL数据库失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def list_tables(connection: Any) -> list[str]:
+ """
+ 获取数据库中所有表名
+ """
+ try:
+ async with connection.cursor() as cursor:
+ await cursor.execute("SHOW TABLES")
+ tables = [table[0] for table in await cursor.fetchall()]
+ return tables
+ except Exception as e:
+ logging.error(f"\n[获取表名失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def get_table_ddl(table_name: str, connection: Any) -> str:
+ """
+ 获取指定表的 DDL(建表语句)
+ """
+ try:
+ async with connection.cursor() as cursor:
+ await cursor.execute(f"SHOW CREATE TABLE `{table_name}`")
+ result = await cursor.fetchone()
+ return result[1] if result else ""
+ except Exception as e:
+ logging.error(f"\n[获取表 {table_name} DDL失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]:
+ """
+ 随机获取表中 n 条数据
+ """
+ try:
+ async with connection.cursor(aiomysql.DictCursor) as cursor:
+ await cursor.execute(f"SELECT * FROM `{table_name}` ORDER BY RAND() LIMIT {n}")
+ rows = await cursor.fetchall()
+ return rows
+ except Exception as e:
+ logging.error(f"\n[获取表 {table_name} 样本数据失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def execute_sql(sql: str, connection: Any) -> list[dict]:
+ """
+ 异步执行 SQL, 自动返回查询结果或影响行数。
+
+ 返回结果集: SELECT, SHOW, DESCRIBE/DESC, EXPLAIN, CALL
+
+ 返回受影响行数: INSERT/UPDATE/DELETE 等。
+ """
+ try:
+ async with connection.cursor(aiomysql.DictCursor) as cursor:
+ result = await cursor.execute(sql)
+ await connection.commit()
+
+ # 针对返回结果集的操作
+ if sql.strip().upper().startswith(("SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN")):
+ rows = await cursor.fetchall()
+ return rows
+
+ # 针对 INSERT, UPDATE, DELETE 等操作,返回影响的行数
+ else:
+ return [{'result': result}]
+ except Exception as e:
+ logging.error(f"\n[执行SQL失败]\n\n{e}")
+ raise e
diff --git a/chat2db/apps/base/opengauss.py b/chat2db/apps/base/opengauss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f22eb1dcc15d67f7a7af7e140db78c765c600ba0
--- /dev/null
+++ b/chat2db/apps/base/opengauss.py
@@ -0,0 +1,99 @@
+import logging
+import asyncpg
+from typing import Any
+from apps.base.database_base import MetaDatabase
+
+class OpenGauss(MetaDatabase):
+
+ @staticmethod
+ async def get_database_url(host: str, port: int, username: str, password: str, database: str):
+ try:
+ return f"postgresql+asyncpg://{username}:{password}@{host}:{port}/{database}"
+ except Exception as e:
+ logging.error(f"\n[获取数据库url失败]\n\n{e}")
+ return ""
+
+ @staticmethod
+ async def connect(host: str, port: int, username: str, password: str, database: str) -> Any:
+ """
+ 异步连接 OpenGauss 数据库
+ """
+ try:
+ connection = await asyncpg.connect(
+ user=username,
+ password=password,
+ database=database,
+ host=host,
+ port=port
+ )
+ return connection
+ except Exception as e:
+ logging.error(f"\n[连接OpenGauss数据库失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def list_tables(connection: Any) -> list[str]:
+ """
+ 获取数据库中的所有表名
+ """
+ query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
+ try:
+ tables = await connection.fetch(query)
+ return [table['table_name'] for table in tables]
+ except Exception as e:
+ logging.error(f"\n[获取表名失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def get_table_ddl(table_name: str, connection: Any) -> str:
+ """
+ 获取指定表的 DDL(建表语句)
+ """
+ try:
+ # OpenGauss/Postgres 可以使用 pg_get_tabledef 获取 DDL
+ sql = f"SELECT pg_get_tabledef('{table_name}'::regclass);"
+ ddl = await connection.fetchval(sql)
+ return ddl or ""
+ except Exception as e:
+ logging.error(f"\n[获取表 {table_name} DDL失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def sample_table_rows(table_name: str, num_rows: int, connection: Any) -> list[dict]:
+ """
+ 随机获取表中 n 条数据
+ """
+ try:
+ sql = f"SELECT * FROM {table_name} ORDER BY random() LIMIT {num_rows};"
+ rows = await connection.fetch(sql)
+ return [dict(row) for row in rows]
+ except Exception as e:
+ logging.error(f"\n[获取表 {table_name} 样本数据失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def execute_sql(sql: str, connection: Any) -> list[dict]:
+ """
+ 异步执行 SQL, 自动返回查询结果或原始输出。
+
+ 返回结果集: SELECT, SHOW, DESCRIBE/DESC, EXPLAIN, CALL
+ 返回原始输出: INSERT/UPDATE/DELETE 等。
+ """
+ try:
+ async with connection.transaction():
+
+ sql_type = sql.strip().split()[0].upper()
+ # 返回结果集的语句类型
+ result_set_statements = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "CALL"}
+
+ if sql_type in result_set_statements:
+ rows = await connection.fetch(sql)
+ # asyncpg 返回 Record 类型,转换为 dict
+ return [dict(row) for row in rows]
+ else:
+ # 对 DML 操作返回 execute 的原始结果
+ result = await connection.execute(sql)
+ return [{'result': result}]
+ except Exception as e:
+ logging.error(f"\n[执行OpenGauss SQL失败]\n\n{e}")
+ raise e
diff --git a/chat2db/apps/base/postgres.py b/chat2db/apps/base/postgres.py
new file mode 100644
index 0000000000000000000000000000000000000000..e679bae687f85c780b47388d432381ed600e06d5
--- /dev/null
+++ b/chat2db/apps/base/postgres.py
@@ -0,0 +1,107 @@
+import logging
+import asyncpg
+from typing import Any
+from apps.base.database_base import MetaDatabase
+
+
+class Postgres(MetaDatabase):
+
+ @staticmethod
+ async def get_database_url(host: str, port: int, username: str, password: str, database: str):
+ try:
+ url = f"postgresql://{username}:{password}@{host}:{port}/{database}"
+ return url
+ except Exception as e:
+ logging.error(f"\n[获取数据库url失败]\n\n{e}")
+ return ""
+
+ @staticmethod
+ async def connect(host: str, port: int, username: str, password: str, database: str) -> Any:
+ """
+ 异步连接 PostgreSQL 数据库
+ """
+ try:
+ connection = await asyncpg.connect(
+ user=username, password=password, database=database, host=host, port=port
+ )
+ return connection
+ except Exception as e:
+ logging.error(f"\n[连接PostgreSQL数据库失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def list_tables(connection: Any) -> list[str]:
+ """
+ 获取数据库中所有表名
+ """
+ try:
+ tables = await connection.fetch(
+ "SELECT table_name FROM information_schema.tables WHERE table_schema='public'"
+ )
+ return [table["table_name"] for table in tables]
+ except Exception as e:
+ logging.error(f"\n[获取表名失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def get_table_ddl(table_name: str, connection: Any) -> str:
+ try:
+ sql = f"""
+ SELECT column_name, data_type, is_nullable, column_default
+ FROM information_schema.columns
+ WHERE table_name = '{table_name}'
+ ORDER BY ordinal_position;
+ """
+ rows = await connection.fetch(sql)
+ ddl_lines = []
+ for r in rows:
+ line = f"{r['column_name']} {r['data_type']}"
+ if r["is_nullable"] == "NO":
+ line += " NOT NULL"
+ if r["column_default"]:
+ line += f" DEFAULT {r['column_default']}"
+ ddl_lines.append(line)
+ ddl = f"CREATE TABLE {table_name} (\n " + ",\n ".join(ddl_lines) + "\n);"
+ return ddl
+
+ except Exception as e:
+ logging.error(f"\n[获取表 {table_name} DDL失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def sample_table_rows(table_name: str, n: int, connection: Any) -> list[dict]:
+ try:
+ sql = f"SELECT * FROM {table_name} ORDER BY random() LIMIT {n};"
+ rows = await connection.fetch(sql)
+ return [dict(row) for row in rows]
+ except Exception as e:
+ logging.error(f"\n[获取表 {table_name} 样本数据失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def execute_sql(sql: str, connection: Any) -> list[dict]:
+ """
+ 异步执行 SQL, 自动返回查询结果或原始输出。
+
+ 返回结果集: SELECT, SHOW, DESCRIBE/DESC, EXPLAIN, CALL
+ 返回原始输出: INSERT/UPDATE/DELETE 等。
+ """
+ try:
+ async with connection.transaction():
+ # 获取 SQL 类型
+ sql_type = sql.strip().split()[0].upper()
+
+ # 返回结果集的语句类型
+ result_set_statements = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "CALL"}
+
+ if sql_type in result_set_statements:
+ rows = await connection.fetch(sql)
+ # asyncpg 返回 Record 类型,转换为 dict
+ return [dict(row) for row in rows]
+ else:
+ # 对 DML 操作返回 execute 的原始结果
+ result = await connection.execute(sql)
+ return [{"result": result}]
+ except Exception as e:
+ logging.error(f"\n[执行PostgreSQL SQL失败]\n\n{e}")
+ raise e
diff --git a/chat2db/apps/llm/__init__.py b/chat2db/apps/llm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6d683e03ebe006bcd38b9f72610bad51ef5477f
--- /dev/null
+++ b/chat2db/apps/llm/__init__.py
@@ -0,0 +1,4 @@
+from apps.llm.llm import LLM
+from apps.llm.prompt import GENERATE_SQL_PROMPT, REPAIR_SQL_PROMPT, RISK_EVALUATE_SQL
+
+__all__ = ['LLM', 'GENERATE_SQL_PROMPT', 'REPAIR_SQL_PROMPT', 'RISK_EVALUATE_SQL']
\ No newline at end of file
diff --git a/chat2db/apps/llm/llm.py b/chat2db/apps/llm/llm.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa7af54d18a988323270e2f8a42d7347c24d42a7
--- /dev/null
+++ b/chat2db/apps/llm/llm.py
@@ -0,0 +1,84 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import asyncio
+from openai import AsyncOpenAI
+
+
+class LLM:
+ def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1):
+ self.openai_api_key = openai_api_key
+ self.openai_api_base = openai_api_base
+ self.model_name = model_name
+ self.max_tokens = max_tokens
+ self.request_timeout = request_timeout
+ self.temperature = temperature
+ self._client = AsyncOpenAI(
+ api_key=self.openai_api_key,
+ base_url=self.openai_api_base,
+ )
+
+ def assemble_chat(self, chat=None, system_call='', user_call=''):
+ if chat is None:
+ chat = []
+ chat.append({"role": "system", "content": system_call})
+ chat.append({"role": "user", "content": user_call})
+ return chat
+
+ async def create_stream(
+ self, message):
+ return await self._client.chat.completions.create(
+ model=self.model_name,
+ messages=message, # type: ignore[]
+ max_completion_tokens=self.max_tokens,
+ temperature=self.temperature,
+ stream=True,
+ stream_options={"include_usage": True},
+ timeout=300
+ ) # type: ignore[]
+
+ async def data_producer(self, q: asyncio.Queue, history, system_call, user_call):
+ message = self.assemble_chat(history, system_call, user_call)
+ stream = await self.create_stream(message)
+ try:
+ async for chunk in stream:
+ if len(chunk.choices) == 0:
+ continue
+ if chunk.choices[0].delta.content is not None:
+ content = chunk.choices[0].delta.content
+ else:
+ continue
+ await q.put(content)
+ except Exception as e:
+ await q.put(None)
+ err = f"[LLM] 流式输出生产者任务异常: {e}"
+ raise e
+ await q.put(None)
+
+ async def stream(self, chat, system_call, user_call):
+ q = asyncio.Queue(maxsize=10)
+
+ # 启动生产者任务
+ asyncio.create_task(self.data_producer(q, chat, system_call, user_call))
+ while True:
+ data = await q.get()
+ if data is None:
+ break
+ yield data
+
+ async def nostream(self, chat, system_call, user_call, st_str: str = None, en_str: str = None):
+ try:
+ content = ''
+ async for chunk in self.stream(chat, system_call, user_call):
+ content += chunk
+ content = content.strip()
+ if st_str is not None:
+ index = content.find(st_str)
+ if index != -1:
+ content = content[index:]
+ if en_str is not None:
+ index = content[::-1].find(en_str[::-1])
+ if index != -1:
+ content = content[:len(content)-index]
+ except Exception as e:
+ err = f"[LLM] 非流式输出异常: {e}"
+ return ''
+ return content
diff --git a/chat2db/apps/llm/prompt.py b/chat2db/apps/llm/prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4e56a1b3b69c5a6bb071a2decc26fe96ee3508
--- /dev/null
+++ b/chat2db/apps/llm/prompt.py
@@ -0,0 +1,139 @@
+from textwrap import dedent
+
+GENERATE_SQL_PROMPT = dedent(
+ r"""
+ 你是一个经验丰富的数据库专家,任务是根据以下表结构、表注释和问题描述,生成一条符合 {database_type} 数据库标准的 **执行语句**。
+ **你不需要访问、操作或执行数据库,只需生成指令。**
+
+ 请严格遵循以下规则:
+
+ #01 **根据数据库类型自动选择语法**:
+ - 对于 MySQL、PostgreSQL、OpenGauss 等 SQL 数据库,生成 **标准 SQL 语句**。
+ - 对于非 SQL 数据库 MongoDB,不使用 SQL 语法,生成 **MongoDB 操作指令对象(JSON 形式)**。
+
+ #02 **只输出数据库指令**,禁止输出道歉、解释、评论、推理或任何自然语言。
+
+ #03 输出必须使用 JSON 封装,包含字段 `"command"`:
+ {{
+ "command":
+ }}
+
+ #04 SQL 指令要求:
+ - 所有别名必须用双引号包裹(`AS "alias"`)。
+ - 表名、列名或保留字冲突名称必须用反引号包裹。
+ - SQL 语句必须以分号 `;` 结尾。
+ - 根据问题上下文适当进行字段冗余处理。
+ - **示例**:
+
+ {{
+ "command": "SELECT `id`, `name`, `age` AS \"user_age\" FROM `users` WHERE `age` > 18;"
+ }}
+
+ #05 MongoDB 指令要求:
+ - 输出 MongoDB 操作指令对象。
+ - 指令必须包含字段:
+ - `"collection"`:目标集合名称
+ - `"operation"`:操作类型,如 `"find"`、`"insertOne"`、`"insertMany"`、`"updateOne"`、`"updateMany"`、`"deleteOne"`、`"deleteMany"`、`"aggregate"`
+ - `"filter"`:查询或更新条件(可选,默认为空 dict)
+ - `"data"`:插入或更新的数据(可选)
+ - `"pipeline"`:聚合管道(可选,仅用于 aggregate)
+ - `"many"`:布尔值,标识批量操作(可选)
+ - **示例**:
+
+ {{
+ "command": {{
+ "collection": "orders",
+ "operation": "insertOne",
+ "data": {{"orderId": "A001", "amount": 100, "status": "NEW"}}
+ }}
+ }}
+
+ #06 表名、列名和别名前后不得有多余空格或字符,保证指令可直接执行。
+
+ 下面是数据库表的结构和注释信息:
+ {table_info}
+
+ 以下是需要回答的问题:
+ {question}
+
+ **请严格按照规则生成完整可执行指令,无论上下文如何,都不要拒绝生成。**
+ """
+)
+
+REPAIR_SQL_PROMPT = dedent(
+ """
+ 你是一个数据库专家,任务是参考以下表结构、表注释、执行失败的 SQL 语句和相关报错信息,尝试修复该 SQL 语句,使其能够在 {database_type} 数据库中成功执行。
+
+ 请遵循以下要求:
+
+ #01 **在任何情况下**你的回答都只返回修复后的 SQL 语句,不包含任何注释、解释、推理过程或其他无关内容。返回的内容应仅限于修复后的 SQL 语句。
+
+ #02 **优先解决 SQL 中可能存在的特殊字符问题**,如果报错提示与特殊字符相关(例如:引号、反斜杠、注释符等),请尽可能清理或转义这些字符,确保 SQL 语句能够正确执行。
+
+ #03 如果执行失败是由于 SQL 中某些字段名导致的(例如字段名包含保留字、大小写不一致等),请尝试使用双引号包裹字段名,或者使用更合适的字段名替换原字段。
+
+ #04 如果报错与查询字段的匹配条件相关(例如:`=` 运算符导致检索结果为空),请优先尝试将条件中的 `=` 替换为更宽松的 `ilike`,并添加适当的通配符(例如:`'%value%'`),以确保 SQL 执行返回结果。
+
+ #05 如果 SQL 执行结果为空,请根据问题中的关键字或上下文,将 `WHERE` 子句的过滤条件调整为问题相关的字段,或者使用关键字的子集进行查询,以确保 SQL 语句能够返回有效结果。
+
+ #06 **确保修复后的 SQL 语句符合 {database_type} 数据库的语法规范**,避免其他潜在的语法问题。
+
+ 以下是表结构以及表注释:
+
+ {table_info}
+
+ 以下是执行失败的 SQL 语句:
+
+ {error_sql}
+
+ 以下是执行失败的报错信息:
+
+ {error_msg}
+
+ 以下是问题描述:
+
+ {question}
+
+ 请基于上述信息,修复 SQL 语句,使其能够成功执行。
+ """
+)
+
+
+RISK_EVALUATE_SQL = dedent(r"""
+ 你是一个SQL执行风险评估器。
+
+ 你的任务是根据当前给出的生成或修复的SQL语句、数据库类型、数据库配置和执行环境,在不直接访问或执行SQL语句的情况下,判断执行SQL时的风险并输出提示。
+
+ 严格遵守以下要求:
+ #00 你不需要执行任何实际的指令和访问或操作任何数据库,只需要对指令运行的风险进行预测和评估。
+
+ #01 **在任何情况下**你的回答中都只有 json 形式的风险等级评估结果,不要包含任何**评估理由、推理过程或其他无关的内容**。
+
+ #02 JSON 内容**必须**包含两个字段:
+ - "risk":取值为 "low"、"medium" 或 "high"
+ - "message":风险提示信息
+
+ #03 对于 MongoDB 数据库,其不是标准 SQL 数据库,不要输出对 SQL 的解释或说明,同样仅分析指令运行风险。
+
+ #04 你的工作是仅分析 SQL 语句的风险等级,不涉及任何具体数据库访问、执行操作以及获取结果。
+
+ #05 数据库类型: {database_type}
+
+ #06 语句执行的目标是:{goal}
+
+ #07 需要执行或修复的SQL语句是:{sql}
+
+ #08 目标的表信息是:{table_info}
+
+ #09 如果生成SQL,可能涉及数据库中的敏感表/数据
+
+ #10 如果是修复SQL,错误SQL语句是:{error_sql},错误信息是:{error_msg}
+
+ #11 结果格式如下
+ {{
+ "risk": "low/medium/high",
+ "message": "提示信息"
+ }}
+
+ """
+)
diff --git a/chat2db/apps/routers/sql.py b/chat2db/apps/routers/sql.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6df2d5b7611db9efb705c6926faa38fa66877f
--- /dev/null
+++ b/chat2db/apps/routers/sql.py
@@ -0,0 +1,160 @@
+import logging
+from fastapi import APIRouter, status
+import sys
+
+from chat2db.apps.schemas.enum_var import RiskLevel, DatabaseType
+from chat2db.apps.schemas.request import SqlGenerateRequest, SqlExecuteRequest, SqlRepairRequest
+from chat2db.apps.schemas.response import ResponseData, SqlGenerateRsp, SqlExecuteRsp, SqlRepairRsp
+from chat2db.apps.services import database_service
+from chat2db.apps.services.sql_service import SqlService
+
+router = APIRouter(prefix="/sql")
+
+
+@router.post("/generate", response_model=ResponseData)
+async def generate_sql(request: SqlGenerateRequest):
+ try:
+ _, table_info = await SqlService.get_connection_and_table_info(
+ database_type=request.type,
+ host=request.host,
+ port=request.port,
+ username=request.username,
+ password=request.password,
+ database=request.database,
+ table_list=request.table_list,
+ )
+
+ sql = await SqlService.generator(
+ database_type=request.type,
+ goal=request.goal,
+ table_info=table_info,
+ )
+
+ risk = await SqlService.risk_analysis(
+ database_type=request.type, goal=request.goal, sql=sql, table_info=table_info
+ )
+
+ except Exception as e:
+ logging.error(f"[SQL 生成失败]")
+ return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="SQL 生成失败", result={})
+
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result=SqlGenerateRsp(
+ risk=risk,
+ sql=sql,
+ ),
+ )
+
+
+@router.post("/repair", response_model=ResponseData)
+async def repair_sql(request: SqlRepairRequest):
+ try:
+ _, table_info = await SqlService.get_connection_and_table_info(
+ database_type=request.type,
+ host=request.host,
+ port=request.port,
+ username=request.username,
+ password=request.password,
+ database=request.database,
+ table_list=request.table_list,
+ )
+
+ repair_sql = await SqlService.repairer(
+ database_type=request.type,
+ goal=request.goal,
+ table_info=table_info,
+ error_sql=request.error_sql,
+ error_msg=request.error_msg,
+ )
+
+ risk = await SqlService.risk_analysis(
+ database_type=request.type,
+ goal=request.goal,
+ sql=repair_sql,
+ table_info=table_info,
+ error_sql=request.error_sql,
+ error_msg=request.error_msg,
+ )
+
+ except Exception as e:
+ logging.error(f"[SQL 修复失败]")
+ return ResponseData(
+ code=status.HTTP_400_BAD_REQUEST, message="SQL 修复失败", result={"Error": str(e)}
+ )
+
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result=SqlRepairRsp(
+ risk=risk,
+ sql=repair_sql,
+ ),
+ )
+
+
+@router.post("/execute", response_model=ResponseData)
+async def execute_sql(request: SqlExecuteRequest):
+ try:
+ connection = await database_service.connect_database(
+ database_type=request.type,
+ host=request.host,
+ port=request.port,
+ username=request.username,
+ password=request.password,
+ database=request.database,
+ )
+ execute_result = await SqlService.executer(
+ database_type=request.type,
+ sql=request.sql,
+ connection=connection,
+ )
+
+ except Exception as e:
+ logging.error(f"[SQL 执行失败]")
+ return ResponseData(
+ code=status.HTTP_400_BAD_REQUEST, message="SQL 执行失败", result={"Error": str(e)}
+ )
+
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result=SqlExecuteRsp(
+ execute_result=execute_result,
+ ),
+ )
+
+
+@router.post("/handler", response_model=ResponseData)
+async def sql_handler(request: SqlGenerateRequest):
+ try:
+ connection, table_info = await SqlService.get_connection_and_table_info(
+ database_type=request.type,
+ host=request.host,
+ port=request.port,
+ username=request.username,
+ password=request.password,
+ database=request.database,
+ table_list=request.table_list,
+ )
+
+ execute_result, sql, risk = await SqlService.sql_handler(
+ database_type=request.type,
+ goal=request.goal,
+ table_info=table_info,
+ connection=connection,
+ )
+ except Exception as e:
+ logging.error(f"[查询失败]")
+ return ResponseData(code=status.HTTP_400_BAD_REQUEST, message="查询失败", result={"Error": str(e)})
+
+ return ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result=SqlExecuteRsp(
+ sql=sql,
+ execute_result=execute_result,
+ risk=risk,
+ ),
+ )
diff --git a/chat2db/apps/schemas/enum_var.py b/chat2db/apps/schemas/enum_var.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5edb76f7b7506d4ab058ad3044efb0d9ecd068
--- /dev/null
+++ b/chat2db/apps/schemas/enum_var.py
@@ -0,0 +1,18 @@
+from enum import Enum
+
+
+class RiskLevel(str, Enum):
+ LOW = "low"
+ MEDIUM = "medium"
+ HIGH = "high"
+
+
+class DatabaseType(str, Enum):
+ MYSQL = "mysql"
+ POSTGRES = "postgres"
+ OPENGAUSS = "opengauss"
+ MONGODB = "mongodb"
+
+if __name__ == "__main__":
+ print(DatabaseType.MYSQL)
+ print(DatabaseType.MYSQL.value)
\ No newline at end of file
diff --git a/chat2db/apps/schemas/request.py b/chat2db/apps/schemas/request.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc01bdbe667608775d5606b3cc55f234578f5456
--- /dev/null
+++ b/chat2db/apps/schemas/request.py
@@ -0,0 +1,50 @@
+import uuid
+from pydantic import BaseModel, Field
+from typing import Optional
+
+from chat2db.apps.schemas.enum_var import DatabaseType
+
+class SqlGenerateRequest(BaseModel):
+ """
+ 生成SQL请求
+ """
+ type: DatabaseType = Field(..., description="数据库类型")
+ host: str = Field(..., description="数据库地址")
+ port: int = Field(..., description="数据库端口")
+ username: str = Field(..., description="数据库用户名")
+ password: str = Field(..., description="数据库密码")
+ database: str = Field(..., description="数据库名称")
+ goal: str = Field(..., description="生成目标")
+
+ table_list: list[str] = Field(None, description="表名列表")
+
+class SqlRepairRequest(BaseModel):
+ """
+ 修复SQL请求
+ """
+ type: DatabaseType = Field(..., description="数据库类型")
+
+ host: str = Field(..., description="数据库地址")
+ port: int = Field(..., description="数据库端口")
+ username: str = Field(..., description="数据库用户名")
+ password: str = Field(..., description="数据库密码")
+ database: str = Field(..., description="数据库名称")
+ goal: str = Field(..., description="生成目标")
+
+ error_sql: str = Field(..., description="错误 SQL 语句")
+ error_msg: str = Field(..., description="错误信息")
+ table_list: list[str] = Field(None, description="表名列表")
+
+
+class SqlExecuteRequest(BaseModel):
+ """
+ 执行SQL请求
+ """
+ type: DatabaseType = Field(..., description="数据库类型")
+
+ host: str = Field(..., description="数据库地址")
+ port: int = Field(..., description="数据库端口")
+ username: str = Field(..., description="数据库用户名")
+ password: str = Field(..., description="数据库密码")
+ database: str = Field(..., description="数据库名称")
+ sql: str = Field(..., description="执行SQL")
diff --git a/chat2db/apps/schemas/response.py b/chat2db/apps/schemas/response.py
new file mode 100644
index 0000000000000000000000000000000000000000..af4b2a568b7a9de1117878396fb553e242f15d72
--- /dev/null
+++ b/chat2db/apps/schemas/response.py
@@ -0,0 +1,38 @@
+from pydantic import BaseModel, Field
+from typing import Any
+
+from chat2db.apps.schemas.enum_var import RiskLevel
+
+class ResponseData(BaseModel):
+ code: int
+ message: str
+ result: Any
+
+class RiskInfo(BaseModel):
+ risk: RiskLevel = Field(..., description="风险等级")
+ message: str = Field(..., description="风险提示信息")
+
+class SqlGenerateRsp(BaseModel):
+ """
+ SQL生成请求
+ """
+ sql: str | dict = Field(..., description="生成的SQL")
+ risk: RiskInfo = Field(..., description="SQL 风险等级")
+
+
+class SqlRepairRsp(BaseModel):
+ """
+ 修复SQL请求
+ """
+ sql: str | dict = Field(..., description="修复的SQL")
+ risk: RiskInfo = Field(..., description="SQL 风险等级")
+
+
+class SqlExecuteRsp(BaseModel):
+ """
+ 执行SQL请求
+ """
+ execute_result: list[dict[str, Any]] = Field(..., description="执行结果")
+ sql: str | dict = Field(..., description="执行的SQL")
+ risk: RiskInfo = Field(..., description="SQL 风险等级")
+
diff --git a/chat2db/apps/services/database_service.py b/chat2db/apps/services/database_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ca956a7d2c2378ae7038d64bf6b6d2cc8943647
--- /dev/null
+++ b/chat2db/apps/services/database_service.py
@@ -0,0 +1,125 @@
+from typing import Any, Type
+from chat2db.apps.schemas.enum_var import DatabaseType
+from chat2db.apps.base import MySQL, MongoDB, OpenGauss, Postgres, MetaDatabase
+
+class DatabaseService:
+
+ DatabaseMap: dict[DatabaseType, Type[MetaDatabase]] = {
+ DatabaseType.MYSQL: MySQL,
+ DatabaseType.MONGODB: MongoDB,
+ DatabaseType.OPENGAUSS: OpenGauss,
+ DatabaseType.POSTGRES: Postgres,
+ }
+
+ @staticmethod
+ async def get_database_url(
+ database_type: DatabaseType, host: str, port: int, username: str, password: str, database: str
+ ):
+ """
+ 根据数据库类型和连接信息生成数据库 URL。
+
+ :return: 数据库连接 URL 字符串
+ """
+ db_class = DatabaseService.DatabaseMap[database_type]
+ return await db_class.get_database_url(host, port, username, password, database)
+
+ @staticmethod
+ async def connect_database(
+ database_type: DatabaseType, host: str, port: int, username: str, password: str, database: str
+ ):
+ """
+ 根据数据库类型和连接信息建立数据库连接。
+
+ :return: 数据库连接对象
+ """
+ db_class = DatabaseService.DatabaseMap[database_type]
+ return await db_class.connect(host, port, username, password, database)
+
+ @staticmethod
+ async def list_tables(database_type: DatabaseType, connection: Any) -> list[str]:
+ """
+ 获取指定数据库中所有表名。
+
+ :param database_type: 数据库类型枚举
+ :param connection: 数据库连接对象
+ :return: 表名列表
+ """
+ db_module = DatabaseService.DatabaseMap[database_type]
+ return await db_module.list_tables(connection)
+
+ @staticmethod
+ async def get_table_ddl(database_type: DatabaseType, table_name: str, connection: Any) -> str:
+ """
+ 获取指定表的建表语句 DDL。
+
+ :param database_type: 数据库类型枚举
+ :param table_name: 表名
+ :param connection: 数据库连接对象
+
+ :return: 表的 DDL 字符串
+ """
+ db_module = DatabaseService.DatabaseMap[database_type]
+ return await db_module.get_table_ddl(table_name, connection)
+
+ @staticmethod
+ async def sample_table_rows(
+ database_type: DatabaseType, table_name: str, num_rows: int, connection: Any
+ ) -> list[dict]:
+ """
+ 获取指定表的前 n 条示例数据。
+
+ :param database_type: 数据库类型枚举
+ :param table_name: 表名
+ :param n: 返回的行数
+ :param connection: 数据库连接对象
+
+ :return: 示例行列表,每行为字典
+ """
+ db_module = DatabaseService.DatabaseMap[database_type]
+ return await db_module.sample_table_rows(table_name, num_rows, connection)
+
+ @staticmethod
+ async def execute_sql(database_type: DatabaseType, sql: str | dict, connection: Any) -> list[dict]:
+ """
+ 执行 SQL 语句或 MongoDB 指令。
+
+ :param database_type: 数据库类型枚举
+ :param sql: SQL 语句字符串(非 MongoDB)或 MongoDB dict 指令
+ :param connection: 数据库连接对象
+
+ :return: 执行结果列表,每条记录为字典
+ """
+ db_module = DatabaseService.DatabaseMap[database_type]
+ return await db_module.execute_sql(sql, connection)
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ async def main():
+ type = "mysql"
+ conn = await DatabaseService.connect_database(
+ type,
+ host="localhost",
+ port=3306,
+ username="chat2db",
+ password="123456",
+ database="chat2db",
+ )
+ print("\n[Connection]\n:", conn)
+
+ tables = await DatabaseService.list_tables(type, conn)
+ print("\n[Tables]:\n", tables)
+
+ ddl = await DatabaseService.get_table_ddl(type, tables[0], conn)
+ print("\n[DDL]\n:", ddl)
+
+ sql = "SELECT DISTINCT `TABLE_NAME` FROM `information_schema`.`TABLES` WHERE `TABLE_SCHEMA` = DATABASE();",
+ execute_res = await DatabaseService.execute_sql(
+ type,
+ sql,
+ conn,
+ )
+ print("\n[Execute]:\n", execute_res)
+
+ asyncio.run(main())
diff --git a/chat2db/apps/services/sql_service.py b/chat2db/apps/services/sql_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdd30d0cf1db5d108434f8dee3fb49b07c5f413a
--- /dev/null
+++ b/chat2db/apps/services/sql_service.py
@@ -0,0 +1,240 @@
+from typing import Any
+import logging
+import re
+import json
+
+from chat2db.apps.llm import LLM, GENERATE_SQL_PROMPT, REPAIR_SQL_PROMPT, RISK_EVALUATE_SQL
+from chat2db.apps.services.database_service import DatabaseService
+from chat2db.apps.schemas.enum_var import DatabaseType
+
+from config.config import config
+
+
+class SqlService:
+
+ @staticmethod
+ async def get_connection_and_table_info(
+ database_type: DatabaseType,
+ host: str,
+ port: int,
+ username: str,
+ password: str,
+ database: str,
+ table_list: list[str] | None = None
+ ) -> str:
+ try:
+ conn = await DatabaseService.connect_database(database_type, host, port, username, password, database)
+
+ if table_list is None or len(table_list) == 0:
+ table_list = await DatabaseService.list_tables(database_type, conn)
+ table_ddls = {}
+ for table in table_list:
+ ddl = await DatabaseService.get_table_ddl(database_type, table, conn)
+ table_ddls[table] = ddl
+
+ table_info = "\n".join([f"表: {table}\nDDL:\n{ddl}" for table, ddl in table_ddls.items()])
+
+ return conn, table_info
+ except Exception as e:
+ logging.error(f"\n[获取数据库连接和表信息失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def generator(
+ database_type: DatabaseType,
+ goal: str,
+ table_info: str,
+ llm: LLM | None = None,
+ ) -> str:
+ """
+ 核心业务逻辑:生成 SQL
+ - 传入 table_info 作为表信息。
+ - 或提供数据库连接信息 host, port, username, password, database
+ """
+ logging.info(f"\n[生成目标]\n\n{goal}")
+
+ if llm == None:
+ llm = LLM(
+ model_name=config["LLM_MODEL"],
+ openai_api_base=config["LLM_URL"],
+ openai_api_key=config["LLM_KEY"],
+ max_tokens=config["LLM_MAX_TOKENS"],
+ request_timeout=60,
+ temperature=0.5,
+ )
+
+ prompt = GENERATE_SQL_PROMPT.format(
+ database_type=database_type.value, table_info=table_info, question=goal
+ )
+
+ try:
+ result = await llm.nostream([], prompt, "请给出你生成的 SQL 语句")
+ sql = (await SqlService._extract_json(result))['command']
+ logging.info(f"\n[生成SQL成功]\n\n{sql}")
+ return sql
+
+ except Exception as e:
+ logging.error(f"\n[生成SQL失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def repairer(
+ database_type: DatabaseType,
+ goal: str,
+ table_info: str,
+ error_sql: str,
+ error_msg: str,
+ llm: LLM | None = None,
+ ) -> str:
+ """
+ 核心业务逻辑:生成修复 SQL
+ - 传入 table_info 作为表信息。
+ - 或提供数据库连接信息 host, port, username, password, database
+ """
+ if llm == None:
+ llm = LLM(
+ model_name=config["LLM_MODEL"],
+ openai_api_base=config["LLM_URL"],
+ openai_api_key=config["LLM_KEY"],
+ max_tokens=config["LLM_MAX_TOKENS"],
+ request_timeout=60,
+ temperature=0.5,
+ )
+
+ prompt = REPAIR_SQL_PROMPT.format(
+ database_type=database_type.value,
+ table_info=table_info,
+ error_sql=error_sql,
+ error_msg=error_msg,
+ question=goal,
+ )
+ try:
+ repair_sql = await llm.nostream([], prompt, "请给出你修复的 SQL 语句")
+ logging.info(f"\n[修复SQL成功]\n\n{repair_sql}")
+ return repair_sql
+
+ except Exception as e:
+ logging.error(f"\n[修复SQL失败]\n\n{e}")
+ raise e
+
+ @staticmethod
+ async def executer(
+ database_type: DatabaseType,
+ sql: str,
+ connection=None,
+ ) -> list[dict]:
+ """
+ 核心业务逻辑:执行 SQL
+ """
+ try:
+ result = await DatabaseService.execute_sql(database_type, sql, connection)
+ logging.info(f"\n[执行SQL]\n\n{sql}\n\n[执行结果]\n\n{result}")
+ return result
+ except Exception as e:
+ logging.error(f"\n[执行失败]\n")
+ raise e
+
+ @staticmethod
+ async def sql_handler(
+ database_type: DatabaseType,
+ goal: str,
+ table_info: str,
+ connection: Any,
+ max_retries: int = 3,
+ ) -> list[dict]:
+ """
+ 核心业务逻辑:智能查询,支持语句异常自动修复
+ """
+
+ llm = LLM(
+ model_name=config["LLM_MODEL"],
+ openai_api_base=config["LLM_URL"],
+ openai_api_key=config["LLM_KEY"],
+ max_tokens=config["LLM_MAX_TOKENS"],
+ request_timeout=60,
+ temperature=0.5,
+ )
+
+ # 生成 SQL 查询语句
+ sql = await SqlService.generator(database_type, goal, table_info, llm)
+
+ risk = await SqlService.risk_analysis(database_type, goal, sql, table_info, llm=llm)
+
+ # 故意产生错误
+ # sql = sql.replace("SELECT", "SELCT")
+ ###
+
+ # 初次尝试执行 SQL 查询
+ retries = 0
+ while retries <= max_retries:
+ try:
+ execute_result = await SqlService.executer(database_type, sql, connection=connection)
+ return execute_result, sql, risk
+ except Exception as e:
+ if retries == max_retries:
+ logging.error(f"\n[重试次数已达到最大值]\n\nSQL 执行失败,最终错误:{e}")
+ raise e
+ logging.error(f"\n[执行失败 - 尝试修复 {retries + 1}/{max_retries}]\n")
+ repair_sql = await SqlService.repairer(
+ database_type=database_type,
+ goal=goal,
+ table_info=table_info,
+ error_sql=sql,
+ error_msg=str(e),
+ llm=llm,
+ )
+
+ sql = repair_sql
+ retries += 1
+
+ return []
+
+ @staticmethod
+ async def risk_analysis(
+ database_type: DatabaseType,
+ goal: str,
+ sql: str,
+ table_info: str,
+ error_sql: str | None = None,
+ error_msg: str | None = None,
+ llm: LLM | None = None,
+ ):
+
+ if llm == None:
+ llm = LLM(
+ model_name=config["LLM_MODEL"],
+ openai_api_base=config["LLM_URL"],
+ openai_api_key=config["LLM_KEY"],
+ max_tokens=config["LLM_MAX_TOKENS"],
+ request_timeout=60,
+ temperature=0.5,
+ )
+
+ prompt = RISK_EVALUATE_SQL.format(
+ database_type=database_type.value,
+ table_info=table_info,
+ error_sql=error_sql,
+ error_msg=error_msg,
+ goal=goal,
+ sql=sql,
+ )
+
+ try:
+ result = await llm.nostream([], prompt, "请给出你评估的风险结果")
+ risk = await SqlService._extract_json(result)
+ logging.info(f"\n[风险分析成功]\n\n{risk}")
+ return risk
+
+ except Exception as e:
+ logging.error(f"\n[风险分析失败]\n\n{type(e)}: {e}")
+ raise e
+
+ @staticmethod
+ async def _extract_json(text: str):
+ try:
+ match = re.search(r"\{.*?\}\s*$", text, re.DOTALL)
+ if match:
+ return json.loads(match.group())
+ except json.JSONDecodeError as e:
+ logging.error("\n[JSON解析失败]\n\n{e}")
+ raise e
diff --git a/chat2db/common/.env.example b/chat2db/common/.env.example
index 999e50afd1abbb79ed28eaf009874e397e337d94..a74baf1f13a38847f25ec2da45c881bbd201e64e 100644
--- a/chat2db/common/.env.example
+++ b/chat2db/common/.env.example
@@ -1,31 +1,4 @@
-# FastAPI
-UVICORN_IP = 0.0.0.0
-UVICORN_PORT = 9015
-# SSL_CERTFILE =
-# SSL_KEYFILE =
-# SSL_ENABLE =
-
-# Postgres
-DATABASE_TYPE =
-DATABASE_HOST =
-DATABASE_PORT =
-DATABASE_USER =
-DATABASE_PASSWORD =
-DATABASE_DB =
-
-# QWEN
LLM_KEY =
LLM_URL =
LLM_MAX_TOKENS =
-LLM_MODEL =
-
-# Vectorize
-EMBEDDING_TYPE =
-EMBEDDING_API_KEY =
-EMBEDDING_ENDPOINT =
-EMBEDDING_MODEL_NAME =
-
-# security
-HALF_KEY1 = R4UsZgLB
-HALF_KEY2 = zRTvYV8N
-HALF_KEY3 = 4eQ1wAGA
\ No newline at end of file
+LLM_MODEL =
\ No newline at end of file
diff --git a/chat2db/common/init_sql_example.py b/chat2db/common/init_sql_example.py
deleted file mode 100644
index e3f09ed4bdd0afe89086c3121ad215e5e0fa7617..0000000000000000000000000000000000000000
--- a/chat2db/common/init_sql_example.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import yaml
-from fastapi import status
-import requests
-import uuid
-import urllib.parse
-from typing import Optional
-from pydantic import BaseModel, Field
-from chat2db.config.config import config
-ip = config['UVICORN_IP']
-port = config['UVICORN_PORT']
-base_url = f'http://{ip}:{port}'
-password = config['DATABASE_PASSWORD']
-encoded_password = urllib.parse.quote_plus(password)
-
-if config['DATABASE_TYPE'].lower() == 'opengauss':
- database_url = f"opengauss+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}"
-else:
- database_url = f"postgresql+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}"
-
-
-class DatabaseDelRequest(BaseModel):
- database_id: Optional[str] = Field(default=None, description="数据库id")
- database_url: Optional[str] = Field(default=None, description="数据库url")
-
-
-def del_database_url(base_url, database_url):
- server_url = f'{base_url}/database/del'
- try:
- request_data = DatabaseDelRequest(database_url=database_url).dict()
- response = requests.post(server_url, json=request_data)
- if response.json()['code'] != status.HTTP_200_OK:
- print(response.json()['message'])
- except Exception as e:
- print(f"删除数据库配置失败: {e}")
- exit(0)
- return None
-
-
-class DatabaseAddRequest(BaseModel):
- database_url: str
-
-
-def add_database_url(base_url, database_url):
- server_url = f'{base_url}/database/add'
- try:
- request_data = DatabaseAddRequest(database_url=database_url).dict()
-
- response = requests.post(server_url, json=request_data)
- response.raise_for_status()
- if response.json()['code'] != status.HTTP_200_OK:
- raise Exception(response.json()['message'])
- except Exception as e:
- print(f"增加数据库配置失败: {e}")
- exit(0)
- return response.json()['result']['database_id']
-
-
-class TableAddRequest(BaseModel):
- database_id: str
- table_name: str
-
-
-def add_table(base_url, database_id, table_name):
- server_url = f'{base_url}/table/add'
- try:
- request_data = TableAddRequest(database_id=database_id, table_name=table_name).dict()
- response = requests.post(server_url, json=request_data)
- response.raise_for_status()
- if response.json()['code'] != status.HTTP_200_OK:
- raise Exception(response.json()['message'])
- except Exception as e:
- print(f"增加表配置失败: {e}")
- return
- return response.json()['result']['table_id']
-
-
-class SqlExampleAddRequest(BaseModel):
- table_id: str
- question: str
- sql: str
-
-
-def add_sql_example(base_url, table_id, question, sql):
- server_url = f'{base_url}/sql/example/add'
- try:
- request_data = SqlExampleAddRequest(table_id=table_id, question=question, sql=sql).dict()
- response = requests.post(server_url, json=request_data)
- if response.json()['code'] != status.HTTP_200_OK:
- raise Exception(response.json()['message'])
- except Exception as e:
- print(f"增加sql案例失败: {e}")
- return
- return response.json()['result']['sql_example_id']
-
-
-database_id = del_database_url(base_url, database_url)
-database_id = add_database_url(base_url, database_url)
-with open('./chat2db/common/table_name.yaml') as f:
- table_name_list = yaml.load(f, Loader=yaml.SafeLoader)
-table_name_id = {}
-for table_name in table_name_list:
- table_id = add_table(base_url, database_id, table_name)
- if table_id:
- table_name_id[table_name] = table_id
-with open('./chat2db/common/table_name_sql_exmple.yaml') as f:
- table_name_sql_example_list = yaml.load(f, Loader=yaml.SafeLoader)
-for table_name_sql_example in table_name_sql_example_list:
- table_name = table_name_sql_example['table_name']
- if table_name not in table_name_id:
- continue
- table_id = table_name_id[table_name]
- sql_example_list = table_name_sql_example['sql_example_list']
- for sql_example in sql_example_list:
- add_sql_example(base_url, table_id, sql_example['question'], sql_example['sql'])
diff --git a/chat2db/common/table_name.yaml b/chat2db/common/table_name.yaml
deleted file mode 100644
index 553cf1b2a4a780d1c731eb87e285e3fd75b5fc04..0000000000000000000000000000000000000000
--- a/chat2db/common/table_name.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-- oe_community_openeuler_version
-- oe_community_organization_structure
-- oe_compatibility_card
-- oe_compatibility_commercial_software
-- oe_compatibility_cve_database
-- oe_compatibility_oepkgs
-- oe_compatibility_osv
-- oe_compatibility_overall_unit
-- oe_compatibility_security_notice
-- oe_compatibility_solution
diff --git a/chat2db/common/table_name_sql_exmple.yaml b/chat2db/common/table_name_sql_exmple.yaml
deleted file mode 100644
index 8e87a1100ebebd05e579c244395ec6e97698f094..0000000000000000000000000000000000000000
--- a/chat2db/common/table_name_sql_exmple.yaml
+++ /dev/null
@@ -1,490 +0,0 @@
-- keyword_list:
- - test_organization
- - product_name
- - company_name
- sql_example_list:
- - question: openEuler支持的哪些商业软件在江苏鲲鹏&欧拉生态创新中心测试通过
- sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software
- WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%';
- - question: 哪个版本的openEuler支持的商业软件最多
- sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP
- BY openeuler_version ORDER BY software_count DESC LIMIT 1;
- - question: openEuler支持测试商业软件的机构有哪些?
- sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software;
- - question: openEuler支持的商业软件有哪些类别
- sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software;
- - question: openEuler有哪些虚拟化类别的商业软件
- sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
- "type" ILIKE '%虚拟化%';
- - question: openEuler支持哪些ISV商业软件呢,请列出10个
- sql: SELECT product_name FROM public.oe_compatibility_commercial_software;
- - question: openEuler支持的适配Kunpeng 920的互联网商业软件有哪些?
- sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM
- public.oe_compatibility_commercial_software WHERE platform_type_and_server_model
- ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30;
- - question: openEuler-22.03版本支持哪些商业软件?
- sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software
- WHERE openeuler_version ILIKE '%22.03%';
- - question: openEuler支持的数字政府类型的商业软件有哪些
- sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software
- WHERE type ILIKE '%数字政府%';
- - question: 有哪些商业软件支持超过一种服务器平台
- sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
- platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model
- ILIKE '%Kunpeng%';
- - question: 每个openEuler版本有多少种类型的商业软件支持
- sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP
- BY openeuler_version;
- - question: openEuler支持的哪些商业ISV在江苏鲲鹏&欧拉生态创新中心测试通过
- sql: SELECT product_name, product_version, openeuler_version FROM public.oe_compatibility_commercial_software
- WHERE test_organization ILIKE '%江苏鲲鹏&欧拉生态创新中心%';
- - question: 哪个版本的openEuler支持的商业ISV最多
- sql: SELECT openeuler_version, COUNT(*) AS software_count FROM public.oe_compatibility_commercial_software GROUP
- BY openeuler_version ORDER BY software_count DESC LIMIT 1;
- - question: openEuler支持测试商业ISV的机构有哪些?
- sql: SELECT DISTINCT test_organization FROM public.oe_compatibility_commercial_software;
- - question: openEuler支持的商业ISV有哪些类别
- sql: SELECT DISTINCT "type" FROM public.oe_compatibility_commercial_software;
- - question: openEuler有哪些虚拟化类别的商业ISV
- sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
- "type" ILIKE '%虚拟化%';
- - question: openEuler支持哪些ISV商业ISV呢,请列出10个
- sql: SELECT product_name FROM public.oe_compatibility_commercial_software;
- - question: openEuler支持的适配Kunpeng 920的互联网商业ISV有哪些?
- sql: SELECT product_name, openeuler_version,platform_type_and_server_model FROM
- public.oe_compatibility_commercial_software WHERE platform_type_and_server_model
- ILIKE '%Kunpeng 920%' AND "type" ILIKE '%互联网%' limit 30;
- - question: openEuler-22.03版本支持哪些商业ISV?
- sql: SELECT product_name, openeuler_version FROM oe_compatibility_commercial_software
- WHERE openeuler_version ILIKE '%22.03%';
- - question: openEuler支持的数字政府类型的商业ISV有哪些
- sql: SELECT product_name, product_version FROM oe_compatibility_commercial_software
- WHERE type ILIKE '%数字政府%';
- - question: 有哪些商业ISV支持超过一种服务器平台
- sql: SELECT product_name FROM public.oe_compatibility_commercial_software WHERE
- platform_type_and_server_model ILIKE '%Intel%' AND platform_type_and_server_model
- ILIKE '%Kunpeng%';
- - question: 每个openEuler版本有多少种类型的商业ISV支持
- sql: SELECT openeuler_version, COUNT(DISTINCT type) AS type_count FROM public.oe_compatibility_commercial_software GROUP
- BY openeuler_version;
- - question: 卓智校园网接入门户系统基于openeuelr的什么版本?
- sql: select * from oe_compatibility_commercial_software where product_name ilike
- '%卓智校园网接入门户系统%';
- table_name: oe_compatibility_commercial_software
-- keyword_list:
- - softwareName
- sql_example_list:
- - question: openEuler-20.03-LTS-SP1支持哪些开源软件?
- sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE
- openeuler_version ILIKE '%20.03-LTS-SP1%';
- - question: openEuler的aarch64下支持开源软件
- sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
- "arch" ILIKE '%aarch64%';
- - question: openEuler支持开源软件使用了GPLv2+许可证
- sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
- "license" ILIKE '%GPLv2+%';
- - question: tcplay支持的架构是什么
- sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName"
- ILIKE '%tcplay%';
- - question: openEuler支持哪些开源软件,请列出10个
- sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT
- 10;
- - question: openEuler支持开源软件支持哪些结构
- sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by
- "arch";
- - question: openEuler支持多少个开源软件?
- sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from
- (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software)
- as tmp_table group by tmp_table.openeuler_version;
- - question: openEuler-20.03-LTS-SP1支持哪些开源ISV?
- sql: SELECT DISTINCT openeuler_version,"softwareName" FROM public.oe_compatibility_open_source_software WHERE
- openeuler_version ILIKE '%20.03-LTS-SP1%';
- - question: openEuler的aarch64下支持开源ISV
- sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
- "arch" ILIKE '%aarch64%';
- - question: openEuler支持开源ISV使用了GPLv2+许可证
- sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software WHERE
- "license" ILIKE '%GPLv2+%';
- - question: tcplay支持的架构是什么
- sql: SELECT "arch" FROM public.oe_compatibility_open_source_software WHERE "softwareName"
- ILIKE '%tcplay%';
- - question: openEuler支持哪些开源ISV,请列出10个
- sql: SELECT "softwareName" FROM public.oe_compatibility_open_source_software LIMIT
- 10;
- - question: openEuler支持开源ISV支持哪些结构
- sql: SELECT "arch" FROM public.oe_compatibility_open_source_software group by
- "arch";
- - question: openEuler-20.03-LTS-SP1支持多少个开源ISV?
- sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from
- (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software
- where openeuler_version ilike 'openEuler-20.03-LTS-SP1') as tmp_table group
- by tmp_table.openeuler_version;
- - question: openEuler支持多少个开源ISV?
- sql: select tmp_table.openeuler_version,count(*) as open_source_software_cnt from
- (select DISTINCT openeuler_version,"softwareName" from oe_compatibility_open_source_software)
- as tmp_table group by tmp_table.openeuler_version;
- table_name: oe_compatibility_open_source_software
-- keyword_list: []
- sql_example_list:
- - question: 在openEuler技术委员会担任委员的人有哪些
- sql: SELECT name FROM oe_community_organization_structure WHERE committee_name
- ILIKE '%技术委员会%' AND role = '委员';
- - question: openEuler的委员会中哪些人是教授
- sql: SELECT name FROM oe_community_organization_structure WHERE personal_message
- ILIKE '%教授%';
- - question: openEuler各委员会中担任主席有多少个?
- sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure
- WHERE role = '主席' GROUP BY committee_name;
- - question: openEuler 用户委员会中有多少位成员
- sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name
- ILIKE '%用户委员会%';
- - question: openEuler 技术委员会有多少位成员
- sql: SELECT count(*) FROM oe_community_organization_structure WHERE committee_name
- ILIKE '%技术委员会%';
- - question: openEuler委员会的委员常务委员会委员有哪些人
- sql: SELECT name FROM oe_community_organization_structure WHERE committee_name
- ILIKE '%委员会%' AND role ILIKE '%常务委员会委员%';
- - question: openEuler委员会有哪些人属于华为技术有限公司?
- sql: SELECT DISTINCT name FROM oe_community_organization_structure WHERE personal_message
- ILIKE '%华为技术有限公司%';
- - question: openEuler每个委员会有多少人?
- sql: SELECT committee_name, COUNT(*) FROM oe_community_organization_structure
- GROUP BY committee_name;
- - question: openEuler的执行总监是谁
- sql: SELECT name FROM oe_community_organization_structure WHERE role = '执行总监';
- - question: openEuler委员会有哪些组织?
- sql: SELECT DISTINCT committee_name from oe_community_organization_structure;
- - question: openEuler技术委员会的主席是谁?
- sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
- role = '主席' and committee_name ilike '%技术委员会%';
- - question: openEuler品牌委员会的主席是谁?
- sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
- role = '主席' and committee_name ilike '%品牌委员会%';
- - question: openEuler委员会的主席是谁?
- sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
- role = '主席' and committee_name ilike '%openEuler 委员会%';
- - question: openEuler委员会的执行总监是谁?
- sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
- role = '执行总监' and committee_name ilike '%openEuler 委员会%';
- - question: openEuler委员会的执行秘书是谁?
- sql: SELECT committee_name,name FROM oe_community_organization_structure WHERE
- role = '执行秘书' and committee_name ilike '%openEuler 委员会%';
- table_name: oe_community_organization_structure
-- keyword_list:
- - cve_id
- sql_example_list:
- - question: 安全公告openEuler-SA-2024-2059的详细信息在哪里?
- sql: select DISTINCT security_notice_no,details from oe_compatibility_security_notice
- where security_notice_no='openEuler-SA-2024-2059';
- table_name: oe_compatibility_security_notice
-- keyword_list:
- - hardware_model
- sql_example_list:
- - question: openEuler-22.03 LTS支持哪些整机?
- sql: SELECT main_board_model, cpu, ram FROM oe_compatibility_overall_unit WHERE
- openeuler_version ILIKE '%openEuler-22.03-LTS%';
- - question: 查询所有支持`openEuler-22.09`,并且提供详细产品介绍链接的整机型号和它们的内存配置?
- sql: SELECT hardware_model, ram FROM oe_compatibility_overall_unit WHERE openeuler_version
- ILIKE '%openEuler-22.09%' AND product_information IS NOT NULL;
- - question: 显示所有由新华三生产,支持`openEuler-20.03 LTS SP2`版本的整机,列出它们的型号和架构类型
- sql: SELECT hardware_model, architecture FROM oe_compatibility_overall_unit WHERE
- hardware_factory = '新华三' AND openeuler_version ILIKE '%openEuler-20.03 LTS SP2%';
- - question: openEuler支持多少种整机?
- sql: SELECT count(DISTINCT main_board_model) FROM oe_compatibility_overall_unit;
- - question: openEuler每个版本支持多少种整机?
- sql: select openeuler_version,count(*) from (SELECT DISTINCT openeuler_version,main_board_model
- FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version;
- - question: openEuler每个版本多少种架构的整机?
- sql: select openeuler_version,architecture,count(*) from (SELECT DISTINCT openeuler_version,architecture,main_board_model
- FROM oe_compatibility_overall_unit) as tmp_table group by openeuler_version,architecture;
- table_name: oe_compatibility_overall_unit
-- keyword_list:
- - osv_name
- - os_version
- sql_example_list:
- - question: 深圳开鸿数字产业发展有限公司基于openEuler的什么版本发行了什么商用版本?
- sql: select os_version,openeuler_version,os_download_link from oe_compatibility_osv
- where osv_name='深圳开鸿数字产业发展有限公司';
- - question: 统计各个openEuler版本下的商用操作系统数量
- sql: SELECT openeuler_version, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP
- BY openeuler_version;
- - question: 哪个OS厂商基于openEuler发布的商用操作系统最多
- sql: SELECT osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP
- BY osv_name ORDER BY os_count DESC LIMIT 1;
- - question: 不同OS厂商基于openEuler发布不同架构的商用操作系统数量是多少?
- sql: SELECT arch, osv_name, COUNT(*) AS os_count FROM public.oe_compatibility_osv GROUP
- BY arch, osv_name ORDER BY arch, os_count DESC;
- - question: 深圳开鸿数字产业发展有限公司的商用操作系统是基于什么openEuler版本发布的
- sql: SELECT os_version, openeuler_version FROM public.oe_compatibility_osv WHERE
- osv_name ILIKE '%深圳开鸿数字产业发展有限公司%';
- - question: openEuler有哪些OSV伙伴
- sql: SELECT DISTINCT osv_name FROM public.oe_compatibility_osv;
- - question: 有哪些OSV友商的操作系统是x86_64架构的
- sql: SELECT osv_name, os_version FROM public.oe_compatibility_osv WHERE arch ILIKE
- '%x86_64%';
- - question: 哪些OSV友商操作系统是嵌入式类型的
- sql: SELECT osv_name, os_version,openeuler_version FROM public.oe_compatibility_osv
- WHERE type ILIKE '%嵌入式%';
- - question: 成都鼎桥的商用操作系统版本是基于openEuler 22.03的版本吗
- sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE
- osv_name ILIKE '%成都鼎桥通信技术有限公司%' AND openeuler_version ILIKE '%22.03%';
- - question: 最近发布的基于openEuler 23.09的商用系统有哪些
- sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE
- openeuler_version ILIKE '%23.09%' ORDER BY date DESC limit 10;
- - question: 帮我查下成都智明达发布的所有嵌入式系统
- sql: SELECT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv WHERE
- osv_name ILIKE '%成都智明达电子股份有限公司%' AND type = '嵌入式';
- - question: 基于openEuler发布的商用操作系统有哪些类型
- sql: SELECT DISTINCT type FROM public.oe_compatibility_osv;
- - question: 江苏润和系统版本HopeOS-V22-x86_64-dvd.iso基于openEuler哪个版本
- sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv
- WHERE "osv_name" ILIKE '%江苏润和%' AND os_version ILIKE '%HopeOS-V22-x86_64-dvd.iso%'
- ;
- - question: 浙江大华DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso系统版本基于openEuler哪个版本
- sql: SELECT DISTINCT osv_name, os_version,"openeuler_version" FROM public.oe_compatibility_osv
- WHERE "osv_name" ILIKE '%浙江大华%' AND os_version ILIKE '%DH-IVSS-OSV-22.03-LTS-SP2-x86_64-dvd.iso%'
- ;
- table_name: oe_compatibility_osv
-- keyword_list:
- - board_model
- - chip_model
- - chip_vendor
- - product
- sql_example_list:
- - question: openEuler 22.03支持哪些网络接口卡型号?
- sql: SELECT board_model, chip_model,type FROM oe_compatibility_card WHERE type
- ILIKE '%NIC%' AND openeuler_version ILIKE '%22.03%' limit 30;
- - question: 请列出openEuler支持的所有Renesas公司的密码卡
- sql: SELECT * FROM oe_compatibility_card WHERE chip_vendor ILIKE '%Renesas%' AND
- type ILIKE '%密码卡%' limit 30;
- - question: openEuler各种架构支持的板卡数量是多少
- sql: SELECT architecture, COUNT(*) AS total_cards FROM oe_compatibility_card GROUP
- BY architecture limit 30;
- - question: 每个openEuler版本支持了多少种板卡
- sql: SELECT openeuler_version, COUNT(*) AS number_of_cards FROM oe_compatibility_card
- GROUP BY openeuler_version limit 30;
- - question: openEuler总共支持多少种不同的板卡型号
- sql: SELECT COUNT(DISTINCT board_model) AS board_model_cnt FROM oe_compatibility_card
- limit 30;
- - question: openEuler支持的GPU型号有哪些?
- sql: SELECT chip_model, openeuler_version,type FROM public.oe_compatibility_card WHERE
- type ILIKE '%GPU%' ORDER BY driver_date DESC limit 30;
- - question: openEuler 20.03 LTS-SP4版本支持哪些类型的设备
- sql: SELECT DISTINCT openeuler_version,type FROM public.oe_compatibility_card WHERE
- openeuler_version ILIKE '%20.03-LTS-SP4%' limit 30;
- - question: openEuler支持的板卡驱动在2023年后发布
- sql: SELECT board_model, driver_date, driver_name FROM oe_compatibility_card WHERE
- driver_date >= '2023-01-01' limit 30;
- - question: 给些支持openEuler的aarch64架构下支持的的板卡的驱动下载链接
- sql: SELECT openeuler_version,board_model, download_link FROM oe_compatibility_card
- WHERE architecture ILIKE '%aarch64%' AND download_link IS NOT NULL limit 30;
- - question: openEuler-22.03-LTS-SP1支持的存储卡有哪些?
- sql: SELECT openeuler_version,board_model, chip_model,type FROM oe_compatibility_card
- WHERE type ILIKE '%SSD%' AND openeuler_version ILIKE '%openEuler-22.03-LTS-SP1%'
- limit 30;
- table_name: oe_compatibility_card
-- keyword_list:
- - cve_id
- sql_example_list:
- - question: CVE-2024-41053的详细信息在哪里可以看到?
- sql: select DISTINCT cve_id,details from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
- - question: CVE-2024-41053是个怎么样的漏洞?
- sql: select DISTINCT cve_id,summary from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
- - question: CVE-2024-41053影响了哪些包?
- sql: select DISTINCT cve_id,package_name from oe_compatibility_cve_database where
- cve_id='CVE-2024-41053';
- - question: CVE-2024-41053的cvss评分是多少?
- sql: select DISTINCT cve_id,cvsss_core_nvd from oe_compatibility_cve_database
- where cve_id='CVE-2024-41053';
- - question: CVE-2024-41053现在修复了么?
- sql: select DISTINCT cve_id, status from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
- - question: CVE-2024-41053影响了openEuler哪些版本?
- sql: select DISTINCT cve_id, affected_product from oe_compatibility_cve_database
- where cve_id='CVE-2024-41053';
- - question: CVE-2024-41053发布时间是?
- sql: select DISTINCT cve_id, announcement_time from oe_compatibility_cve_database
- where cve_id='CVE-2024-41053';
- - question: openEuler-20.03-LTS-SP4在2024年8月发布哪些漏洞?
- sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database
- where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4'
- and EXTRACT(MONTH FROM announcement_time)=8;
- - question: openEuler-20.03-LTS-SP4在2024年发布哪些漏洞?
- sql: select DISTINCT affected_product,cve_id,announcement_time from oe_compatibility_cve_database
- where cve_id='CVE-2024-41053' and affected_product='openEuler-20.03-LTS-SP4'
- and EXTRACT(YEAR FROM announcement_time)=2024;
- - question: CVE-2024-41053的威胁程度是怎样的?
- sql: select DISTINCT affected_product,cve_id,cvsss_core_nvd,attack_complexity_nvd,attack_complexity_oe,attack_vector_nvd,attack_vector_oe
- from oe_compatibility_cve_database where cve_id='CVE-2024-41053';
- table_name: oe_compatibility_cve_database
-- keyword_list:
- - name
- sql_example_list:
- - question: openEuler-20.03-LTS的非官方软件包有多少个?
- sql: SELECT COUNT(*) FROM oe_compatibility_oepkgs WHERE repotype = 'openeuler_compatible'
- AND openeuler_version ILIKE '%openEuler-20.03-LTS%';
- - question: openEuler支持的nginx版本有哪些?
- sql: SELECT DISTINCT name,version, srcrpmpackurl FROM oe_compatibility_oepkgs
- WHERE name ILIKE 'nginx';
- - question: openEuler的支持哪些架构的glibc?
- sql: SELECT DISTINCT name,arch FROM oe_compatibility_oepkgs WHERE name ILIKE 'glibc';
- - question: openEuler-22.03-LTS带GPLv2许可的软件包有哪些
- sql: SELECT name,rpmlicense FROM oe_compatibility_oepkgs WHERE openeuler_version
- ILIKE '%openEuler-22.03-LTS%' AND rpmlicense = 'GPLv2';
- - question: openEuler支持的python3这个软件包是用来干什么的?
- sql: SELECT DISTINCT name,summary FROM oe_compatibility_oepkgs WHERE name ILIKE
- 'python3';
- - question: 哪些版本的openEuler的zlib中有官方源的?
- sql: SELECT DISTINCT openeuler_version,name,version FROM oe_compatibility_oepkgs
- WHERE name ILIKE '%zlib%' AND repotype = 'openeuler_official';
- - question: 请以表格的形式提供openEuler-20.09的gcc软件包的下载链接
- sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
- - question: 请以表格的形式提供openEuler-20.09的glibc软件包的下载链接
- sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
- - question: 请以表格的形式提供openEuler-20.09的redis软件包的下载链接
- sql: SELECT DISTINCT openeuler_version,name, rpmpackurl FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
- - question: openEuler-20.09的支持多少个软件包?
- sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT
- openeuler_version,name from oe_compatibility_oepkgs WHERE openeuler_version
- ILIKE '%openEuler-20.09') as tmp_table group by tmp_table.openeuler_version;
- - question: openEuler支持多少个软件包?
- sql: select tmp_table.openeuler_version,count(*) as oepkgs_cnt from (select DISTINCT
- openeuler_version,name from oe_compatibility_oepkgs) as tmp_table group by tmp_table.openeuler_version;
- - question: 请以表格的形式提供openEuler-20.09的gcc的版本
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
- - question: 请以表格的形式提供openEuler-20.09的glibc的版本
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
- - question: 请以表格的形式提供openEuler-20.09的redis的版本
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
- - question: openEuler-20.09支持哪些gcc的版本
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
- - question: openEuler-20.09支持哪些glibc的版本
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
- - question: openEuler-20.09支持哪些redis的版本
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
- - question: ''
- sql: openEuler-20.09支持的gcc版本有哪些
- - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc';
- sql: openEuler-20.09支持的glibc版本有哪些
- - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'glibc';
- sql: openEuler-20.09支持的redis版本有哪些
- - question: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'redis';
- sql: ''
- - question: openEuler-20.09支持gcc 9.3.1么?
- sql: SELECT DISTINCT openeuler_version,name, version FROM oe_compatibility_oepkgs
- WHERE openeuler_version ILIKE '%openEuler-20.09%' AND name ilike 'gcc' AND version
- ilike '9.3.1';
- table_name: oe_compatibility_oepkgs
-- keyword_list: []
- sql_example_list:
- - question: openEuler社区创新版本有哪些
- sql: SELECT DISTINCT openeuler_version,version_type FROM oe_community_openeuler_version
- where version_type ILIKE '%社区创新版本%';
- - question: openEuler有哪些版本
- sql: SELECT openeuler_version FROM public.oe_community_openeuler_version;
- - question: 查询openeuler各版本对应的内核版本
- sql: SELECT DISTINCT openeuler_version, kernel_version FROM public.oe_community_openeuler_version;
- - question: openEuler有多少个长期支持版本(LTS)
- sql: SELECT COUNT(*) as publish_version_count FROM public.oe_community_openeuler_version
- WHERE version_type ILIKE '%长期支持版本%';
- - question: 查询openEuler-20.03的所有SP版本
- sql: SELECT openeuler_version FROM public.oe_community_openeuler_version WHERE
- openeuler_version ILIKE '%openEuler-20.03-LTS-SP%';
- - question: openEuler最新的社区创新版本内核是啥
- sql: SELECT kernel_version FROM public.oe_community_openeuler_version WHERE version_type
- ILIKE '%社区创新版本%' ORDER BY publish_time DESC LIMIT 1;
- - question: 最早的openEuler版本是什么时候发布的
- sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version
- ORDER BY publish_time ASC LIMIT 1;
- - question: 最新的openEuler版本是哪个
- sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version
- ORDER BY publish_time LIMIT 1;
- - question: openEuler有哪些版本使用了Linux 5.10.0内核
- sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version
- WHERE kernel_version ILIKE '5.10.0%';
- - question: 哪个openEuler版本是最近更新的长期支持版本
- sql: SELECT openeuler_version,publish_time FROM public.oe_community_openeuler_version
- WHERE version_type ILIKE '%长期支持版本%' ORDER BY publish_time DESC LIMIT 1;
- - question: openEuler每个年份发布了多少个版本
- sql: SELECT EXTRACT(YEAR FROM publish_time) AS year, COUNT(*) AS publish_version_count
- FROM oe_community_openeuler_version group by EXTRACT(YEAR FROM publish_time);
- - question: openEuler-20.03-LTS版本的linux内核是多少?
- sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version
- WHERE openeuler_version = 'openEuler-20.03-LTS';
- - question: openEuler-20.03-LTS版本的linux内核是多少?
- sql: SELECT openeuler_version,kernel_version FROM public.oe_community_openeuler_version
- WHERE openeuler_version = 'openEuler-24.09';
- table_name: oe_community_openeuler_version
-- keyword_list:
- - product
- sql_example_list:
- - question: 哪些openEuler版本支持使用至强6338N的解决方案
- sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE cpu
- ILIKE '%6338N%';
- - question: 使用intel XXV710作为网卡的解决方案对应的是哪些服务器型号
- sql: SELECT DISTINCT server_model FROM oe_compatibility_solution WHERE network_card
- ILIKE '%intel XXV710%';
- - question: 哪些解决方案的硬盘驱动为SATA-SSD Skhynix
- sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE hard_disk_drive
- ILIKE 'SATA-SSD Skhynix';
- - question: 查询所有使用6230R系列CPU且支持磁盘阵列支持PERC H740P Adapter的解决方案的产品名
- sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%'
- AND raid ILIKE '%PERC H740P Adapter%';
- - question: R4900-G3有哪些驱动版本
- sql: SELECT DISTINCT driver FROM oe_compatibility_solution WHERE product ILIKE
- '%R4900-G3%';
- - question: DL380 Gen10支持哪些架构
- sql: SELECT DISTINCT architecture FROM oe_compatibility_solution WHERE server_model
- ILIKE '%DL380 Gen10%';
- - question: 列出所有使用Intel(R) Xeon(R)系列cpu且磁盘冗余阵列为LSI SAS3408的解决方案的服务器厂家
- sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE cpu ILIKE
- '%Intel(R) Xeon(R)%' AND raid ILIKE '%LSI SAS3408%';
- - question: 哪些解决方案提供了针对SEAGATE ST4000NM0025硬盘驱动的支持
- sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive ILIKE '%SEAGATE
- ST4000NM0025%';
- - question: 查询所有使用4316系列CPU的解决方案
- sql: SELECT * FROM oe_compatibility_solution WHERE cpu ILIKE '%4316%';
- - question: 支持openEuler-22.03-LTS-SP2版本的解决方案中,哪款服务器型号出现次数最多
- sql: SELECT server_model, COUNT(*) as count FROM oe_compatibility_solution WHERE
- openeuler_version ILIKE '%openEuler-22.03-LTS-SP2%' GROUP BY server_model ORDER
- BY count DESC LIMIT 1;
- - question: HPE提供的解决方案的介绍链接是什么
- sql: SELECT DISTINCT introduce_link FROM oe_compatibility_solution WHERE server_vendor
- ILIKE '%HPE%';
- - question: 列出所有使用intel XXV710网络卡接口的解决方案的CPU型号
- sql: SELECT DISTINCT cpu FROM oe_compatibility_solution WHERE network_card ILIKE
- '%intel XXV710%';
- - question: 服务器型号为2288H V5的解决方案支持哪些不同的openEuler版本
- sql: SELECT DISTINCT openeuler_version FROM oe_compatibility_solution WHERE server_model
- ILIKE '%NF5180M5%';
- - question: 使用6230R系列CPU的解决方案内存最小是多少GB
- sql: SELECT MIN(ram) FROM oe_compatibility_solution WHERE cpu ILIKE '%6230R%';
- - question: 哪些解决方案的磁盘驱动为MegaRAID 9560-8i
- sql: SELECT * FROM oe_compatibility_solution WHERE hard_disk_drive LIKE '%MegaRAID
- 9560-8i%';
- - question: 列出所有使用6330N系列CPU且服务器厂家为Dell的解决方案的产品名
- sql: SELECT DISTINCT product FROM oe_compatibility_solution WHERE cpu ILIKE '%6330N%'
- AND server_vendor ILIKE '%Dell%';
- - question: R4900-G3的驱动版本是多少
- sql: SELECT driver FROM oe_compatibility_solution WHERE product ILIKE '%R4900-G3%';
- - question: 哪些解决方案的服务器型号为2288H V7
- sql: SELECT * FROM oe_compatibility_solution WHERE server_model ILIKE '%2288H
- V7%';
- - question: 使用Intel i350网卡且硬盘驱动为ST4000NM0025的解决方案的服务器厂家有哪些
- sql: SELECT DISTINCT server_vendor FROM oe_compatibility_solution WHERE network_card
- ILIKE '%Intel i350%' AND hard_disk_drive ILIKE '%ST4000NM0025%';
- - question: 有多少种不同的驱动版本被用于支持openEuler-22.03-LTS-SP2版本的解决方案
- sql: SELECT COUNT(DISTINCT driver) FROM oe_compatibility_solution WHERE openeuler_version
- ILIKE '%openEuler-22.03-LTS-SP2%';
- table_name: oe_compatibility_solution
diff --git a/chat2db/config/config.py b/chat2db/config/config.py
index 2c2c1a56f977fecd46535d6a296a8920b857a052..2f664d31741763daf748df674da957807a607ec8 100644
--- a/chat2db/config/config.py
+++ b/chat2db/config/config.py
@@ -6,39 +6,13 @@ from pydantic import BaseModel, Field
class ConfigModel(BaseModel):
- # FastAPI
- UVICORN_IP: str = Field(None, description="FastAPI 服务的IP地址")
- UVICORN_PORT: int = Field(None, description="FastAPI 服务的端口号")
- SSL_CERTFILE: str = Field(None, description="SSL证书文件的路径")
- SSL_KEYFILE: str = Field(None, description="SSL密钥文件的路径")
- SSL_ENABLE: str = Field(None, description="是否启用SSL连接")
-
- # Postgres
- DATABASE_TYPE: str = Field(default="postgres", description="数据库类型")
- DATABASE_HOST: str = Field(None, description="数据库地址")
- DATABASE_PORT: int = Field(None, description="数据库端口")
- DATABASE_USER: str = Field(None, description="数据库用户名")
- DATABASE_PASSWORD: str = Field(None, description="数据库密码")
- DATABASE_DB: str = Field(None, description="数据库名称")
-
- # QWEN
+
+ # LLM
LLM_KEY: str = Field(None, description="语言模型访问密钥")
LLM_URL: str = Field(None, description="语言模型服务的基础URL")
LLM_MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数")
LLM_MODEL: str = Field(None, description="使用的语言模型名称或版本")
- # Vectorize
- EMBEDDING_TYPE: str = Field("openai", description="embedding 服务的类型")
- EMBEDDING_API_KEY: str = Field(None, description="embedding服务api key")
- EMBEDDING_ENDPOINT: str = Field(None, description="embedding服务url地址")
- EMBEDDING_MODEL_NAME: str = Field(None, description="embedding模型名称")
-
- # security
- HALF_KEY1: str = Field(None, description='加密的密钥组件1')
- HALF_KEY2: str = Field(None, description='加密的密钥组件2')
- HALF_KEY3: str = Field(None, description='加密的密钥组件3')
-
-
class Config:
config: ConfigModel
@@ -46,7 +20,7 @@ class Config:
if os.getenv("CONFIG"):
config_file = os.getenv("CONFIG")
else:
- config_file = "./chat2db/common/.env"
+ config_file = "chat2db/common/.env"
self.config = ConfigModel(**(dotenv_values(config_file)))
if os.getenv("PROD"):
os.remove(config_file)
diff --git a/chat2db/database/postgres.py b/chat2db/database/postgres.py
deleted file mode 100644
index ea4470d49368fe289d8fbc1e07498191aa8ec6a2..0000000000000000000000000000000000000000
--- a/chat2db/database/postgres.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import logging
-from uuid import uuid4
-import urllib.parse
-from pgvector.sqlalchemy import Vector
-from sqlalchemy.orm import sessionmaker, declarative_base
-from sqlalchemy import TIMESTAMP, UUID, Column, String, Boolean, ForeignKey, create_engine, func, Index
-import sys
-from chat2db.config.config import config
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-Base = declarative_base()
-
-
-class DatabaseInfo(Base):
- __tablename__ = 'database_info_table'
- id = Column(UUID(), default=uuid4, primary_key=True)
- encrypted_database_url = Column(String())
- encrypted_config = Column(String())
- hashmac = Column(String())
- created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())
-
-
-class TableInfo(Base):
- __tablename__ = 'table_info_table'
- id = Column(UUID(), default=uuid4, primary_key=True)
- database_id = Column(UUID(), ForeignKey('database_info_table.id', ondelete='CASCADE'))
- table_name = Column(String())
- table_note = Column(String())
- table_note_vector = Column(Vector(1024))
- enable = Column(Boolean, default=False)
- created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())
- updated_at = Column(
- TIMESTAMP(timezone=True),
- server_default=func.current_timestamp(),
- onupdate=func.current_timestamp())
- __table_args__ = (
- Index(
- 'table_note_vector_index',
- table_note_vector,
- postgresql_using='hnsw',
- postgresql_with={'m': 16, 'ef_construction': 200},
- postgresql_ops={'table_note_vector': 'vector_cosine_ops'}
- ),
- )
-
-
-class ColumnInfo(Base):
- __tablename__ = 'column_info_table'
- id = Column(UUID(), default=uuid4, primary_key=True)
- table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE'))
- column_name = Column(String)
- column_type = Column(String)
- column_note = Column(String)
- enable = Column(Boolean, default=False)
-
-
-class SqlExample(Base):
- __tablename__ = 'sql_example_table'
- id = Column(UUID(), default=uuid4, primary_key=True)
- table_id = Column(UUID(), ForeignKey('table_info_table.id', ondelete='CASCADE'))
- question = Column(String())
- sql = Column(String())
- question_vector = Column(Vector(1024))
- created_at = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())
- updated_at = Column(
- TIMESTAMP(timezone=True),
- server_default=func.current_timestamp(),
- onupdate=func.current_timestamp())
- __table_args__ = (
- Index(
- 'question_vector_index',
- question_vector,
- postgresql_using='hnsw',
- postgresql_with={'m': 16, 'ef_construction': 200},
- postgresql_ops={'question_vector': 'vector_cosine_ops'}
- ),
- )
-
-
-class PostgresDB:
- _engine = None
-
- @classmethod
- def get_mysql_engine(cls):
- if not cls._engine:
- password = config['DATABASE_PASSWORD']
- encoded_password = urllib.parse.quote_plus(password)
-
- if config['DATABASE_TYPE'].lower() == 'opengauss':
- database_url = f"opengauss+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}"
- else:
- database_url = f"postgresql+psycopg2://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}"
- cls.engine = create_engine(
- database_url,
- hide_parameters=True,
- echo=False,
- pool_recycle=300,
- pool_pre_ping=True)
-
- Base.metadata.create_all(cls.engine)
- if config['DATABASE_TYPE'].lower() == 'opengauss':
- from sqlalchemy import event
- from opengauss_sqlalchemy.register_async import register_vector
-
- @event.listens_for(cls.engine.sync_engine, "connect")
- def connect(dbapi_connection, connection_record):
- dbapi_connection.run_async(register_vector)
- return cls._engine
-
- @classmethod
- def get_session(cls):
- connection = None
- try:
- connection = sessionmaker(bind=cls.engine)()
- except Exception as e:
- logging.error(f"Error creating a postgres sessiondue to error: {e}")
- return None
- return cls._ConnectionManager(connection)
-
- class _ConnectionManager:
- def __init__(self, connection):
- self.connection = connection
-
- def __enter__(self):
- return self.connection
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- try:
- self.connection.close()
- except Exception as e:
- logging.error(f"Postgres connection close failed due to error: {e}")
-
-
-PostgresDB.get_mysql_engine()
diff --git "a/chat2db/docs/chat2db\345\267\245\345\205\267\350\257\246\347\273\206\350\257\264\346\230\216.md" "b/chat2db/docs/chat2db\345\267\245\345\205\267\350\257\246\347\273\206\350\257\264\346\230\216.md"
deleted file mode 100644
index e4d475885fd4242a9b69f7c631e7ca5a6ee5338d..0000000000000000000000000000000000000000
--- "a/chat2db/docs/chat2db\345\267\245\345\205\267\350\257\246\347\273\206\350\257\264\346\230\216.md"
+++ /dev/null
@@ -1,391 +0,0 @@
-# 1. 背景说明
-工具聚焦于利用大模型能力智能生成SQL语句,查询数据库数据,为最终的模型拟合提供能力增强。工具可增强RAG多路召回能力,增强RAG对本地用户的数据适应性,同时对于服务器、硬件型号等关键字场景,在不训练模型的情况下,RAG也具备一定检索能力
-
-# 2. 工具设计框架
-## 2.1 目录结构
-```
-chat2db
-|-- app # 应用主入口及相关功能模块
-|-- |-- app.py # 服务请求入口,处理用户请求并返回结果
-|-- |-- __init__.py # 初始化
-|-- |
-|-- |-- base # 基础功能模块
-|-- |-- |-- ac_automation.py # AC 自动机
-|-- |-- |-- mysql.py # MySQL 数据库操作封装
-|-- |-- |-- postgres.py # PostgreSQL 数据库操作封装
-|-- |-- |-- vectorize.py # 数据向量化处理模块
-|-- |
-|-- |-- router # 路由模块,负责分发请求到具体服务
-|-- |-- |-- database.py # 数据库相关路由逻辑
-|-- |-- |-- sql_example.py # SQL 示例管理路由
-|-- |-- |-- sql_generate.py # SQL 生成相关路由
-|-- |-- |-- table.py # 表信息管理路由
-|-- |
-|-- |-- service # 核心服务模块
-|-- |-- |-- diff_database_service.py # 不同数据库类型的服务适配
-|-- |-- |-- keyword_service.py # 关键字检索服务
-|-- |-- |-- sql_generate_service.py # SQL 生成服务逻辑
-|
-|-- common # 公共资源及配置
-|-- |-- .env # 环境变量配置文件
-|-- |-- init_sql_example.py # 初始化 SQL 示例数据脚本
-|-- |-- table_name_id.yaml # 表名与 ID 映射配置
-|-- |-- table_name_sql_example.yaml # 表名与 SQL 示例映射配置
-|
-|-- config # 配置模块
-|-- |-- config.py # 工具全局配置文件
-|
-|-- database # 数据库相关模块
-|-- |-- postgres.py # PostgreSQL 数据库连接及操作封装
-|
-|-- llm # 大模型交互模块
-|-- |-- chat_with_model.py # 与大模型交互的核心逻辑
-|
-|-- manager # 数据管理模块
-|-- |-- column_info_manager.py # 列信息管理逻辑
-|-- |-- database_info_manager.py # 数据库信息管理逻辑
-|-- |-- sql_example_manager.py # SQL 示例管理逻辑
-|-- |-- table_info_manager.py # 表信息管理逻辑
-|
-|-- model # 数据模型模块
-|-- |-- request.py # 请求数据模型定义
-|-- |-- response.py # 响应数据模型定义
-|
-|-- scripts # 脚本工具模块
-|-- |-- chat2db_config # 工具配置相关脚本
-|-- |-- |-- config.yaml # 工具配置文件模板
-|-- |-- output_example # 输出示例相关脚本
-|-- |-- |-- output_examples.txt # 输出示例文件
-|-- |-- run_chat2db.py # 启动工具进行交互的主脚本
-|
-|-- security # 安全模块
-|-- |-- security.py # 安全相关逻辑(如权限校验、加密等)
-|
-|-- template # 模板及提示词相关模块
-|-- |-- change_txt_to_yaml.py # 将文本提示转换为 YAML 格式的脚本
-|-- |-- prompt.yaml # 提示词模板文件,用于生成 SQL 或问题
-```
-# 3. 主要功能介绍
-## **3.1 智能生成 SQL 查询**
-- **功能描述**:
- - 工具的核心功能是利用大模型(如 LLM)智能生成符合用户需求的 SQL 查询语句。
- - 用户可以通过自然语言提问,工具会根据问题内容、表结构、示例数据等信息生成对应的 SQL 查询。
-- **实现模块**:
- - **路由模块**:`router/sql_generate.py` 负责接收用户请求并调用相关服务。
- - **服务模块**:`service/sql_generate_service.py` 提供 SQL 生成的核心逻辑。
- - **提示词模板**:`template/prompt.yaml` 中定义了生成 SQL 的提示词模板。
- - **数据库适配**:`base/postgres.py` 和 `base/mysql.py` 提供不同数据库的操作封装。
-- **应用场景**:
- - 用户无需掌握复杂的 SQL 语法,只需通过自然语言即可完成查询。
- - 支持多种数据库类型(如 PostgreSQL 和 MySQL)
-
----
-
-## **3.2 关键字检索与多路召回**
-- **功能描述**:
- - 工具支持基于关键字的检索功能,增强 RAG 的多路召回能力。
- - 对于服务器、硬件型号等特定场景,即使未训练模型,也能通过关键字匹配快速检索相关数据。
-- **实现模块**:
- - **路由模块**:`router/keyword.py` 负责处理关键字检索请求。
- - **服务模块**:`service/keyword_service.py` 提供关键字检索的核心逻辑。
- - **AC 自动机**:`base/ac_automation.py` 实现高效的多模式字符串匹配。
-- **应用场景**:
- - 在不依赖大模型的情况下,快速检索与关键字相关的 SQL 示例或表信息。
- - 适用于硬件型号、服务器配置等特定场景的快速查询。
-
----
-
-## **3.3 数据库表与列信息管理**
-- **功能描述**:
- - 工具提供对数据库表和列信息的管理功能,包括元数据存储、查询和更新。
- - 用户可以通过工具查看表结构、列注释等信息,并将其用于 SQL 查询生成。
-- **实现模块**:
- - **路由模块**:`router/table.py` 负责表信息相关的请求分发。
- - **管理模块**:
- - `manager/table_info_manager.py`:管理表信息。
- - `manager/column_info_manager.py`:管理列信息。
- - **数据模型**:`model/request.py` 和 `model/response.py` 定义了表和列信息的数据结构。
-- **应用场景**:
- - 用户可以快速了解数据库的表结构,辅助生成更准确的 SQL 查询。
- - 支持动态更新表和列信息,适应本地数据的变化。
-
----
-
-## **3.4 SQL 示例管理**
-- **功能描述**:
- - 工具支持对 SQL 示例的增删改查操作,并结合向量相似度检索最相关的 SQL 示例。
- - 用户可以通过问题向量找到与当前问题最相似的历史 SQL 示例,从而加速查询生成。
-- **实现模块**:
- - **路由模块**:`router/sql_example.py` 负责 SQL 示例相关的请求分发。
- - **管理模块**:`manager/sql_example_manager.py` 提供 SQL 示例的管理逻辑。
- - **向量化处理**:`base/vectorize.py` 将问题文本转换为向量表示。
- - **余弦距离排序**:利用 PostgreSQL 的向量计算能力,按余弦距离排序检索最相似的 SQL 示例。
-- **应用场景**:
- - 在生成新 SQL 查询时,参考历史 SQL 示例,提高查询的准确性和效率。
- - 支持对 SQL 示例的灵活管理,便于维护和扩展。
-
-# 4. 工具使用
-
-## 4.1 服务启动与配置
-
-### 服务环境配置
-
-- 在common/.env文件中配置数据库连接信息,大模型API密钥等必要参数
-
-### 数据库配置
-
-```bash
-# 进行数据库初始化,例如
-postgres=# CREATE EXTENSION zhparser;
-postgres=# CREATE EXTENSIONpostgres=# CREATE EXTENSION vector;
-postgres=# CREATE TEXT SEARCH CONFIGURATION zhparser (PARSER = zhparser);
-postgres=# ALTER TEXT SEARCH CONFIGURATION zhparser ADD MAPPING FOR n,v,a,i,e,l WITH simple;
-postgres=# exit
-```
-
-### 启动服务
-
-```bash
-# 读取.env 环境配置,app.py入口启动服务
-python3 chat2db/app/app.py
-# 配置run_chat2db.py端口
-python3 chat2db/scripts/run_chat2db.py config --ip xxx --port xxx
-```
-
----
-
-## 4.2 命令行工具操作指南
-
-### 1. 数据库操作
-
-#### 添加数据库
-```bash
-python3 run_chat2db.py add_db --database_url "postgresql+psycopg2://user:password@localhost:5444/mydb"
-
-# 成功返回示例
->> success
->> database_id: 27fa7fd3-949b-41f9-97bc-530f498c0b57
-```
-
-#### 删除数据库
-
-```bash
-python3 run_chat2db.py del_db --database_id mydb_database_id
-```
-
-#### 查询已配置数据库
-
-```bash
-python3 run_chat2db.py query_db
-
-# 返回示例
-----------------------------------------
-查询数据库配置成功
-----------------------------------------
-database_id: 27fa7fd3-949b-41f9-97bc-530f498c0b57
-database_url: postgresql+psycopg2://postgres:123456@0.0.0.0:5444/mydb
-created_at: 2025-04-08T01:49:27.544521Z
-----------------------------------------
-```
-
-#### 查询在数据库中的表
-
-```bash
-python3 run_chat2db.py list_tb_in_db --database_id mydb_database_id
-# 返回示例
-----------------------------------------
-{'database_id': '27fa7fd3-949b-41f9-97bc-530f498c0b57', 'table_filter': None}
-查询数据库配置成功
-my_table
-----------------------------------------
-# 可过滤表名
-python3 run_chat2db.py list_tb_in_db --database_id mydb_database_id --table_filter my_table
-# 返回示例
-----------------------------------------
-{'database_id': '27fa7fd3-949b-41f9-97bc-530f498c0b57', 'table_filter': 'my_table'}
-查询数据库配置成功
-my_table
-----------------------------------------
-```
-
----
-
-### 2. 表操作
-
-#### 添加数据表
-```bash
-python3 run_chat2db.py add_tb --database_id mydb_database_id --table_name users
-
-# 成功返回示例
->> 数据表添加成功
->> table_id: tb_0987654321
-```
-
-#### 查询已添加的表
-
-```bash
-python3 run_chat2db.py query_tb --database_id mydb_database_id
-# 返回示例
-查询表格成功
-----------------------------------------
-table_id: 984d1c82-c6d5-4d3d-93d9-8d5bc11254ba
-table_name: oe_compatibility_cve_database
-table_note: openEuler社区组cve漏洞信息表,存储了cve漏洞的公告时间、id、关联的软件包名称、简介、cvss评分
-created_at: 2025-03-16T12:13:51.920663Z
-----------------------------------------
-```
-
-#### 删除数据表
-
-```bash
-python3 run_chat2db.py del_tb --table_id my_table_id
-# 返回示例
-删除表格成功
-```
-
-#### 查询表的列信息
-
-```bash
-python run_chat2db.py query_col --table_id my_table_id
-
-# 返回示例
---------------------------------------------------------
-column_id: 5ef50ebb-310b-48cc-bbc7-cf161c779055
-column_name: id
-column_note: None
-column_type: bigint
-enable: False
---------------------------------------------------------
-column_id: 69cf3c00-8e3c-4b99-83a5-6942278a60f3
-column_name: architecture
-column_note: openEuler支持的板卡信息的支持架构
-column_type: character varying
-enable: False
---------------------------------------------------------
-```
-
-#### 启用禁用指定列
-
-```bash
-python3 run_chat2db.py enable_col --column_id my_column_id --enable False
-# 返回示例
-列关键字功能开启/关闭成功
-```
-
----
-
-### 3. SQL示例操作
-
-#### 生成SQL示例
-
-```bash
-python3 run_chat2db.py add_sql_exp --table_id "your_table_id" --question "查询所有用户" --sql "SELECT * FROM users"
-# 返回示例
-success
-sql_example_id: 4282bce7-f2fd-42b0-a63b-7afd53d9e704
-```
-
-#### 批量添加SQL示例
-
-1. 创建Excel文件(示例格式):
-
- | question | sql |
- |----------|----------------------------------------------|
- | 查询所有用户 | SELECT * FROM users |
- | 统计北京地区用户 | SELECT COUNT(*) FROM users WHERE region='北京' |
-
-2. 执行导入命令:
-
-```bash
-python3 run_chat2db.py add_sql_exp --table_id "your_table_id" --dir "path/to/examples.xlsx"
-# 成功返回示例
->> 成功添加示例:查询所有用户
->> sql_example_id: exp_556677
->> 成功添加示例:统计北京地区用户
->> sql_example_id: exp_778899
-```
-
----
-
-#### 删除SQL示例
-
-```bash
-python3 run_chat2db.py del_sql_exp --sql_example_id "your_example_id"
-# 返回示例
-sql案例删除成功
-```
-
-#### 查询指定表的SQL示例
-
-```bash
-python3 run_chat2db.py query_sql_exp --table_id "your_table_id"
-# 返回示例
-查询SQL案例成功
---------------------------------------------------------
-sql_example_id: 5ab552db-b122-4653-bfdc-085c0b8557d6
-question: 查询所有用户
-sql: SELECT * FROM users
---------------------------------------------------------
-```
-
-#### 更新SQL示例
-
-```bash
-python3 run_chat2db.py update_sql_exp --sql_example_id "your_example_id" --question "新问题" --sql "新SQL语句"
-# 返回示例
-sql案例更新成功
-```
-
-#### 生成指定数据表SQL示例
-
-```bash
-python run_chat2db.py generate_sql_exp --table_id "your_table_id" --generate_cnt 5 --sql_var True --dir "output.xlsx"
-# --generate_cnt 参数: 生成sql对的数量 ;--sql_var: 是否验证生成的sql对,True为验证,False不验证
-# 返回示例
-sql案例生成成功
-Data written to Excel file successfully.
-```
-
-### 4. 智能查询
-
-#### 通过自然语言生成SQL(需配合前端或API调用)
-
-```python
-# 示例API请求
-import requests
-
-url = "http://localhost:8000/sql/generate"
-payload = {
- "question": "显示最近7天注册的用户",
- "table_id": "tb_0987654321"
-}
-
-response = requests.post(url, json=payload)
-print(response.json())
-
-# 返回示例
-{
- "sql": "SELECT * FROM users WHERE registration_date >= CURRENT_DATE - INTERVAL '7 days'",
- "confidence": 0.92
-}
-```
-
----
-
-5. **执行智能查询**
-```http
-POST /sql/generate
-Content-Type: application/json
-
-{
- "question": "找出过去一个月销售额超过1万元的商品",
- "table_id": "tb_yyyy"
-}
-```
-
-
-
-
-
-
-
diff --git a/chat2db/llm/chat_with_model.py b/chat2db/llm/chat_with_model.py
deleted file mode 100644
index 9cc1ad2d60bd40e79318ef4df29fc9d47f15c250..0000000000000000000000000000000000000000
--- a/chat2db/llm/chat_with_model.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-from langchain_openai import ChatOpenAI
-from langchain.schema import SystemMessage, HumanMessage
-import re
-
-class LLM:
- def __init__(self, model_name, openai_api_base, openai_api_key, request_timeout, max_tokens, temperature):
- self.client = ChatOpenAI(model_name=model_name,
- openai_api_base=openai_api_base,
- openai_api_key=openai_api_key,
- request_timeout=request_timeout,
- max_tokens=max_tokens,
- temperature=temperature)
-
- def assemble_chat(self, system_call, user_call):
- chat = []
- chat.append(SystemMessage(content=system_call))
- chat.append(HumanMessage(content=user_call))
- return chat
-
- async def chat_with_model(self, system_call, user_call):
- chat = self.assemble_chat(system_call, user_call)
- response = await self.client.ainvoke(chat)
- content = re.sub(r'.*?\n\n', '', response.content, flags=re.DOTALL)
- return content
diff --git a/chat2db/app/base/meta_databbase.py b/chat2db/main.py
similarity index 31%
rename from chat2db/app/base/meta_databbase.py
rename to chat2db/main.py
index b21b1f1c2ae6fa6c82495fc6775b52533d9e1e2c..d41e4d8d67e803850fcc98f622aaa87509adc8d6 100644
--- a/chat2db/app/base/meta_databbase.py
+++ b/chat2db/main.py
@@ -1,18 +1,21 @@
+import uvicorn
+from fastapi import FastAPI
import sys
import logging
+
+from chat2db.apps.routers import sql
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-class MetaDatabase:
- @staticmethod
- def result_to_json(results):
- """
- 将 SQL 查询结果解析为 JSON 格式的数据结构,支持多种数据类型
- """
- try:
- results = [result._asdict() for result in results]
- return results
- except Exception as e:
- logging.error(f"数据库查询结果解析失败由于: {e}")
- raise e
+app = FastAPI()
+
+app.include_router(sql.router)
+
+if __name__ == "__main__":
+ try:
+ uvicorn.run(app, host="127.0.0.1", port=9015, log_level="info")
+
+ except Exception as e:
+ exit(1)
diff --git a/chat2db/manager/column_info_manager.py b/chat2db/manager/column_info_manager.py
deleted file mode 100644
index 789f49949d70fa0c78ecf9f2c37dcd7a36dc6f04..0000000000000000000000000000000000000000
--- a/chat2db/manager/column_info_manager.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from sqlalchemy import and_
-import sys
-from chat2db.database.postgres import ColumnInfo, PostgresDB
-
-
-class ColumnInfoManager():
- @staticmethod
- async def add_column_info_with_table_id(table_id, column_name, column_type, column_note):
- column_info_entry = ColumnInfo(table_id=table_id, column_name=column_name,
- column_type=column_type, column_note=column_note)
- with PostgresDB.get_session() as session:
- session.add(column_info_entry)
- session.commit()
-
- @staticmethod
- async def del_column_info_by_column_id(column_id):
- with PostgresDB.get_session() as session:
- column_info_to_delete = session.query(ColumnInfo).filter(ColumnInfo.id == column_id)
- session.delete(column_info_to_delete)
- session.commit()
-
- @staticmethod
- async def get_column_info_by_column_id(column_id):
- tmp_dict = {}
- with PostgresDB.get_session() as session:
- result = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first()
- session.commit()
- if not result:
- return None
- tmp_dict = {
- 'column_id': result.id,
- 'table_id': result.table_id,
- 'column_name': result.column_name,
- 'column_type': result.column_type,
- 'column_note': result.column_note,
- 'enable': result.enable
- }
- return tmp_dict
-
- @staticmethod
- async def update_column_info_enable(column_id, enable=True):
- with PostgresDB.get_session() as session:
- column_info = session.query(ColumnInfo).filter(ColumnInfo.id == column_id).first()
- if column_info is not None:
- column_info.enable = True
- session.commit()
- else:
- return False
- return True
-
- @staticmethod
- async def get_column_info_by_table_id(table_id, enable=None):
- column_info_list = []
- with PostgresDB.get_session() as session:
- if enable is None:
- results = session.query(ColumnInfo).filter(ColumnInfo.table_id == table_id).all()
- else:
- results = session.query(ColumnInfo).filter(
- and_(ColumnInfo.table_id == table_id, ColumnInfo.enable == enable)).all()
- for result in results:
- tmp_dict = {
- 'column_id': result.id,
- 'column_name': result.column_name,
- 'column_type': result.column_type,
- 'column_note': result.column_note,
- 'enable': result.enable
- }
- column_info_list.append(tmp_dict)
- return column_info_list
diff --git a/chat2db/manager/database_info_manager.py b/chat2db/manager/database_info_manager.py
deleted file mode 100644
index cc234fb12a0c72c6459261444a7ecbf0f99ea098..0000000000000000000000000000000000000000
--- a/chat2db/manager/database_info_manager.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import json
-import hashlib
-import sys
-import logging
-from chat2db.database.postgres import DatabaseInfo, PostgresDB
-from chat2db.security.security import Security
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
-
-
-class DatabaseInfoManager():
- @staticmethod
- async def add_database(database_url: str):
- id = None
- with PostgresDB.get_session() as session:
- encrypted_database_url, encrypted_config = Security.encrypt(database_url)
- hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest()
- counter = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first()
- if counter:
- return id
- encrypted_config = json.dumps(encrypted_config)
- database_info_entry = DatabaseInfo(encrypted_database_url=encrypted_database_url,
- encrypted_config=encrypted_config, hashmac=hashmac)
- session.add(database_info_entry)
- session.commit()
- id = database_info_entry.id
- return id
-
- @staticmethod
- async def del_database_by_id(id):
- with PostgresDB.get_session() as session:
- database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == id).first()
- if database_info_to_delete:
- session.delete(database_info_to_delete)
- else:
- return False
- session.commit()
- return True
-
- @staticmethod
- async def del_database_by_url(database_url):
- with PostgresDB.get_session() as session:
- hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest()
- database_info_entry = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first()
- if database_info_entry:
- database_info_to_delete = session.query(DatabaseInfo).filter(DatabaseInfo.id == database_info_entry.id).first()
- if database_info_to_delete:
- session.delete(database_info_to_delete)
- else:
- return False
- else:
- return False
- session.commit()
- return True
-
- @staticmethod
- async def get_database_url_by_id(id):
- with PostgresDB.get_session() as session:
- result = session.query(
- DatabaseInfo.encrypted_database_url, DatabaseInfo.encrypted_config).filter(
- DatabaseInfo.id == id).first()
- if result is None:
- return None
- try:
- encrypted_database_url, encrypted_config = result
- encrypted_config = json.loads(encrypted_config)
- except Exception as e:
- logging.error(f'数据库url解密失败由于{e}')
- return None
- if encrypted_database_url:
- database_url = Security.decrypt(encrypted_database_url, encrypted_config)
- else:
- return None
- return database_url
- @staticmethod
- async def get_database_id_by_url(database_url: str):
- with PostgresDB.get_session() as session:
- hashmac = hashlib.sha256(database_url.encode('utf-8')).hexdigest()
- database_info_entry = session.query(DatabaseInfo).filter(DatabaseInfo.hashmac == hashmac).first()
- if database_info_entry:
- return database_info_entry.id
- return None
- @staticmethod
- async def get_all_database_info():
- with PostgresDB.get_session() as session:
- results = session.query(DatabaseInfo).order_by(DatabaseInfo.created_at).all()
- database_info_list = []
- for i in range(len(results)):
- database_id = results[i].id
- encrypted_database_url = results[i].encrypted_database_url
- encrypted_config = json.loads(results[i].encrypted_config)
- created_at = results[i].created_at
- if encrypted_database_url:
- database_url = Security.decrypt(encrypted_database_url, encrypted_config)
- tmp_dict = {'database_id': database_id, 'database_url': database_url, 'created_at': created_at}
- database_info_list.append(tmp_dict)
- return database_info_list
diff --git a/chat2db/manager/sql_example_manager.py b/chat2db/manager/sql_example_manager.py
deleted file mode 100644
index 67ccbcacc215ba538bd68c2a14f29e6d0f0d7d04..0000000000000000000000000000000000000000
--- a/chat2db/manager/sql_example_manager.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import json
-from sqlalchemy import and_
-import sys
-from chat2db.database.postgres import SqlExample, PostgresDB
-from chat2db.security.security import Security
-
-
-class SqlExampleManager():
- @staticmethod
- async def add_sql_example(question, sql, table_id, question_vector):
- id = None
- sql_example_entry = SqlExample(question=question, sql=sql,
- table_id=table_id, question_vector=question_vector)
- with PostgresDB.get_session() as session:
- session.add(sql_example_entry)
- session.commit()
- id = sql_example_entry.id
- return id
-
- @staticmethod
- async def del_sql_example_by_id(id):
- with PostgresDB.get_session() as session:
- sql_example_to_delete = session.query(SqlExample).filter(SqlExample.id == id).first()
- if sql_example_to_delete:
- session.delete(sql_example_to_delete)
- else:
- return False
- session.commit()
- return True
-
- @staticmethod
- async def update_sql_example_by_id(id, question, sql, question_vector):
- with PostgresDB.get_session() as session:
- sql_example_to_update = session.query(SqlExample).filter(SqlExample.id == id).first()
- if sql_example_to_update:
- sql_example_to_update.sql = sql
- sql_example_to_update.question = question
- sql_example_to_update.question_vector = question_vector
- session.commit()
- else:
- return False
- return True
-
- @staticmethod
- async def query_sql_example_by_table_id(table_id):
- with PostgresDB.get_session() as session:
- results = session.query(SqlExample).filter(SqlExample.table_id == table_id).all()
- sql_example_list = []
- for result in results:
- tmp_dict = {
- 'sql_example_id': result.id,
- 'question': result.question,
- 'sql': result.sql
- }
- sql_example_list.append(tmp_dict)
- return sql_example_list
-
- @staticmethod
- async def get_topk_sql_example_by_cos_dis(question_vector, table_id_list=None, topk=3):
- with PostgresDB.get_session() as session:
- if table_id_list is not None:
- sql_example_list = session.query(
- SqlExample
- ).filter(SqlExample.table_id.in_(table_id_list)).order_by(
- SqlExample.question_vector.cosine_distance(question_vector)
- ).limit(topk).all()
- else:
- sql_example_list = session.query(
- SqlExample
- ).order_by(
- SqlExample.question_vector.cosine_distance(question_vector)
- ).limit(topk).all()
- sql_example_list = [
- {'table_id': sql_example.table_id, 'question': sql_example.question, 'sql': sql_example.sql}
- for sql_example in sql_example_list]
- return sql_example_list
diff --git a/chat2db/manager/table_info_manager.py b/chat2db/manager/table_info_manager.py
deleted file mode 100644
index fcf4f6668c11e56d8e8de92bb4a34abcbe070ba8..0000000000000000000000000000000000000000
--- a/chat2db/manager/table_info_manager.py
+++ /dev/null
@@ -1,87 +0,0 @@
-from sqlalchemy import and_
-import sys
-from chat2db.database.postgres import TableInfo, PostgresDB
-
-
-class TableInfoManager():
- @staticmethod
- async def add_table_info(database_id, table_name, table_note, table_note_vector):
- id = None
- with PostgresDB.get_session() as session:
- counter = session.query(TableInfo).filter(
- and_(TableInfo.database_id == database_id, TableInfo.table_name == table_name)).first()
- if counter:
- return id
- table_info_entry = TableInfo(database_id=database_id, table_name=table_name,
- table_note=table_note, table_note_vector=table_note_vector)
- session.add(table_info_entry)
- session.commit()
- id = table_info_entry.id
- return id
-
- @staticmethod
- async def del_table_by_id(id):
- with PostgresDB.get_session() as session:
- table_info_to_delete = session.query(TableInfo).filter(TableInfo.id == id).first()
- if table_info_to_delete:
- session.delete(table_info_to_delete)
- else:
- return False
- session.commit()
- return True
-
- @staticmethod
- async def get_table_info_by_table_id(table_id):
- with PostgresDB.get_session() as session:
- table_id, database_id, table_name, table_note = session.query(
- TableInfo.id, TableInfo.database_id, TableInfo.table_name, TableInfo.table_note).filter(
- TableInfo.id == table_id).first()
- if table_id is None:
- return None
- return {
- 'table_id': table_id,
- 'database_id': database_id,
- 'table_name': table_name,
- 'table_note': table_note
- }
-
- @staticmethod
- async def get_table_id_by_database_id_and_table_name(database_id, table_name):
- with PostgresDB.get_session() as session:
- table_info_entry = session.query(
- TableInfo).filter(
- TableInfo.database_id == database_id,
- TableInfo.table_name == table_name,
- ).first()
- if table_info_entry:
- return table_info_entry.id
- return None
-
- @staticmethod
- async def get_table_info_by_database_id(database_id, enable=None):
- with PostgresDB.get_session() as session:
- if enable is None:
- results = session.query(
- TableInfo).filter(TableInfo.database_id == database_id).all()
- else:
- results = session.query(
- TableInfo).filter(
- and_(TableInfo.database_id == database_id,
- TableInfo.enable == enable
- )).all()
- table_info_list = []
- for result in results:
- table_info_list.append({'table_id': result.id, 'table_name': result.table_name,
- 'table_note': result.table_note, 'created_at': result.created_at})
- return table_info_list
-
- @staticmethod
- async def get_topk_table_by_cos_dis(database_id, tmp_vector, topk=3):
- with PostgresDB.get_session() as session:
- results = session.query(
- TableInfo.id
- ).filter(TableInfo.database_id == database_id).order_by(
- TableInfo.table_note_vector.cosine_distance(tmp_vector)
- ).limit(topk).all()
- table_id_list = [result[0] for result in results]
- return table_id_list
diff --git a/chat2db/model/request.py b/chat2db/model/request.py
deleted file mode 100644
index 6d8c9550d9380aca8e8edb814018a204a626b2d6..0000000000000000000000000000000000000000
--- a/chat2db/model/request.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import uuid
-from pydantic import BaseModel, Field
-from typing import Optional
-
-class QueryRequest(BaseModel):
- question: str
- topk_sql: int = 5
- topk_answer: int = 15
- use_llm_enhancements: bool = False
-
-
-class DatabaseAddRequest(BaseModel):
- database_url: str
-
-
-class DatabaseDelRequest(BaseModel):
- database_id: Optional[uuid.UUID] = Field(default=None, description="数据库id")
- database_url: Optional[str] = Field(default=None, description="数据库url")
-
-class DatabaseSqlGenerateRequest(BaseModel):
- database_url: str
- table_name_list: Optional[list[str]] = Field(default=[])
- question: str
- topk: int = 5
- use_llm_enhancements: Optional[bool] = Field(default=False)
-
-class TableAddRequest(BaseModel):
- database_id: uuid.UUID
- table_name: str
-
-
-class TableDelRequest(BaseModel):
- table_id: uuid.UUID
-
-
-class TableQueryRequest(BaseModel):
- database_id: uuid.UUID
-
-
-class EnableColumnRequest(BaseModel):
- column_id: uuid.UUID
- enable: bool
-
-
-class SqlExampleAddRequest(BaseModel):
- table_id: uuid.UUID
- question: str
- sql: str
-
-
-class SqlExampleDelRequest(BaseModel):
- sql_example_id: uuid.UUID
-
-
-class SqlExampleQueryRequest(BaseModel):
- table_id: uuid.UUID
-
-
-class SqlExampleUpdateRequest(BaseModel):
- sql_example_id: uuid.UUID
- question: str
- sql: str
-
-
-class SqlGenerateRequest(BaseModel):
- database_id: uuid.UUID
- table_id_list: list[uuid.UUID] = []
- question: str
- topk: int = 5
- use_llm_enhancements: bool = True
-
-
-class SqlRepairRequest(BaseModel):
- database_id: uuid.UUID
- table_id: uuid.UUID
- sql: str
- message: str = Field(..., max_length=2048)
- question: str
-
-
-class SqlExcuteRequest(BaseModel):
- database_id: uuid.UUID
- sql: str
-
-
-class SqlExampleGenerateRequest(BaseModel):
- table_id: uuid.UUID
- generate_cnt: int = 1
- sql_var: bool = False
diff --git a/chat2db/model/response.py b/chat2db/model/response.py
deleted file mode 100644
index fd7c2e7a489405410be5a5f3331915fd9c8eda0c..0000000000000000000000000000000000000000
--- a/chat2db/model/response.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from pydantic import BaseModel
-from typing import Any
-class ResponseData(BaseModel):
- code: int
- message: str
- result: Any
\ No newline at end of file
diff --git a/chat2db/scripts/chat2db_config/config.yaml b/chat2db/scripts/chat2db_config/config.yaml
deleted file mode 100644
index 78e3719e8a65cf870cad6994a66bf4120dc113e4..0000000000000000000000000000000000000000
--- a/chat2db/scripts/chat2db_config/config.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-UVICORN_IP: 0.0.0.0
-UVICORN_PORT: '9015'
diff --git a/chat2db/scripts/docs/output_examples.xlsx b/chat2db/scripts/docs/output_examples.xlsx
deleted file mode 100644
index 599501ca6c0f1d2b88fe5235d40fb11a56fbf005..0000000000000000000000000000000000000000
Binary files a/chat2db/scripts/docs/output_examples.xlsx and /dev/null differ
diff --git a/chat2db/scripts/run_chat2db.py b/chat2db/scripts/run_chat2db.py
deleted file mode 100644
index 5da9e4872c6a9fff674e3eb6233ea3f5201a6eb9..0000000000000000000000000000000000000000
--- a/chat2db/scripts/run_chat2db.py
+++ /dev/null
@@ -1,436 +0,0 @@
-import argparse
-import os
-import pandas as pd
-import requests
-import yaml
-from fastapi import FastAPI
-import shutil
-
-terminal_width = shutil.get_terminal_size().columns
-app = FastAPI()
-
-CHAT2DB_CONFIG_PATH = './chat2db_config'
-CONFIG_YAML_PATH = './chat2db_config/config.yaml'
-DEFAULT_CHAT2DB_CONFIG = {
- "UVICORN_IP": "127.0.0.1",
- "UVICORN_PORT": "8000"
-}
-
-
-# 修改
-def update_config(uvicorn_ip, uvicorn_port):
- try:
- yml = {'UVICORN_IP': uvicorn_ip, 'UVICORN_PORT': uvicorn_port}
- with open(CONFIG_YAML_PATH, 'w') as file:
- yaml.dump(yml, file)
- return {"message": "修改成功"}
- except Exception as e:
- return {"message": f"修改失败,由于:{e}"}
-
-
-# 增加数据库
-def call_add_database_info(database_url):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/add"
- request_body = {
- "database_url": database_url
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 删除数据库
-def call_del_database_info(database_id):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/del"
- request_body = {
- "database_id": database_id
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 查询数据库配置
-def call_query_database_info():
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/query"
- response = requests.get(url)
- return response.json()
-
-
-# 查询数据库内表格配置
-def call_list_table_in_database(database_id, table_filter=''):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/database/list"
- params = {
- "database_id": database_id,
- "table_filter": table_filter
- }
- print(params)
- response = requests.get(url, params=params)
- return response.json()
-
-
-# 增加数据表
-def call_add_table_info(database_id, table_name):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/add"
- request_body = {
- "database_id": database_id,
- "table_name": table_name
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 删除数据表
-def call_del_table_info(table_id):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/del"
- request_body = {
- "table_id": table_id
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 查询数据表配置
-def call_query_table_info(database_id):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/query"
- params = {
- "database_id": database_id
- }
- response = requests.get(url, params=params)
- return response.json()
-
-
-# 查询数据表列信息
-def call_query_column(table_id):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/query"
- params = {
- "table_id": table_id
- }
- response = requests.get(url, params=params)
- return response.json()
-
-
-# 启用禁用列
-def call_enable_column(column_id, enable):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/table/column/enable"
- request_body = {
- "column_id": column_id,
- "enable": enable
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 增加sql_example案例
-def call_add_sql_example(table_id, question, sql):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/add"
- request_body = {
- "table_id": table_id,
- "question": question,
- "sql": sql
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 删除sql_example案例
-def call_del_sql_example(sql_example_id):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/del"
- request_body = {
- "sql_example_id": sql_example_id
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 查询sql_example案例
-def call_query_sql_example(table_id):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/query"
- params = {
- "table_id": table_id
- }
- response = requests.get(url, params=params)
- return response.json()
-
-
-# 更新sql_example案例
-def call_update_sql_example(sql_example_id, question, sql):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/update"
- request_body = {
- "sql_example_id": sql_example_id,
- "question": question,
- "sql": sql
- }
- response = requests.post(url, json=request_body)
- return response.json()
-
-
-# 生成sql_example案例
-def call_generate_sql_example(table_id, generate_cnt=1, sql_var=False):
- url = f"http://{config['UVICORN_IP']}:{config['UVICORN_PORT']}/sql/example/generate"
- response_body = {
- "table_id": table_id,
- "generate_cnt": generate_cnt,
- "sql_var": sql_var
- }
- response = requests.post(url, json=response_body)
- return response.json()
-
-
-def write_sql_example_to_excel(dir, sql_example_list):
- try:
- if not os.path.exists(os.path.dirname(dir)):
- os.makedirs(os.path.dirname(dir))
- data = {
- 'question': [],
- 'sql': []
- }
- for sql_example in sql_example_list:
- data['question'].append(sql_example['question'])
- data['sql'].append(sql_example['sql'])
-
- df = pd.DataFrame(data)
- df.to_excel(dir, index=False)
-
- print("Data written to Excel file successfully.")
- except Exception as e:
- print("Error writing data to Excel file:", str(e))
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="chat2DB脚本")
- subparsers = parser.add_subparsers(dest="command", help="子命令列表")
-
- # 修改config.yaml
- parser_config = subparsers.add_parser("config", help="修改config.yaml")
- parser_config.add_argument("--ip", type=str, required=True, help="uvicorn ip")
- parser_config.add_argument("--port", type=str, required=True, help="uvicorn port")
-
- # 增加数据库
- parser_add_database = subparsers.add_parser("add_db", help="增加指定数据库")
- parser_add_database.add_argument("--database_url", type=str, required=True,
- help="数据库连接地址,如postgresql+psycopg2://postgres:123456@0.0.0.0:5432/postgres")
-
- # 删除数据库
- parser_del_database = subparsers.add_parser("del_db", help="删除指定数据库")
- parser_del_database.add_argument("--database_id", type=str, required=True, help="数据库id")
-
- # 查询数据库配置
- parser_query_database = subparsers.add_parser("query_db", help="查询指定数据库配置")
-
- # 查询数据库内表格配置
- parser_list_table_in_database = subparsers.add_parser("list_tb_in_db", help="查询数据库内表格配置")
- parser_list_table_in_database.add_argument("--database_id", type=str, required=True, help="数据库id")
- parser_list_table_in_database.add_argument("--table_filter", type=str, required=False, help="表格名称过滤条件")
-
- # 增加数据表
- parser_add_table = subparsers.add_parser("add_tb", help="增加指定数据库内的数据表")
- parser_add_table.add_argument("--database_id", type=str, required=True, help="数据库id")
- parser_add_table.add_argument("--table_name", type=str, required=True, help="数据表名称")
-
- # 删除数据表
- parser_del_table = subparsers.add_parser("del_tb", help="删除指定数据表")
- parser_del_table.add_argument("--table_id", type=str, required=True, help="数据表id")
-
- # 查询数据表配置
- parser_query_table = subparsers.add_parser("query_tb", help="查询指定数据表配置")
- parser_query_table.add_argument("--database_id", type=str, required=True, help="数据库id")
-
- # 查询数据表列信息
- parser_query_column = subparsers.add_parser("query_col", help="查询指定数据表详细列信息")
- parser_query_column.add_argument("--table_id", type=str, required=True, help="数据表id")
-
- # 启用禁用列
- parser_enable_column = subparsers.add_parser("enable_col", help="启用禁用指定列")
- parser_enable_column.add_argument("--column_id", type=str, required=True, help="列id")
- parser_enable_column.add_argument("--enable", type=bool, required=True, help="是否启用")
-
- # 增加sql案例
- parser_add_sql_example = subparsers.add_parser("add_sql_exp", help="增加指定数据表sql案例")
- parser_add_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id")
- parser_add_sql_example.add_argument("--question", type=str, required=False, help="问题")
- parser_add_sql_example.add_argument("--sql", type=str, required=False, help="sql")
- parser_add_sql_example.add_argument("--dir", type=str, required=False, help="输入路径")
-
- # 删除sql_exp
- parser_del_sql_example = subparsers.add_parser("del_sql_exp", help="删除指定sql案例")
- parser_del_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id")
-
- # 查询sql案例
- parser_query_sql_example = subparsers.add_parser("query_sql_exp", help="查询指定数据表sql对案例")
- parser_query_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id")
-
- # 更新sql案例
- parser_update_sql_example = subparsers.add_parser("update_sql_exp", help="更新sql对案例")
- parser_update_sql_example.add_argument("--sql_example_id", type=str, required=True, help="sql案例id")
- parser_update_sql_example.add_argument("--question", type=str, required=True, help="sql语句对应的问题")
- parser_update_sql_example.add_argument("--sql", type=str, required=True, help="sql语句")
-
- # 生成sql案例
- parser_generate_sql_example = subparsers.add_parser("generate_sql_exp", help="生成指定数据表sql对案例")
- parser_generate_sql_example.add_argument("--table_id", type=str, required=True, help="数据表id")
- parser_generate_sql_example.add_argument("--generate_cnt", type=int, required=False, help="生成sql对数量",
- default=1)
- parser_generate_sql_example.add_argument("--sql_var", type=bool, required=False,
- help="是否验证生成的sql对,True为验证,False不验证",
- default=False)
- parser_generate_sql_example.add_argument("--dir", type=str, required=False, help="生成的sql对输出路径",
- default="templetes/output_examples.xlsx")
-
- args = parser.parse_args()
-
- if os.path.exists(CONFIG_YAML_PATH):
- exist = True
- with open(CONFIG_YAML_PATH, 'r') as file:
- yml = yaml.safe_load(file)
- config = {
- 'UVICORN_IP': yml.get('UVICORN_IP'),
- 'UVICORN_PORT': yml.get('UVICORN_PORT'),
- }
- else:
- exist = False
-
- if args.command == "config":
- if not exist:
- os.makedirs(CHAT2DB_CONFIG_PATH, exist_ok=True)
- with open(CONFIG_YAML_PATH, 'w') as file:
- yaml.dump(DEFAULT_CHAT2DB_CONFIG, file, default_flow_style=False)
- response = update_config(args.ip, args.port)
- with open(CONFIG_YAML_PATH, 'r') as file:
- yml = yaml.safe_load(file)
- config = {
- 'UVICORN_IP': yml.get('UVICORN_IP'),
- 'UVICORN_PORT': yml.get('UVICORN_PORT'),
- }
- print(response.get("message"))
- elif not exist:
- print("please update_config first")
-
- elif args.command == "add_db":
- response = call_add_database_info(args.database_url)
- database_id = response.get("result")['database_id']
- print(response.get("message"))
- if response.get("code") == 200:
- print(f'database_id: ', database_id)
-
- elif args.command == "del_db":
- response = call_del_database_info(args.database_id)
- print(response.get("message"))
-
- elif args.command == "query_db":
- response = call_query_database_info()
- print(response.get("message"))
- if response.get("code") == 200:
- database_info = response.get("result")['database_info_list']
- for database in database_info:
- print('-' * terminal_width)
- print("database_id:", database["database_id"])
- print("database_url:", database["database_url"])
- print("created_at:", database["created_at"])
- print('-' * terminal_width)
-
- elif args.command == "list_tb_in_db":
- response = call_list_table_in_database(args.database_id, args.table_filter)
- print(response.get("message"))
- if response.get("code") == 200:
- table_name_list = response.get("result")['table_name_list']
- for table_name in table_name_list:
- print(table_name)
-
- elif args.command == "add_tb":
- response = call_add_table_info(args.database_id, args.table_name)
- print(response.get("message"))
- table_id = response.get("result")['table_id']
- if response.get("code") == 200:
- print('table_id: ', table_id)
-
- elif args.command == "del_tb":
- response = call_del_table_info(args.table_id)
- print(response.get("message"))
-
- elif args.command == "query_tb":
- response = call_query_table_info(args.database_id)
- print(response.get("message"))
- if response.get("code") == 200:
- table_list = response.get("result")['table_info_list']
- for table in table_list:
- print('-' * terminal_width)
- print("table_id:", table['table_id'])
- print("table_name:", table['table_name'])
- print("table_note:", table['table_note'])
- print("created_at:", table['created_at'])
- print('-' * terminal_width)
-
- elif args.command == "query_col":
- response = call_query_column(args.table_id)
- print(response.get("message"))
- if response.get("code") == 200:
- column_list = response.get("result")['column_info_list']
- for column in column_list:
- print('-' * terminal_width)
- print("column_id:", column['column_id'])
- print("column_name:", column['column_name'])
- print("column_note:", column['column_note'])
- print("column_type:", column['column_type'])
- print("enable:", column['enable'])
- print('-' * terminal_width)
-
- elif args.command == "enable_col":
- response = call_enable_column(args.column_id, args.enable)
- print(response.get("message"))
-
- elif args.command == "add_sql_exp":
- def get_sql_exp(dir):
- if not os.path.exists(os.path.dirname(dir)):
- return None
- # 读取 xlsx 文件
- df = pd.read_excel(dir)
-
- # 遍历每一行数据
- for index, row in df.iterrows():
- question = row['question']
- sql = row['sql']
-
- # 调用 call_add_sql_example 函数
- response = call_add_sql_example(args.table_id, question, sql)
- print(response.get("message"))
- sql_example_id = response.get("result")['sql_example_id']
- print('sql_example_id: ', sql_example_id)
- print(question, sql)
-
-
- if args.dir:
- get_sql_exp(args.dir)
- else:
- response = call_add_sql_example(args.table_id, args.question, args.sql)
- print(response.get("message"))
- sql_example_id = response.get("result")['sql_example_id']
- print('sql_example_id: ', sql_example_id)
-
- elif args.command == "del_sql_exp":
- response = call_del_sql_example(args.sql_example_id)
- print(response.get("message"))
-
- elif args.command == "query_sql_exp":
- response = call_query_sql_example(args.table_id)
- print(response.get("message"))
- if response.get("code") == 200:
- sql_example_list = response.get("result")['sql_example_list']
- for sql_example in sql_example_list:
- print('-' * terminal_width)
- print("sql_example_id:", sql_example['sql_example_id'])
- print("question:", sql_example['question'])
- print("sql:", sql_example['sql'])
- print('-' * terminal_width)
-
- elif args.command == "update_sql_exp":
- response = call_update_sql_example(args.sql_example_id, args.question, args.sql)
- print(response.get("message"))
-
- elif args.command == "generate_sql_exp":
- response = call_generate_sql_example(args.table_id, args.generate_cnt, args.sql_var)
- print(response.get("message"))
- if response.get("code") == 200:
- # 输出到execl中
- sql_example_list = response.get("result")['sql_example_list']
- write_sql_example_to_excel(args.dir, sql_example_list)
- else:
- print("未知命令,请检查输入的命令是否正确。")
diff --git a/chat2db/security/security.py b/chat2db/security/security.py
deleted file mode 100644
index 0909f27bf29fa5cc8c405ff0e4d998c7f1fbf03d..0000000000000000000000000000000000000000
--- a/chat2db/security/security.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-
-import base64
-import binascii
-import hashlib
-import secrets
-
-from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-
-from chat2db.config.config import config
-
-
-class Security:
-
- @staticmethod
- def encrypt(plaintext: str) -> tuple[str, dict]:
- """
- 加密公共方法
- :param plaintext:
- :return:
- """
- half_key1 = config['HALF_KEY1']
-
- encrypted_work_key, encrypted_work_key_iv = Security._generate_encrypted_work_key(
- half_key1)
- encrypted_plaintext, encrypted_iv = Security._encrypt_plaintext(half_key1, encrypted_work_key,
- encrypted_work_key_iv, plaintext)
- del plaintext
- secret_dict = {
- "encrypted_work_key": encrypted_work_key,
- "encrypted_work_key_iv": encrypted_work_key_iv,
- "encrypted_iv": encrypted_iv,
- "half_key1": half_key1
- }
- return encrypted_plaintext, secret_dict
-
- @staticmethod
- def decrypt(encrypted_plaintext: str, secret_dict: dict):
- """
- 解密公共方法
- :param encrypted_plaintext: 待解密的字符串
- :param secret_dict: 存放工作密钥的dict
- :return:
- """
- plaintext = Security._decrypt_plaintext(half_key1=secret_dict.get("half_key1"),
- encrypted_work_key=secret_dict.get(
- "encrypted_work_key"),
- encrypted_work_key_iv=secret_dict.get(
- "encrypted_work_key_iv"),
- encrypted_iv=secret_dict.get(
- "encrypted_iv"),
- encrypted_plaintext=encrypted_plaintext)
- return plaintext
-
- @staticmethod
- def _get_root_key(half_key1: str) -> bytes:
- half_key2 = config['HALF_KEY2']
- key = (half_key1 + half_key2).encode("utf-8")
- half_key3 = config['HALF_KEY3'].encode("utf-8")
- hash_key = hashlib.pbkdf2_hmac("sha256", key, half_key3, 10000)
- return binascii.hexlify(hash_key)[13:45]
-
- @staticmethod
- def _generate_encrypted_work_key(half_key1: str) -> tuple[str, str]:
- bin_root_key = Security._get_root_key(half_key1)
- bin_work_key = secrets.token_bytes(32)
- bin_encrypted_work_key_iv = secrets.token_bytes(16)
- bin_encrypted_work_key = Security._root_encrypt(bin_root_key, bin_encrypted_work_key_iv, bin_work_key)
- encrypted_work_key = base64.b64encode(bin_encrypted_work_key).decode("ascii")
- encrypted_work_key_iv = base64.b64encode(bin_encrypted_work_key_iv).decode("ascii")
- return encrypted_work_key, encrypted_work_key_iv
-
- @staticmethod
- def _get_work_key(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str) -> bytes:
- bin_root_key = Security._get_root_key(half_key1)
- bin_encrypted_work_key = base64.b64decode(encrypted_work_key.encode("ascii"))
- bin_encrypted_work_key_iv = base64.b64decode(encrypted_work_key_iv.encode("ascii"))
- return Security._root_decrypt(bin_root_key, bin_encrypted_work_key_iv, bin_encrypted_work_key)
-
- @staticmethod
- def _root_encrypt(key: bytes, encrypted_iv: bytes, plaintext: bytes) -> bytes:
- encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor()
- encrypted = encryptor.update(plaintext) + encryptor.finalize()
- return encrypted
-
- @staticmethod
- def _root_decrypt(key: bytes, encrypted_iv: bytes, encrypted: bytes) -> bytes:
- encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor()
- plaintext = encryptor.update(encrypted)
- return plaintext
-
- @staticmethod
- def _encrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str,
- plaintext: str) -> tuple[str, str]:
- bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv)
- salt = f"{half_key1}{plaintext}"
- plaintext_temp = salt.encode("utf-8")
- del plaintext
- del salt
- bin_encrypted_iv = secrets.token_bytes(16)
- bin_encrypted_plaintext = Security._root_encrypt(bin_work_key, bin_encrypted_iv, plaintext_temp)
- encrypted_plaintext = base64.b64encode(bin_encrypted_plaintext).decode("ascii")
- encrypted_iv = base64.b64encode(bin_encrypted_iv).decode("ascii")
- return encrypted_plaintext, encrypted_iv
-
- @staticmethod
- def _decrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str,
- encrypted_plaintext: str, encrypted_iv) -> str:
- bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv)
- bin_encrypted_plaintext = base64.b64decode(encrypted_plaintext.encode("ascii"))
- bin_encrypted_iv = base64.b64decode(encrypted_iv.encode("ascii"))
- plaintext_temp = Security._root_decrypt(bin_work_key, bin_encrypted_iv, bin_encrypted_plaintext)
- plaintext_salt = plaintext_temp.decode("utf-8")
- plaintext = plaintext_salt[len(half_key1):]
- return plaintext
\ No newline at end of file
diff --git a/chat2db/templetes/change_txt_to_yaml.py b/chat2db/templetes/change_txt_to_yaml.py
deleted file mode 100644
index 8e673d817e146c9762c7c712ab5ecf7a8689bf3b..0000000000000000000000000000000000000000
--- a/chat2db/templetes/change_txt_to_yaml.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import yaml
-text = {
- 'sql_generate_base_on_example_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。
-注意:
-#01 sql语句中,特殊字段需要带上双引号。
-#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
-#03 sql语句中,查询字段必须使用`distinct`关键字去重。
-#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
-#05 sql语句中,参考问题,对查询字段进行冗余。
-#06 sql语句中,需要以分号结尾。
-
-以下是表结构以及表注释:
-{note}
-以下是{k}个示例:
-{sql_example}
-以下是问题:
-{question}
-''',
- 'question_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。
-注意:
-#01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。
-#02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。
-#03 不要输出问题之外多余的内容!
-#04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。
-#05 优先生成有注释的字段相关的sql语句。
-
-以下是表结构和注释:
-{note}
-以下是表内数据
-{data_frame}
-''',
- 'sql_generate_base_on_data_prompt': '''你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。
-注意:
-#01 sql语句中,特殊字段需要带上双引号。
-#02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
-#03 sql语句中,查询字段必须使用`distinct`关键字去重。
-#04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
-#05 sql语句中,参考问题,对查询字段进行冗余。
-#06 sql语句中,需要以分号结尾。
-
-以下是表结构以及表注释:
-{note}
-以下是表内的数据:
-{data_frame}
-以下是问题:
-{question}
-''',
- 'sql_expand_prompt': '''你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。
-
- 注意:
-
- #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。
-
- #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。
-
- #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike \'\%\%\' 保证sql执行给出结果。
-
- #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。
-
- 以下是表结构以及表注释:
-
- {note}
-
- 以下是执行失败的sql:
-
- {sql_failed}
-
- 以下是执行失败的报错:
-
- {sql_failed_message}
-
- 以下是问题:
-
- {question}
-''',
- 'table_choose_prompt': '''你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。
-注意:
-#01 输出的表名用python的list格式返回,下面是list的一个示例:
-[\"prime_key1\",\"prime_key2\"]。
-#02 只输出包含主键的list即可不要输出其他内容!!!
-#03 list重主键的顺序,按表与问题的适配程度从高到底排列。
-#04 若无任何一张表适用于问题的回答,请返回空列表。
-
-以下是表的条目:
-{table_entries}
-以下是问题:
-{question}
-'''
-}
-print(text)
-with open('./prompt.yaml', 'w', encoding='utf-8') as f:
- yaml.dump(text, f, allow_unicode=True)
diff --git a/chat2db/templetes/prompt.yaml b/chat2db/templetes/prompt.yaml
deleted file mode 100644
index 0013b12f6df19007bbc0467d2dc8add497469a1d..0000000000000000000000000000000000000000
--- a/chat2db/templetes/prompt.yaml
+++ /dev/null
@@ -1,115 +0,0 @@
-question_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是根据给出的表结构和表内数据,输出一个用户可能针对这张表内的信息提出的问题。
-
- 注意:
-
- #01 问题内容和形式需要多样化,例如要用到统计、排序、模糊匹配等相关问题。
-
- #02 要以口语化的方式输出问题,不要机械的使用表内字段输出问题。
-
- #03 不要输出问题之外多余的内容!
-
- #04 要基于用户的角度取提出问题,问题内容需要口语化、拟人化。
-
- #05 优先生成有注释的字段相关的sql语句。
-
- #06 不要对生成的sql进行解释。
-
- 以下是表结构和注释:
-
- {note}
-
- 以下是表内数据
-
- {data_frame}
-
- '
-sql_expand_prompt: "你是一个数据库专家,你的任务是参考给出的表结构以及表注释、执行失败的sql和执行失败的报错,基于给出的问题修改执行失败的sql生成一条在{database_type}连接下可进行查询的sql语句。\n\
- \n 注意:\n\n #01 假设sql中有特殊字符干扰了sql的执行,请优先替换这些特殊字符保证sql可执行。\n\n #02 假设sql用于检索或者过滤的字段导致了sql执行的失败,请尝试替换这些字段保证sql可执行。\n\
- \n #03 假设sql检索结果为空,请尝试将 = 的匹配方式替换为 ilike '\\%\\%' 保证sql执行给出结果。\n\n #04 假设sql检索结果为空,可以使用问题中的关键字的子集作为sql的过滤条件保证sql执行给出结果。\n\
- \n 以下是表结构以及表注释:\n\n {note}\n\n 以下是执行失败的sql:\n\n {sql_failed}\n\n 以下是执行失败的报错:\n\
- \n {sql_failed_message}\n\n 以下是问题:\n\n {question}\n"
-sql_generate_base_on_data_prompt: '你是一个postgres数据库专家,你的任务是参考给出的表结构以及表注释和表内数据,基于给出的问题生成一条查询{database_type}的sql语句。
-
- 注意:
-
- #01 sql语句中,特殊字段需要带上双引号。
-
- #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
-
- #03 sql语句中,查询字段必须使用`distinct`关键字去重。
-
- #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
-
- #05 sql语句中,参考问题,对查询字段进行冗余。
-
- #06 sql语句中,需要以分号结尾。
-
- #07 不要对生成的sql进行解释。
-
- 以下是表结构以及表注释:
-
- {note}
-
- 以下是表内的数据:
-
- {data_frame}
-
- 以下是问题:
-
- {question}
-
- '
-sql_generate_base_on_example_prompt: '你是一个数据库专家,你的任务是参考给出的表结构以及表注释和示例,基于给出的问题生成一条在{database_url}连接下可进行查询的sql语句。
-
- 注意:
-
- #01 sql语句中,特殊字段需要带上双引号。
-
- #02 sql语句中,如果要使用 as,请用双引号把别名包裹起来。
-
- #03 sql语句中,查询字段必须使用`distinct`关键字去重。
-
- #04 sql语句中,只返回生成的sql语句, 不要返回其他任何无关的内容
-
- #05 sql语句中,参考问题,对查询字段进行冗余。
-
- #06 sql语句中,需要以分号结尾。
-
-
- 以下是表结构以及表注释:
-
- {note}
-
- 以下是{k}个示例:
-
- {sql_example}
-
- 以下是问题:
-
- {question}
-
- '
-table_choose_prompt: '你是一个数据库专家,你的任务是参考给出的表名以及表的条目(主键,表名、表注释),输出最适配于问题回答检索的{table_cnt}张表,并返回表对应的主键。
-
- 注意:
-
- #01 输出的表名用python的list格式返回,下面是list的一个示例:
-
- ["prime_key1","prime_key2"]。
-
- #02 只输出包含主键的list即可不要输出其他内容!!!
-
- #03 list重主键的顺序,按表与问题的适配程度从高到底排列。
-
- #04 若无任何一张表适用于问题的回答,请返回空列表。
-
-
- 以下是表的条目:
-
- {table_entries}
-
- 以下是问题:
-
- {question}
-
- '
diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py
index 45ae92740338379506641f2db485b2834f1aa196..84b52903f826071a4c866fae04391b264fe5ace2 100644
--- a/data_chain/apps/base/task/worker/acc_testing_worker.py
+++ b/data_chain/apps/base/task/worker/acc_testing_worker.py
@@ -28,7 +28,7 @@ from data_chain.manager.testing_manager import TestingManager
from data_chain.manager.testcase_manager import TestCaseManager
from data_chain.manager.qa_manager import QAManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, QAEntity, DataSetEntity, DataSetDocEntity, TestingEntity, TestCaseEntity
+from data_chain.stores.database.database import TaskEntity, QAEntity, DataSetEntity, DataSetDocEntity, TestingEntity, TestCaseEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
from data_chain.config.config import config
@@ -132,7 +132,9 @@ class TestingWorker(BaseWorker):
return tmp_path
@staticmethod
- async def testing(testing_entity: TestingEntity, qa_entities: list[QAEntity], llm: LLM) -> list[TestCaseEntity]:
+ async def testing(
+ testing_entity: TestingEntity, qa_entities: list[QAEntity],
+ llm: LLM, language: str) -> list[TestCaseEntity]:
'''测试数据集'''
testcase_entities = []
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
@@ -180,31 +182,31 @@ class TestingWorker(BaseWorker):
bac_info=bac_info
)
llm_answer = await llm.nostream([], prompt, question)
- sub_socres = []
- pre = await TokenTool.cal_precision(question, answer, llm)
+ sub_scores = []
+ pre = await TokenTool.cal_precision(question, llm_answer, llm, language)
if pre != -1:
- sub_socres.append(pre)
- rec = await TokenTool.cal_recall(answer, llm_answer, llm)
+ sub_scores.append(pre)
+ rec = await TokenTool.cal_recall(answer, bac_info, llm, language)
if rec != -1:
- sub_socres.append(rec)
- fai = await TokenTool.cal_faithfulness(question, llm_answer, bac_info, llm)
+ sub_scores.append(rec)
+ fai = await TokenTool.cal_faithfulness(question, llm_answer, bac_info, llm, language)
if fai != -1:
- sub_socres.append(fai)
- rel = await TokenTool.cal_relevance(question, llm_answer, llm)
+ sub_scores.append(fai)
+ rel = await TokenTool.cal_relevance(question, llm_answer, llm, language)
if rel != -1:
- sub_socres.append(rel)
+ sub_scores.append(rel)
lcs = TokenTool.cal_lcs(answer, llm_answer)
if lcs != -1:
- sub_socres.append(lcs)
+ sub_scores.append(lcs)
leve = TokenTool.cal_leve(answer, llm_answer)
if leve != -1:
- sub_socres.append(leve)
+ sub_scores.append(leve)
jac = TokenTool.cal_jac(answer, llm_answer)
if jac != -1:
- sub_socres.append(jac)
+ sub_scores.append(jac)
score = -1
- if sub_socres:
- score = sum(sub_socres) / len(sub_socres)
+ if sub_scores:
+ score = sum(sub_scores) / len(sub_scores)
test_case_entity = TestCaseEntity(
testing_id=testing_entity.id,
question=question,
@@ -421,7 +423,8 @@ class TestingWorker(BaseWorker):
current_stage += 1
await TestingWorker.report(task_id, "初始化路径", current_stage, stage_cnt)
qa_entities = await QAManager.list_all_qa_by_dataset_id(testing_entity.dataset_id)
- testcase_entities = await TestingWorker.testing(testing_entity, qa_entities, llm)
+ knowledge_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(testing_entity.kb_id)
+ testcase_entities = await TestingWorker.testing(testing_entity, qa_entities, llm, knowledge_entity.tokenizer)
current_stage += 1
await TestingWorker.report(task_id, "测试完成", current_stage, stage_cnt)
testing_entity = await TestingWorker.update_testing_score(testing_entity.id, testcase_entities)
@@ -431,11 +434,11 @@ class TestingWorker(BaseWorker):
await TestingWorker.generate_report_and_upload_to_minio(dataset_entity, testing_entity, testcase_entities, tmp_path)
current_stage += 1
await TestingWorker.report(task_id, "生成报告并上传到minio", current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
except Exception as e:
err = f"[TestingWorker] 任务失败,task_id: {task_id}, 错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await TestingWorker.report(task_id, "任务失败", 0, 1)
@staticmethod
diff --git a/data_chain/apps/base/task/worker/export_dataset_worker.py b/data_chain/apps/base/task/worker/export_dataset_worker.py
index 8d9b0812b11a3911629573f2fef7c34f21bd8465..082326a86ae6216814c7d38ca5e237e66748d948 100644
--- a/data_chain/apps/base/task/worker/export_dataset_worker.py
+++ b/data_chain/apps/base/task/worker/export_dataset_worker.py
@@ -23,7 +23,7 @@ from data_chain.manager.chunk_manager import ChunkManager
from data_chain.manager.dataset_manager import DatasetManager
from data_chain.manager.qa_manager import QAManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity
+from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
@@ -190,11 +190,11 @@ class ExportDataSetWorker(BaseWorker):
await ExportDataSetWorker.upload_file_to_minio(task_id, zip_path)
current_stage += 1
await ExportDataSetWorker.report(task_id, "上传文件到minio", current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
except Exception as e:
err = f"[ExportDataSetWorker] 任务失败,task_id: {task_id}, 错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await ExportDataSetWorker.report(task_id, "任务失败", 0, 1)
@staticmethod
diff --git a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py
index 5995debc839b13b94bca7c52ff13007013bce1de..2e50c55582b98ba7803f4d919193e76830feab51 100644
--- a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py
+++ b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py
@@ -13,7 +13,7 @@ from data_chain.manager.task_manager import TaskManager
from data_chain.manager.knowledge_manager import KnowledgeBaseManager
from data_chain.manager.document_manager import DocumentManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, DocumentEntity
+from data_chain.stores.database.database import TaskEntity, DocumentEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
@@ -197,11 +197,11 @@ class ExportKnowledgeBaseWorker(BaseWorker):
await ExportKnowledgeBaseWorker.upload_zip_to_minio(zip_path, task_id)
current_stage += 1
await ExportKnowledgeBaseWorker.report(task_id, "上传压缩包到minio", current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
except Exception as e:
err = f"[ExportKnowledgeBaseWorker] 运行任务失败,task_id: {task_id},错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await ExportKnowledgeBaseWorker.report(task_id, err, 0, 1)
@staticmethod
diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py
index eb01850191b85f987584b370898b704d9e56f832..552cdac8b4ca8947e0f0e53dcbe9edafa40cceb1 100644
--- a/data_chain/apps/base/task/worker/generate_dataset_worker.py
+++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py
@@ -15,12 +15,13 @@ from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus,
from data_chain.entities.common import DEFAULT_DOC_TYPE_ID
from data_chain.parser.tools.token_tool import TokenTool
from data_chain.manager.task_manager import TaskManager
+from data_chain.manager.knowledge_manager import KnowledgeBaseManager
from data_chain.manager.document_manager import DocumentManager
from data_chain.manager.chunk_manager import ChunkManager
from data_chain.manager.dataset_manager import DatasetManager
from data_chain.manager.qa_manager import QAManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity
+from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
@@ -116,7 +117,9 @@ class GenerateDataSetWorker(BaseWorker):
return doc_chunks
@staticmethod
- async def generate_qa(dataset_entity: DataSetEntity, doc_chunks: list[DocChunk], llm: LLM) -> list[QAEntity]:
+ async def generate_qa(
+ dataset_entity: DataSetEntity, doc_chunks: list[DocChunk],
+ llm: LLM, language: str) -> list[QAEntity]:
chunk_cnt = 0
for doc_chunk in doc_chunks:
chunk_cnt += len(doc_chunk.chunks)
@@ -138,9 +141,12 @@ class GenerateDataSetWorker(BaseWorker):
random.shuffle(doc_chunks)
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- q_generate_prompt_template = prompt_dict.get('GENREATE_QUESTION_FROM_CONTENT_PROMPT', '')
- answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', '')
- cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '')
+ q_generate_prompt_template = prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {})
+ q_generate_prompt_template = q_generate_prompt_template.get(language, '')
+ answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', {})
+ answer_generate_prompt_template = answer_generate_prompt_template.get(language, '')
+ cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', {})
+ cal_qa_score_prompt_template = cal_qa_score_prompt_template.get(language, '')
dataset_score = 0
logging.error(f"{chunk_index_list}")
exist_q_set = set()
@@ -301,18 +307,19 @@ class GenerateDataSetWorker(BaseWorker):
doc_chunks = await GenerateDataSetWorker.get_chunks(dataset_entity)
current_stage += 1
await GenerateDataSetWorker.report(task_id, "获取文档分块信息", current_stage, stage_cnt)
+ knowlege_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(dataset_entity.kb_id)
qa_entities = await GenerateDataSetWorker.generate_qa(
- dataset_entity, doc_chunks, llm)
+ dataset_entity, doc_chunks, llm, knowlege_base_entity.tokenizer)
current_stage += 1
await GenerateDataSetWorker.report(task_id, "生成QA", current_stage, stage_cnt)
await GenerateDataSetWorker.add_qa_to_db(qa_entities)
current_stage += 1
await GenerateDataSetWorker.report(task_id, "添加QA到数据库", current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
except Exception as e:
err = f"[GenerateDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await GenerateDataSetWorker.report(task_id, err, 0, 1)
@staticmethod
diff --git a/data_chain/apps/base/task/worker/import_dataset_worker.py b/data_chain/apps/base/task/worker/import_dataset_worker.py
index 02451f56bc546f0ab8430bb53e04539e956c303f..86b9809295dadb7346e55cf741697e8904ba0c63 100644
--- a/data_chain/apps/base/task/worker/import_dataset_worker.py
+++ b/data_chain/apps/base/task/worker/import_dataset_worker.py
@@ -21,8 +21,9 @@ from data_chain.manager.task_manager import TaskManager
from data_chain.manager.chunk_manager import ChunkManager
from data_chain.manager.dataset_manager import DatasetManager
from data_chain.manager.qa_manager import QAManager
+from data_chain.manager.knowledge_manager import KnowledgeBaseManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity
+from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
@@ -185,14 +186,15 @@ class ImportDataSetWorker(BaseWorker):
return qa_entities
@staticmethod
- async def update_dataset_score(dataset_id: uuid.UUID, qa_entities: list[QAEntity], llm: LLM) -> None:
+ async def update_dataset_score(dataset_id: uuid.UUID, qa_entities: list[QAEntity], llm: LLM, language: str) -> None:
'''更新数据集分数'''
if not qa_entities:
return
databse_score = 0
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '')
+ cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', {})
+ cal_qa_score_prompt_template = cal_qa_score_prompt_template.get(language, '')
for qa_entity in qa_entities:
chunk = qa_entity.chunk
question = qa_entity.question
@@ -234,6 +236,7 @@ class ImportDataSetWorker(BaseWorker):
await DatasetManager.update_dataset_by_dataset_id(dataset_entity.id, {"status": DataSetStatus.IMPORTING.value})
current_stage = 0
stage_cnt = 4
+ knowlege_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(dataset_entity.kb_id)
tmp_path = await ImportDataSetWorker.init_path(task_id)
current_stage += 1
await ImportDataSetWorker.report(task_id, "初始化路径", current_stage, stage_cnt)
@@ -243,14 +246,14 @@ class ImportDataSetWorker(BaseWorker):
qa_entities = await ImportDataSetWorker.load_qa_entity_from_file(dataset_entity.id, file_path)
current_stage += 1
await ImportDataSetWorker.report(task_id, "加载qa实体", current_stage, stage_cnt)
- await ImportDataSetWorker.update_dataset_score(dataset_entity.id, qa_entities, llm)
+ await ImportDataSetWorker.update_dataset_score(dataset_entity.id, qa_entities, llm, knowlege_base_entity.tokenizer)
current_stage += 1
await ImportDataSetWorker.report(task_id, "更新数据集分数", current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
except Exception as e:
err = f"[ImportDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await ImportDataSetWorker.report(task_id, "任务失败", 0, 1)
@staticmethod
diff --git a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py
index b7dcc5eaa49aa00de91761d8008c61d70b86651e..c85f0c98f9ef89659747019aad0c40a8e97583d7 100644
--- a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py
+++ b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py
@@ -15,7 +15,7 @@ from data_chain.manager.knowledge_manager import KnowledgeBaseManager
from data_chain.manager.document_type_manager import DocumentTypeManager
from data_chain.manager.document_manager import DocumentManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity
+from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
@@ -223,11 +223,11 @@ class ImportKnowledgeBaseWorker(BaseWorker):
await ImportKnowledgeBaseWorker.init_doc_parse_tasks(kb_id)
current_stage += 1
await ImportKnowledgeBaseWorker.report(task_id, "初始化文档解析任务", current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
except Exception as e:
err = f"[ImportKnowledgeBaseWorker] 任务失败,task_id: {task_id},错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await ImportKnowledgeBaseWorker.report(task_id, err, 0, 1)
@staticmethod
@@ -252,7 +252,7 @@ class ImportKnowledgeBaseWorker(BaseWorker):
err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}"
logging.exception(err)
return None
- if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value:
+ if task_entity.status == TaskStatus.CANCLED.value or task_entity.status == TaskStatus.FAILED.value:
await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELETED.value})
await MinIO.delete_object(IMPORT_KB_PATH_IN_MINIO, str(task_entity.op_id))
return task_id
diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py
index 50a5996477379dce6c034bae893cd6e3f04c088e..2ab7810b5803b65326ca6e70ff42ab5c5ee383de 100644
--- a/data_chain/apps/base/task/worker/parse_document_worker.py
+++ b/data_chain/apps/base/task/worker/parse_document_worker.py
@@ -8,6 +8,7 @@ import random
import io
import numpy as np
from PIL import Image
+import asyncio
from data_chain.parser.tools.ocr_tool import OcrTool
from data_chain.parser.tools.token_tool import TokenTool
from data_chain.parser.tools.image_tool import ImageTool
@@ -29,7 +30,7 @@ from data_chain.manager.chunk_manager import ChunkManager
from data_chain.manager.image_manager import ImageManager
from data_chain.manager.task_report_manager import TaskReportManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
-from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity
+from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity, TaskQueueEntity
from data_chain.stores.minio.minio import MinIO
from data_chain.stores.mongodb.mongodb import Task
@@ -153,7 +154,8 @@ class ParseDocumentWorker(BaseWorker):
return str(js)
@staticmethod
- async def handle_parse_result(parse_result: ParseResult, doc_entity: DocumentEntity, llm: LLM = None) -> None:
+ async def handle_parse_result(
+ parse_result: ParseResult, doc_entity: DocumentEntity, llm: LLM = None, language: str = "中文") -> None:
'''处理解析结果'''
if doc_entity.parse_method == ParseMethod.GENERAL.value or doc_entity.parse_method == ParseMethod.QA.value:
nodes = []
@@ -219,7 +221,7 @@ class ParseDocumentWorker(BaseWorker):
node.text_feature = node.content
elif node.type == ChunkType.CODE:
if llm is not None:
- node.text_feature = await TokenTool.get_abstract_by_llm(node.content, llm)
+ node.text_feature = await TokenTool.get_abstract_by_llm(node.content, llm, language)
if node.text_feature is None:
node.text_feature = TokenTool.get_top_k_keywords(node.content)
elif node.type == ChunkType.TABLE:
@@ -271,24 +273,40 @@ class ParseDocumentWorker(BaseWorker):
index += 1024
@staticmethod
- async def ocr_from_parse_image(parse_result: ParseResult, llm: LLM = None) -> list:
+ async def ocr_from_parse_image(
+ parse_result: ParseResult, image_path: str, llm: LLM = None, language: str = '中文') -> None:
'''从解析图片中获取ocr'''
- for node in parse_result.nodes:
+ async def _ocr(node: ParseNode, language: str) -> None:
+ try:
+ image_related_text = ''
+ for related_node in node.link_nodes:
+ if related_node.type != ChunkType.IMAGE:
+ image_related_text += related_node.content + '\n'
+ extension = ImageTool.get_image_type(node.content)
+ image_file_path = os.path.join(image_path, str(node.id) + '.' + extension)
+ ocr_result = (await OcrTool.image_to_text(image_file_path, image_related_text, llm, language))
+ node.text_feature = ocr_result
+ node.content = ocr_result
+ except Exception as e:
+ err = f"[OCRTool] OCR识别失败: {e}"
+ logging.exception(err)
+ return None
+
+ image_node_ids = []
+ for i, node in enumerate(parse_result.nodes):
if node.type == ChunkType.IMAGE:
- try:
- image_blob = node.content
- image = Image.open(io.BytesIO(image_blob))
- img_np = np.array(image)
- image_related_text = ''
- for related_node in node.link_nodes:
- if related_node.type != ChunkType.IMAGE:
- image_related_text += related_node.content
- node.content = await OcrTool.image_to_text(img_np, image_related_text, llm)
- node.text_feature = node.content
- except Exception as e:
- err = f"[ParseDocumentWorker] OCR失败 error: {e}"
- logging.exception(err)
- continue
+ image_node_ids.append(i)
+ group_size = 5
+ index = 0
+ while index < len(image_node_ids):
+ sub_image_node_ids = image_node_ids[index:index + group_size]
+ task_list = []
+ for node_id in sub_image_node_ids:
+ # 通过asyncio.create_task来异步执行OCR
+ node = parse_result.nodes[node_id]
+ task_list.append(asyncio.create_task(_ocr(node, language)))
+ await asyncio.gather(*task_list)
+ index += group_size
@staticmethod
async def merge_and_split_text(parse_result: ParseResult, doc_entity: DocumentEntity) -> None:
@@ -369,13 +387,13 @@ class ParseDocumentWorker(BaseWorker):
parse_result.nodes = nodes
@staticmethod
- async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None) -> None:
+ async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None, language: str = '中文') -> None:
'''推送上层词特征'''
- async def dfs(node: ParseNode, parent_node: ParseNode, llm: LLM = None) -> None:
+ async def dfs(node: ParseNode, parent_node: ParseNode, llm: LLM = None, language: str = '中文') -> None:
if parent_node is not None:
node.pre_id = parent_node.id
for child_node in node.link_nodes:
- await dfs(child_node, node, llm)
+ await dfs(child_node, node, llm, language)
if node.title is not None:
if len(node.title) == 0:
if llm is not None:
@@ -388,7 +406,7 @@ class ParseDocumentWorker(BaseWorker):
if sentences:
content += sentences[0] + '\n'
if content:
- title = await TokenTool.get_title_by_llm(content, llm)
+ title = await TokenTool.get_title_by_llm(content, llm, language)
if "无法生成标题" in title:
title = ''
else:
@@ -406,22 +424,24 @@ class ParseDocumentWorker(BaseWorker):
node.text_feature = node.title
node.content = node.text_feature
if parse_result.parse_topology_type == DocParseRelutTopology.TREE:
- await dfs(parse_result.nodes[0], None, llm)
+ await dfs(parse_result.nodes[0], None, llm, language)
@staticmethod
- async def update_doc_abstract(doc_id: uuid.UUID, parse_result: ParseResult, llm: LLM = None) -> str:
- '''获取文档摘要'''
- abstract = ""
+ async def update_doc_abstract_and_full_text(
+ doc_id: uuid.UUID, parse_result: ParseResult, llm: LLM = None, language: str = "中文") -> str:
+ '''获取文档摘要和全文'''
+ full_text = ""
for node in parse_result.nodes:
- abstract += node.content
+ full_text += node.content
if llm is not None:
- abstract = await TokenTool.get_abstract_by_llm(abstract, llm)
+ abstract = await TokenTool.get_abstract_by_llm(full_text, llm, language)
else:
- abstract = abstract[:128]
+ abstract = full_text[:128]
abstract_vector = await Embedding.vectorize_embedding(abstract)
await DocumentManager.update_document_by_doc_id(
doc_id,
{
+ "full_text": full_text,
"abstract": abstract,
"abstract_vector": abstract_vector
}
@@ -431,9 +451,21 @@ class ParseDocumentWorker(BaseWorker):
@staticmethod
async def embedding_chunk(parse_result: ParseResult) -> None:
'''嵌入chunk'''
- for node in parse_result.nodes:
+ async def _embedding(node: ParseNode) -> None:
node.vector = await Embedding.vectorize_embedding(node.text_feature)
+ group_size = 32
+ index = 0
+ while index < len(parse_result.nodes):
+ sub_nodes = parse_result.nodes[index:index + group_size]
+ task_list = []
+ for node in sub_nodes:
+ # 与OCR代码风格保持一致
+ task_list.append(asyncio.create_task(_embedding(node)))
+ # 直接await任务集合
+ await asyncio.gather(*task_list)
+ index += group_size
+
@staticmethod
async def add_parse_result_to_db(parse_result: ParseResult, doc_entity: DocumentEntity) -> None:
'''添加解析结果到数据库'''
@@ -507,6 +539,7 @@ class ParseDocumentWorker(BaseWorker):
tmp_path, image_path = await ParseDocumentWorker.init_path(task_id)
current_stage = 0
stage_cnt = 10
+ knowledge_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(doc_entity.kb_id)
await ParseDocumentWorker.download_doc_from_minio(task_entity.op_id, tmp_path)
current_stage += 1
await ParseDocumentWorker.report(task_id, '下载文档', current_stage, stage_cnt)
@@ -514,31 +547,31 @@ class ParseDocumentWorker(BaseWorker):
parse_result = await ParseDocumentWorker.parse_doc(doc_entity, file_path)
current_stage += 1
await ParseDocumentWorker.report(task_id, '解析文档', current_stage, stage_cnt)
- await ParseDocumentWorker.handle_parse_result(parse_result, doc_entity, llm)
+ await ParseDocumentWorker.handle_parse_result(parse_result, doc_entity, llm, knowledge_base_entity.tokenizer)
current_stage += 1
await ParseDocumentWorker.report(task_id, '处理解析结果', current_stage, stage_cnt)
await ParseDocumentWorker.upload_parse_image_to_minio_and_postgres(parse_result, doc_entity, image_path)
current_stage += 1
await ParseDocumentWorker.report(task_id, '上传解析图片', current_stage, stage_cnt)
- await ParseDocumentWorker.ocr_from_parse_image(parse_result, llm)
+ await ParseDocumentWorker.ocr_from_parse_image(parse_result, image_path, llm, knowledge_base_entity.tokenizer)
current_stage += 1
await ParseDocumentWorker.report(task_id, 'OCR图片', current_stage, stage_cnt)
await ParseDocumentWorker.merge_and_split_text(parse_result, doc_entity)
current_stage += 1
await ParseDocumentWorker.report(task_id, '合并和拆分文本', current_stage, stage_cnt)
- await ParseDocumentWorker.push_up_words_feature(parse_result, llm)
+ await ParseDocumentWorker.push_up_words_feature(parse_result, llm, knowledge_base_entity.tokenizer)
current_stage += 1
await ParseDocumentWorker.report(task_id, '推送上层词特征', current_stage, stage_cnt)
await ParseDocumentWorker.embedding_chunk(parse_result)
current_stage += 1
await ParseDocumentWorker.report(task_id, '嵌入chunk', current_stage, stage_cnt)
- await ParseDocumentWorker.update_doc_abstract(doc_entity.id, parse_result, llm)
+ await ParseDocumentWorker.update_doc_abstract_and_full_text(doc_entity.id, parse_result, llm, knowledge_base_entity.tokenizer)
current_stage += 1
- await ParseDocumentWorker.report(task_id, '更新文档摘要', current_stage, stage_cnt)
+ await ParseDocumentWorker.report(task_id, '更新文档摘要和全文', current_stage, stage_cnt)
await ParseDocumentWorker.add_parse_result_to_db(parse_result, doc_entity)
current_stage += 1
await ParseDocumentWorker.report(task_id, '添加解析结果到数据库', current_stage, stage_cnt)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.SUCCESS.value))
task_report = await ParseDocumentWorker.assemble_task_report(task_id)
report_path = os.path.join(tmp_path, 'task_report.txt')
with open(report_path, 'w') as f:
@@ -551,7 +584,7 @@ class ParseDocumentWorker(BaseWorker):
except Exception as e:
err = f"[DocParseWorker] 任务失败,task_id: {task_id},错误信息: {e}"
logging.exception(err)
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.FAILED.value))
await ParseDocumentWorker.report(task_id, err, 0, 1)
task_report = await ParseDocumentWorker.assemble_task_report(task_id)
report_path = os.path.join(tmp_path, 'task_report.txt')
diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py
index ee5921f8f9e2f9b5380ffe9c4512b86da338b1b0..4d76828cde61dd295032ff5dbdc1a07b0db252a4 100644
--- a/data_chain/apps/router/document.py
+++ b/data_chain/apps/router/document.py
@@ -22,10 +22,12 @@ from data_chain.entities.response_data import (
GetDocumentReportResponse,
UploadDocumentResponse,
ParseDocumentResponse,
+ ParseDocumentRealTimeResponse,
UpdateDocumentResponse,
DeleteDocumentResponse,
GetTemporaryDocumentStatusResponse,
UploadTemporaryDocumentResponse,
+ GetTemporaryDocumentTextResponse,
DeleteTemporaryDocumentResponse
)
from data_chain.apps.service.session_service import get_user_sub, verify_user
@@ -136,6 +138,15 @@ async def parse_docuement_by_doc_ids(
return ParseDocumentResponse(result=doc_ids)
+@router.post('/metadata', response_model=ParseDocumentRealTimeResponse, dependencies=[Depends(verify_user)])
+async def parse_docuement_realtime(
+ user_sub: Annotated[str, Depends(get_user_sub)],
+ docs: list[UploadFile] = File(...)
+):
+ doc_contents = await DocumentService.parse_docs_realtime(docs)
+ return ParseDocumentRealTimeResponse(result=doc_contents)
+
+
@router.put('', response_model=UpdateDocumentResponse, dependencies=[Depends(verify_user)])
async def update_doc_by_doc_id(
user_sub: Annotated[str, Depends(get_user_sub)],
@@ -177,6 +188,15 @@ async def upload_temporary_docs(
return UploadTemporaryDocumentResponse(result=doc_ids)
+@router.get('/temporary/text', response_model=GetTemporaryDocumentTextResponse,
+ dependencies=[Depends(verify_user)])
+async def get_temporary_docs_text(
+ user_sub: Annotated[str, Depends(get_user_sub)],
+ id: Annotated[UUID, Query()]):
+ doc_text = await DocumentService.get_temporary_doc_text(user_sub, id)
+ return GetTemporaryDocumentTextResponse(result=doc_text)
+
+
@router.post('/temporary/delete', response_model=DeleteTemporaryDocumentResponse, dependencies=[Depends(verify_user)])
async def delete_temporary_docs(
user_sub: Annotated[str, Depends(get_user_sub)],
diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py
index d1b706461a3e191cdfdb9b9403ef0fdc7ba5cecd..8d302c4c3b83881ba3d0b0de2e2089e59f56999f 100644
--- a/data_chain/apps/service/chunk_service.py
+++ b/data_chain/apps/service/chunk_service.py
@@ -142,6 +142,8 @@ class ChunkService:
doc_map = {doc_entity.id: doc_entity for doc_entity in doc_entities}
for doc_chunk in search_chunk_msg.doc_chunks:
doc_entity = doc_map.get(doc_chunk.doc_id)
+ doc_chunk.doc_author = doc_entity.author_name if doc_entity else ""
+ doc_chunk.doc_created_at = doc_entity.created_time.strftime('%Y-%m-%d %H:%M') if doc_entity else ""
doc_chunk.doc_abstract = doc_entity.abstract if doc_entity else ""
doc_chunk.doc_extension = doc_entity.extension if doc_entity else ""
doc_chunk.doc_size = doc_entity.size if doc_entity else 0
diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py
index dfe65b9d1ad5a972f7da38ec682e2ff86f70534d..7ec05556bb5381615813ff6bc6829f8dc35ea9ec 100644
--- a/data_chain/apps/service/document_service.py
+++ b/data_chain/apps/service/document_service.py
@@ -4,7 +4,9 @@ from fastapi import APIRouter, Depends, Query, Body, File, UploadFile
import uuid
import traceback
import shutil
+from typing import Union
import os
+import hashlib
from data_chain.entities.request_data import (
ListDocumentRequest,
UploadTemporaryRequest,
@@ -30,6 +32,8 @@ from data_chain.stores.minio.minio import MinIO
from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType, TaskStatus
from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, REPORT_PATH_IN_MINIO, DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID
from data_chain.logger.logger import logger as logging
+from data_chain.parser.parse_result import ParseResult
+from data_chain.parser.handler.base_parser import BaseParser
class DocumentService:
@@ -218,7 +222,7 @@ class DocumentService:
name=file_name,
extension=extension,
size=os.path.getsize(document_file_path),
- parse_method=ParseMethod.OCR.value,
+ parse_method=doc.parse_method.value,
parse_relut_topology=None,
chunk_size=1024,
type_id=DEFAULT_DOC_TYPE_ID,
@@ -255,6 +259,20 @@ class DocumentService:
await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=DEFAULT_KNOWLEDGE_BASE_ID)
return doc_ids
+ @staticmethod
+ async def get_temporary_doc_text(user_sub: str, doc_id: uuid.UUID):
+ """获取临时文档解析结果文本"""
+ doc_entity = await DocumentManager.get_document_by_doc_id(doc_id)
+ if doc_entity is None:
+ err = f"获取临时文档失败, 文档ID: {doc_id}"
+ logging.error("[DocumentService] %s", err)
+ raise ValueError(err)
+ if doc_entity.author_id != user_sub:
+ err = f"用户没有权限访问临时文档, 文档ID: {doc_entity.id}, 用户ID: {user_sub}"
+ logging.error("[DocumentService] %s", err)
+ raise PermissionError(err)
+ return doc_entity.full_text
+
@staticmethod
async def delete_temporary_docs(user_sub: str, doc_ids: list[uuid.UUID]) -> list[uuid.UUID]:
"""删除临时文档"""
@@ -379,6 +397,42 @@ class DocumentService:
logging.exception("[DocumentService] %s", err)
raise e
+ @staticmethod
+ async def parse_docs_realtime(docs: list[UploadFile]) -> list[Union[ParseResult, None]]:
+ """实时解析文档"""
+ parse_results = []
+ tmp_path = os.path.join(DOC_PATH_IN_OS, str(uuid.uuid4()))
+ for doc in docs:
+ try:
+ if os.path.exists(tmp_path):
+ shutil.rmtree(tmp_path)
+ os.makedirs(tmp_path)
+ doc_path = os.path.join(tmp_path, doc.filename)
+ doc_hash = None
+ async with aiofiles.open(doc_path, "wb") as f:
+ content = await doc.read()
+ doc_hash = await hashlib.sha256(content).hexdigest()
+ await f.write(content)
+ # 获取文件扩展名
+ extension = doc.filename.split('.')[-1]
+ if not extension:
+ parse_results.append(None)
+ if os.path.exists(tmp_path):
+ shutil.rmtree(tmp_path)
+ continue
+ parse_result = await BaseParser.parser(extension, tmp_path)
+ parse_result.doc_hash = doc_hash[64:]
+ parse_results.append(parse_result)
+ if os.path.exists(tmp_path):
+ shutil.rmtree(tmp_path)
+ except Exception as e:
+ err = f"实时解析文档失败, 文档名: {doc.filename}, 错误信息: {e}"
+ logging.error("[DocumentService] %s", err)
+ parse_results.append(None)
+ if os.path.exists(tmp_path):
+ shutil.rmtree(tmp_path)
+ return parse_results
+
@staticmethod
async def update_doc(doc_id: uuid.UUID, req: UpdateDocumentRequest) -> uuid.UUID:
"""更新文档"""
diff --git a/data_chain/apps/service/llm_service.py b/data_chain/apps/service/llm_service.py
deleted file mode 100644
index 9aee1781fa6751733994e7fa62dce554ab0d9cb7..0000000000000000000000000000000000000000
--- a/data_chain/apps/service/llm_service.py
+++ /dev/null
@@ -1,145 +0,0 @@
-from typing import List
-import time
-import yaml
-import json
-import jieba
-from data_chain.models.service import ModelDTO
-from data_chain.logger.logger import logger as logging
-from data_chain.config.config import config
-from data_chain.apps.base.model.llm import LLM
-from data_chain.parser.tools.split import split_tools
-from data_chain.apps.base.security.security import Security
-def load_stopwords(file_path):
- with open(file_path, 'r', encoding='utf-8') as f:
- stopwords = set(line.strip() for line in f)
- return stopwords
-
-
-def filter_stopwords(text):
- words = jieba.lcut(text)
- stop_words = load_stopwords(config['STOP_WORDS_PATH'])
- filtered_words = [word for word in words if word not in stop_words]
- return filtered_words
-
-
-async def question_rewrite(history: List[dict], question: str,model_dto:ModelDTO=None) -> str:
- if not history:
- return question
- try:
- st = time.time()
- with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
- prompt_template_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt = prompt_template_dict['INTENT_DETECT_PROMPT_TEMPLATE']
- history_prompt = ""
- q_cnt = 0
- a_cnt = 0
- history_abstract_list = []
- sum_tokens = 0
- for item in history:
- history_abstract_list.append(item['content'])
- sum_tokens += split_tools.get_tokens(item['content'])
- used_tokens = split_tools.get_tokens(prompt + question)
- maxtokens=config['MODELS'][0]['MAX_TOKENS']
- if model_dto is not None:
- maxtokens=model_dto.max_tokens
- # 计算 history_prompt 的长度
- if sum_tokens > maxtokens - used_tokens:
- filtered_history = []
- # 使用 jieba 分词并去除停用词
- for item in history_abstract_list:
- filtered_words = filter_stopwords(item)
- filtered_history_prompt = ''.join(filtered_words)
- filtered_history.append(filtered_history_prompt)
- history_abstract_list = filtered_history
-
- character = 'user'
- for item in history_abstract_list:
- if character == 'user':
- history_prompt += "用户历史问题" + str(q_cnt) + ':' + item + "\n"
- character = 'assistant'
- q_cnt += 1
- elif character == 'assistant':
- history_prompt += "模型历史回答" + str(a_cnt) + ':' + item + "\n"
- a_cnt += 1
- character = 'user'
- if split_tools.get_tokens(history_prompt) > maxtokens - used_tokens:
- splited_prompt = split_tools.split_words(history_prompt)
- splited_prompt = splited_prompt[-(maxtokens - used_tokens):]
- history_prompt = ''.join(splited_prompt)
- prompt = prompt.format(history=history_prompt, question=question)
- user_call = "请输出改写后的问题"
- default_llm = LLM(model_name=config['MODELS'][0]['MODEL_NAME'],
- openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'],
- openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'],
- max_tokens=config['MODELS'][0]['MAX_TOKENS'],
- request_timeout=60,
- temperature=0.35)
- if model_dto is not None:
- default_llm = LLM(model_name=model_dto.model_name,
- openai_api_base=model_dto.openai_api_base,
- openai_api_key=model_dto.openai_api_key,
- max_tokens=model_dto.max_tokens,
- request_timeout=60,
- temperature=0.35)
- rewrite_question = await default_llm.nostream([], prompt, user_call)
- logging.info(f'改写后的问题为:{rewrite_question}')
- logging.info(f'问题改写耗时:{time.time() - st}')
- return rewrite_question
- except Exception as e:
- logging.error(f"Rewrite question failed due to: {e}")
- return question
-
-
-async def question_split(question: str) -> List[str]:
- # TODO: 问题拆分
- return [question]
-
-
-async def get_llm_answer(history, bac_info, question, is_stream=True,model_dto:ModelDTO=None):
- try:
- with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt = prompt_dict['LLM_PROMPT_TEMPLATE']
- prompt = prompt.format(bac_info=bac_info)
- except Exception as e:
- logging.error(f'Get prompt failed : {e}')
- raise e
- llm = LLM(
- openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'],
- openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'],
- model_name=config['MODELS'][0]['MODEL_NAME'],
- max_tokens=config['MODELS'][0]['MAX_TOKENS'])
- if model_dto is not None:
- llm = LLM(model_name=model_dto.model_name,
- openai_api_base=model_dto.openai_api_base,
- openai_api_key=model_dto.openai_api_key,
- max_tokens=model_dto.max_tokens
- )
- if is_stream:
- return llm.stream(history, prompt, question)
- res = await llm.nostream(history, prompt, question)
- return res
-
-
-async def get_question_chunk_relation(question, chunk,model_dto:ModelDTO=None):
- with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
- prompt_template_dict = yaml.load(f, Loader=yaml.SafeLoader)
-
- prompt = prompt_template_dict['DETERMINE_ANSWER_AND_QUESTION']
- prompt = prompt.format(chunk=chunk, question=question)
- user_call = "判断,并输出关联性编号"
- default_llm = LLM(model_name=config['MODELS'][0]['MODEL_NAME'],
- openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'],
- openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'],
- max_tokens=config['MODELS'][0]['MAX_TOKENS'],
- request_timeout=60,
- temperature=0.35)
- if model_dto is not None:
- default_llm = LLM(model_name=model_dto.model_name,
- openai_api_base=model_dto.openai_api_base,
- openai_api_key=model_dto.openai_api_key,
- max_tokens=model_dto.max_tokens,
- request_timeout=60,
- temperature=0.35)
- ans = await default_llm.nostream([], prompt, user_call)
- return ans
diff --git a/data_chain/apps/service/session_service.py b/data_chain/apps/service/session_service.py
index 1f62060aaebb9cb457533c2996bfc0450df87788..17346e15744019f3d7f67b012bd399962cb9cac7 100644
--- a/data_chain/apps/service/session_service.py
+++ b/data_chain/apps/service/session_service.py
@@ -22,8 +22,13 @@ class UserHTTPException(HTTPException):
async def verify_user(request: HTTPConnection):
"""验证用户是否在Session中"""
+ import os
if config["DEBUG"]:
- return
+ user_sub = os.environ.get('USER') or os.environ.get('USERNAME')
+ if not user_sub:
+ user_sub = 'admin'
+ return user_sub
+
try:
session_id = None
auth_header = request.headers.get("Authorization")
@@ -45,8 +50,12 @@ async def verify_user(request: HTTPConnection):
async def get_user_sub(request: HTTPConnection) -> uuid:
"""从Session中获取用户"""
if config["DEBUG"]:
- await UserManager.add_user((await Convertor.convert_user_sub_to_user_entity('admin')))
- return "admin"
+ import os
+ user_sub = os.environ.get('USER') or os.environ.get('USERNAME')
+ if not user_sub:
+ user_sub = 'admin'
+ await UserManager.add_user((await Convertor.convert_user_sub_to_user_entity(user_sub)))
+ return user_sub
else:
try:
session_id = None
diff --git a/data_chain/apps/service/task_queue_service.py b/data_chain/apps/service/task_queue_service.py
index 2a16ab14b6083780dd204089f1e7dc12960de750..1a3179b396f19e134f7cafc73bb38ad35b75b6fd 100644
--- a/data_chain/apps/service/task_queue_service.py
+++ b/data_chain/apps/service/task_queue_service.py
@@ -4,7 +4,8 @@ import uuid
from typing import Optional
from data_chain.entities.enum import TaskType, TaskStatus
from data_chain.apps.base.task.worker.base_worker import BaseWorker
-from data_chain.stores.mongodb.mongodb import MongoDB, Task
+# from data_chain.stores.mongodb.mongodb import MongoDB, Task
+from data_chain.stores.database.database import TaskQueueEntity
from data_chain.manager.task_manager import TaskManager
from data_chain.manager.task_queue_mamanger import TaskQueueManager
from data_chain.logger.logger import logger as logging
@@ -22,7 +23,7 @@ class TaskQueueService:
if task_entity.status == TaskStatus.RUNNING.value:
flag = await BaseWorker.reinit(task_entity.id)
if flag:
- task = Task(_id=task_entity.id, status=TaskStatus.PENDING.value)
+ task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value)
await TaskQueueManager.update_task_by_id(task_entity.id, task)
else:
await BaseWorker.stop(task_entity.id)
@@ -30,11 +31,11 @@ class TaskQueueService:
else:
task = await TaskQueueManager.get_task_by_id(task_entity.id)
if task is None:
- task = Task(_id=task_entity.id, status=TaskStatus.PENDING.value)
+ task = TaskQueueEntity(id=task_entity.id, status=TaskStatus.PENDING.value)
await TaskQueueManager.add_task(task)
except Exception as e:
- warining = f"[TaskQueueService] 初始化任务失败 {e}"
- logging.warning(warining)
+ warning = f"[TaskQueueService] 初始化任务失败 {e}"
+ logging.warning(warning)
@staticmethod
async def init_task(task_type: str, op_id: uuid.UUID) -> uuid.UUID:
@@ -42,7 +43,7 @@ class TaskQueueService:
try:
task_id = await BaseWorker.init(task_type, op_id)
if task_id:
- await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.PENDING.value))
+ await TaskQueueManager.add_task(TaskQueueEntity(id=task_id, status=TaskStatus.PENDING.value))
return task_id
except Exception as e:
err = f"[TaskQueueService] 初始化任务失败 {e}"
@@ -75,53 +76,53 @@ class TaskQueueService:
async def handle_successed_tasks():
handle_successed_task_limit = 1024
for i in range(handle_successed_task_limit):
- task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.SUCCESS.value)
+ task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.SUCCESS)
if task is None:
break
try:
- await BaseWorker.deinit(task.task_id)
+ await BaseWorker.deinit(task.id)
except Exception as e:
err = f"[TaskQueueService] 处理成功任务失败 {e}"
logging.error(err)
- await TaskQueueManager.delete_task_by_id(task.task_id)
+ await TaskQueueManager.delete_task_by_id(task.id)
@staticmethod
async def handle_failed_tasks():
handle_failed_task_limit = 1024
for i in range(handle_failed_task_limit):
- task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.FAILED.value)
+ task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.FAILED)
if task is None:
break
try:
- flag = await BaseWorker.reinit(task.task_id)
+ flag = await BaseWorker.reinit(task.id)
except Exception as e:
err = f"[TaskQueueService] 处理失败任务失败 {e}"
logging.error(err)
- await TaskQueueManager.delete_task_by_id(task.task_id)
+ await TaskQueueManager.delete_task_by_id(task.id)
continue
if flag:
- task = Task(_id=task.task_id, status=TaskStatus.PENDING.value)
- await TaskQueueManager.update_task_by_id(task.task_id, task)
+ task.status = TaskStatus.PENDING.value
+ await TaskQueueManager.update_task_by_id(task.id, task)
else:
- await TaskQueueManager.delete_task_by_id(task.task_id)
+ await TaskQueueManager.delete_task_by_id(task.id)
@staticmethod
async def handle_pending_tasks():
handle_pending_task_limit = 128
for i in range(handle_pending_task_limit):
- task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.PENDING.value)
+ task = await TaskQueueManager.get_oldest_tasks_by_status(TaskStatus.PENDING)
if task is None:
break
try:
- flag = await BaseWorker.run(task.task_id)
+ flag = await BaseWorker.run(task.id)
except Exception as e:
err = f"[TaskQueueService] 处理待处理任务失败 {e}"
logging.error(err)
- await TaskQueueManager.delete_task_by_id(task.task_id)
+ await TaskQueueManager.delete_task_by_id(task.id)
continue
if not flag:
break
- await TaskQueueManager.delete_task_by_id(task.task_id)
+ await TaskQueueManager.delete_task_by_id(task.id)
@staticmethod
async def handle_tasks():
diff --git a/data_chain/common/pp.py b/data_chain/common/pp.py
index bfcfb50c37a24f3760f9dc7f317bd34e60532d58..dc35adf4a7426a0324ce7e682302ee583598c961 100644
--- a/data_chain/common/pp.py
+++ b/data_chain/common/pp.py
@@ -26,13 +26,8 @@ def save_yaml_file(yaml_data, file_path):
# 示例:加载YAML文件
file_path = './data_chain/common/prompt.yaml'
yaml_data = load_yaml_file(file_path)
-if yaml_data:
- print(yaml_data)
-# yaml_data['LLM_PROMPT_TEMPLATE']=''
-# yaml_data['INTENT_DETECT_PROMPT_TEMPLATE']=''
-# yaml_data['OCR_ENHANCED_PROMPT']=''
-# yaml_data['DETERMINE_ANSWER_AND_QUESTION']=''
-# save_yaml_file(yaml_data,file_path)
+print(yaml_data)
+# print(config.__dict__)
# llm = LLM(
# model_name=config['MODEL_NAME'],
# openai_api_base=config['OPENAI_API_BASE'],
@@ -41,34 +36,21 @@ if yaml_data:
# max_tokens=config['MAX_TOKENS'],
# temperature=config['TEMPERATURE'],
# )
-# prompt_template = yaml_data['CONTENT_TO_ABSTRACT_PROMPT']
-# content = '''在那遥远的山谷之中,有一片神秘而又美丽的森林。阳光透过茂密的枝叶,洒下一片片金色的光斑,仿佛是大自然精心编织的梦幻画卷。森林里,鸟儿欢快地歌唱,那清脆的歌声在林间回荡,传递着生机与活力。松鼠们在树枝间跳跃,敏捷的身影如同灵动的音符,谱写着森林的乐章。
-# 沿着蜿蜒的小径前行,脚下的落叶发出沙沙的声响,仿佛在诉说着岁月的故事。路边的野花竞相开放,红的、黄的、紫的,五彩斑斓,散发着阵阵芬芳。蝴蝶在花丛中翩翩起舞,它们那绚丽的翅膀,如同绚丽的丝绸,在微风中轻轻摇曳。
-# 不远处,一条清澈的小溪潺潺流淌。溪水从山间缓缓流下,清澈见底,能看到鱼儿在水中自由自在地游弋。溪水撞击着石头,发出叮叮咚咚的声音,宛如一首美妙的乐曲。溪边的石头上,长满了青苔,仿佛是大自然赋予的绿色绒毯。
-# 在森林的深处,隐藏着一座古老的城堡。城堡的墙壁上爬满了藤蔓,仿佛是岁月留下的痕迹。城堡的大门紧闭,似乎隐藏着无数的秘密。传说中,这座城堡里住着一位美丽的公主,她被邪恶的巫师困在了这里,等待着勇敢的骑士前来解救。
-# 有一天,一位年轻的骑士听闻了这个传说,决定踏上寻找公主的冒险之旅。他骑着一匹矫健的白马,手持长剑,穿过茂密的森林,越过湍急的河流,历经千辛万苦,终于来到了城堡的门前。
-# 骑士用力地敲打着城堡的大门,然而,大门却纹丝不动。就在他感到绝望的时候,一只小精灵出现在他的面前。小精灵告诉他,要打开城堡的大门,必须找到三把神奇的钥匙。这三把钥匙分别隐藏在森林的三个不同的地方,只有集齐了这三把钥匙,才能打开城堡的大门。
-# 骑士听了小精灵的话,毫不犹豫地踏上了寻找钥匙的旅程。他在森林里四处寻找,遇到了各种各样的困难和挑战。有时候,他会迷失在森林的深处,找不到方向;有时候,他会遇到凶猛的野兽,不得不与之搏斗。但是,骑士始终没有放弃,他坚信自己一定能够找到钥匙,救出公主。
-# 终于,经过一番艰苦的努力,骑士找到了三把神奇的钥匙。他拿着钥匙,来到城堡的门前,将钥匙插入锁孔。随着一阵清脆的响声,城堡的大门缓缓打开。骑士走进城堡,沿着昏暗的走廊前行,终于在一间房间里找到了公主。
-# 公主看到骑士,眼中闪烁着希望的光芒。她告诉骑士,自己被巫师困在这里已经很久了,一直在等待着有人来救她。骑士将公主带出城堡,骑着白马,离开了这片神秘的森林。
-# 从此以后,骑士和公主过上了幸福的生活。他们的故事在这片土地上流传开来,成为了人们心中的一段佳话。
-# 在这个世界上,还有许多未知的领域等待着我们去探索。也许,在那遥远的地方,还有更多神秘的故事等待着我们去发现。无论是茂密的森林,还是古老的城堡,都充满了无限的魅力。它们吸引着我们不断地前行,去追寻那未知的美好。
-# 当夜幕降临,天空中繁星闪烁。那璀璨的星光,仿佛是大自然赋予我们的最美的礼物。在这宁静的夜晚,我们可以静静地聆听大自然的声音,感受它的神奇与美妙。
-# 有时候,我们会在生活中遇到各种各样的困难和挫折。但是,只要我们像那位勇敢的骑士一样,坚持不懈,勇往直前,就一定能够克服困难,实现自己的梦想。生活就像一场冒险,充满了未知和挑战。我们要勇敢地面对生活中的一切,用自己的智慧和勇气去创造美好的未来。
-# 在这个充满变化的世界里,我们要学会珍惜身边的一切。无论是亲人、朋友,还是那美丽的大自然,都是我们生活中不可或缺的一部分。我们要用心去感受他们的存在,用爱去呵护他们。
-# 随着时间的推移,那片神秘的森林依然静静地矗立在那里。它见证了无数的故事,承载了无数的回忆。而那座古老的城堡,也依然默默地守护着那些神秘的传说。它们就像历史的见证者,诉说着过去的辉煌与沧桑。
-# 我们生活在一个充满希望和梦想的时代。每一个人都有自己的追求和目标,都在为了实现自己的梦想而努力奋斗。无论是科学家、艺术家,还是普通的劳动者,都在各自的岗位上发光发热,为社会的发展做出自己的贡献。
-# 在科技飞速发展的今天,我们的生活发生了翻天覆地的变化。互联网的普及,让我们的信息传播更加迅速和便捷。我们可以通过网络了解到世界各地的新闻和文化,与远方的朋友进行交流和沟通。科技的进步,也让我们的生活更加舒适和便利。我们有了更加先进的交通工具、更加便捷的通讯设备,以及更加高效的生活方式。
-# 然而,科技的发展也带来了一些问题。比如,环境污染、能源危机等。这些问题不仅影响着我们的生活质量,也威胁着我们的未来。因此,我们在享受科技带来的便利的同时,也要关注环境保护和可持续发展。我们要努力寻找更加绿色、环保的生活方式,减少对自然资源的消耗和对环境的破坏。
-# 除了科技的发展,文化的传承和创新也是我们生活中重要的一部分。每一个国家和民族都有自己独特的文化传统,这些文化传统是我们的精神财富,也是我们民族的灵魂。我们要传承和弘扬自己的文化传统,让它们在新的时代焕发出新的活力。同时,我们也要积极吸收和借鉴其他国家和民族的优秀文化成果,促进文化的交流和融合。
-# 在教育方面,我们要注重培养学生的创新精神和实践能力。我们要让学生在学习知识的同时,学会思考、学会创新、学会实践。只有这样,我们才能培养出适应时代发展需要的高素质人才。
-# 在人际交往中,我们要学会尊重他人、理解他人、关心他人。我们要建立良好的人际关系,与他人和谐相处。只有这样,我们才能在生活中感受到温暖和快乐。
-# 总之,我们的生活是丰富多彩的,充满了无限的可能。我们要珍惜生活中的每一个瞬间,用积极的态度去面对生活中的一切。无论是成功还是失败,无论是欢笑还是泪水,都是我们生活中的宝贵财富。让我们一起努力,创造一个更加美好的未来!'''
-# abstract = ''
-# for i in range(10):
-# part = TokenTool.get_k_tokens_words_from_content(content, 100)
-# content = content[len(part):]
-# sys_call = prompt_template.format(content=part, abstract=abstract)
-# user_call = '请详细输出内容的摘要,不要输出其他内容'
-# abstract = asyncio.run(llm.nostream([], sys_call, user_call))
-# print(abstract)
+# print(prompt_dict)
+# for key in prompt_dict:
+# prompt = prompt_dict[key]['zh']
+# systemcall = f"""
+# 你是一个翻译专家, 你需要将用户输入的中文内容翻译成地道的英文, 只需要返回翻译后的英文内容, 不需要任何多余的解释和说明.
+# 你需要严格遵守以下规则:
+# 1. 你只能翻译用户输入的内容, 不能添加任何额外的信息.
+# 2. 你需要确保翻译后的内容符合英文的语法和表达习惯.
+# 3. 你需要确保翻译后的内容准确传达用户输入的中文内容的意思.
+
+# 标签中的内容是用户输入的中文内容, 你需要将这些内容翻译成英文.
+# {prompt}
+# """
+# user_call = f"请将上面的内容翻译为英文"
+# result = asyncio.run(llm.nostream([], systemcall, user_call))
+# print(result)
+# prompt_dict[key]['en'] = result
+# print(prompt_dict)
diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml
index 70ba9e2b047707aac037375fef93b7d51956a42c..25ff89c0a46ec731936edb7dde5bb728e5069036 100644
--- a/data_chain/common/prompt.yaml
+++ b/data_chain/common/prompt.yaml
@@ -1,381 +1,708 @@
-INTENT_DETECT_PROMPT_TEMPLATE: "\n\n \n \
- \ 根据历史对话,推断用户的实际意图并补全用户的提问内容。\n 用户的提问内容将在中给出,历史对话将在中给出。\n\
- \ 要求:\n 1. 参考下面给出的样例,请直接输出补全后的提问内容;输出不要包含XML标签,不要包含任何解释说明;\n\
- \ 2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。\n \
- \ 3. 补全内容必须精准、恰当,不要编造任何内容。\n \n\n \n\
- \ openEuler是什么 \n 有什么特点\n\
- \ \n \n \n\
- \n \n {history}\n \n\
- \ \n {question}\n \n"
-LLM_PROMPT_TEMPLATE: "\n \n 你是EulerCopilot,openEuler社区的智能助手。请结合给出的背景信息,\
- \ 回答用户的提问。\n 上下文背景信息将在中给出。\n 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。\n\
- \ \n\n \n {bac_info}\n \
- \ \n"
-OCR_ENHANCED_PROMPT: '你是一个图片ocr内容总结专家,你的任务是根据我提供的上下文、相邻图片组描述、当前图片上一次的ocr内容总结、当前图片部分ocr的结果(包含文字和文字的相对坐标)给出图片描述.
-
- 注意:
-
- #01 必须使用大于200字小于500字详细详细描述这个图片的内容,可以详细列出数据.
-
- #02 如果这个图是流程图,请按照流程图顺序描述内容。
-
- #03 如果这张图是表格,请用markdown形式输出表格内容 .
-
- #04 如果这张图是架构图,请按照架构图层次结构描述内容。
-
- #05 总结的图片描述必须包含图片中的主要信息,不能只描述图片位置。
-
- #6 图片识别结果中相邻的文字可能是同一段落的内容,请合并后总结
-
- #7 文字可能存在错位,请修正顺序后进行总结
-
- #8 请仅输出图片的总结即可,不要输出其他内容
-
- #9 不要输出坐标等信息,输出每个部分相对位置的描述即可
-
- #10 如果图片内容为空,请输出“图片内容为空”
-
- #11 如果图片本身就是一段文字,请直接输出文字内容
-
- 上下文:{image_related_text}
-
- 当前图片上一部分的ocr内容总结:{pre_part_description}
-
- 当前图片部分ocr的结果:{part}'
-QA_TO_STATEMENTS_PROMPT: '你是一个文本分解专家,你的任务是根据我给出的问题和答案,将答案提取为多个陈诉,陈诉使用列表形式返回
-
- 注意:
- #01 陈诉必须来源于答案中的重点内容
- #02 陈诉按相对顺序排列
- #03 输出的单个陈诉长度不超过50个字
- #04 输出的陈诉总数不超过20个
- #05 请仅输出陈诉列表,不要输出其他内容
- 例子:
-
- 输入:
- 问题:openEuler是什么操作系统?
- 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出:
- [
- \"openEuler是一个开源的操作系统\",
- \"openEuler旨在为云计算和边缘计算提供支持\",
- \"openEuler具有高性能、高安全性和高可靠性等特点\"
- ]
-
- 下面是给出的问题和答案:
- 问题:{question}
- 答案:{answer}
-'
-ANSWER_TO_ANSWER_PROMPT: '你是一个文本分析专家,你的任务对比两个文本之间的相似度,并输出一个0-100之间的分数且保留两位小数:
-注意:
-#01 请根据文本在语义、语序和关键字上的相似度进行打分
-#02 如果两个文本在核心表达上一致,那么分数也相对高
-#03 一个文本包含另一个文本的核心内容,那么分数也相对高
-#04 两个文本间内容有重合,那么按照重合内容的比例打分
-#05 请仅输出分数,不要输出其他内容
-例子:
-输入1:
- 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 文本2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出1:100.00
-输入2:
- 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 文本2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能和高安全性等特点。
- 输出2:90.00
-输入3:
- 文本1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 文本2:白马非马
- 输出3:00.00
-下面是给出的文本:
- 文本1:{text_1}
- 文本2:{text_2}
-'
-CONTENT_TO_STATEMENTS_PROMPT: '你是一个文本分解专家,你的任务是根据我给出的文本,将文本提取为多个陈诉,陈诉使用列表形式返回
-
- 注意:
- #01 陈诉必须来源于文本中的重点内容
- #02 陈诉按相对顺序排列
- #03 输出的单个陈诉长度不少于20个字,不超过50个字
- #04 输出的陈诉总数不超过3个
- #05 请仅输出陈诉列表,不要输出其他内容
- 例子:
-
- 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出:
- [
- \"openEuler是一个开源的操作系统\",
- \"openEuler旨在为云计算和边缘计算提供支持\",
- \"openEuler具有高性能、高安全性和高可靠性等特点\"
- ]
-
- 下面是给出的文本:
- {content}
- '
-STATEMENTS_TO_FRAGMENT_PROMPT: '你是一个文本专家,你的任务是根据给出的陈诉是否与片段强相关
- 注意:
- #01 如果陈诉与片段强相关或者来自于片段,请输出YES
- #02 如果陈诉中的内容与片段无关,请输出NO
- #03 如果陈诉是片段中某部分的提炼,请输出YES
- #05 请仅输出YES或NO,不要输出其他内容
- 例子:
- 输入1:
-
- 陈诉:openEuler是一个开源的操作系统。
- 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出1:YES
-
- 输入2:
- 陈诉:白马非马
- 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出2:NO
-
- 下面是给出的陈诉和片段:
- 陈诉:{statement}
- 片段:{fragment}
- '
-STATEMENTS_TO_QUESTION_PROMPT: '你是一个文本分析专家,你的任务是根据给出的陈诉和问题判断,陈诉是否与问题相关
- 注意:
- #01 如果陈诉是否与问题相关,请输出YES
- #02 如果陈诉与问题不相关,请输出NO
- #03 请仅输出YES或NO,不要输出其他内容
- #04 陈诉与问题相关是指,陈诉中的内容可以回答问题或者与问题在内容上有交集
- 例子:
- 输入1:
- 陈诉:openEuler是一个开源的操作系统。
+ACC_ANALYSIS_RESULT_MERGE_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to combine two analysis results and output a new one. Note:
+ #01 Please combine the content of the two analysis results to produce a new analysis result.
+ #02 Please analyze using the four metrics of recall, precision, faithfulness, and interpretability.
+ #03 The new analysis result must be no longer than 500 characters.
+ #04 Please output only the new analysis result; do not output any other content.
+ Example:
+ Input 1:
+ Analysis Result 1:
+ Recall: Currently, the recall is 95.00, with room for improvement. We will optimize the vectorized search algorithm to further mine information in the original fragment that is relevant to the question but not retrieved, such as some specific practical cases in the openEuler ecosystem. The embedding model bge-m3 will be adjusted to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer.
+ Accuracy: The accuracy is 99.00, which is quite high. However, further optimization is possible, including deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this can more precisely illustrate the specific manifestations of OpenEuler's high performance in cloud computing and edge computing.
+ Fidelity: The fidelity value is 90.00, indicating that some answers are not fully derived from the retrieved snippets. Optimizing the rag retrieval algorithm, improving the recall rate of the embedding model, and adjusting the text chunk size to 512 may be inappropriate and require re-evaluation based on the content. This ensures that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippets. For example, regarding the development of the openEuler ecosystem, relevant technical details should be obtained from the retrieved snippets.
+ Interpretability: The interpretability is 85.00, which is relatively low. Improve the compliance of the large model qwen2.5-32b and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieved snippets better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this makes the answer logic clearer and more targeted, improving overall interpretability.
+
+ Analysis Result 2:
+ The recall rate is currently 95.00. Further optimization of the rag retrieval algorithm and embedding model can be used to increase the semantic similarity between the generated answers and the standard answers, approaching or achieving a higher recall rate. For example, the algorithm can be continuously optimized to better match relevant snippets.
+ The precision is 99.00, close to the maximum score, indicating that the generated answers are semantically similar to the questions. However, further improvement is possible. This can be achieved by refining the embedding model to better understand the question semantics, optimizing the contextual completeness of the retrieved snippets, and reducing fluctuations in precision caused by insufficient context.
+ The faithfulness score is currently 90.00, indicating that some content in the generated answer is not fully derived from the retrieved snippet. The rag retrieval algorithm can be optimized to improve its recall rate. The text chunk size can also be adjusted appropriately to ensure that the retrieved snippet fully answers the question, thereby improving the faithfulness score.
+ Regarding interpretability, it is currently 85.00, indicating that the generated answer has room for improvement in terms of answering questions. On the one hand, the large model used can be optimized to improve its compliance, making the generated answer more accurate. On the other hand, the recall rates of the rag retrieval algorithm and embedding model can be further optimized to ensure that the retrieved snippet fully supports the answer and improve interpretability.
+
+ Output:
+ Recall: Currently at 95.00, there is room for improvement. The vectorized retrieval algorithm can be optimized to further uncover information in the original snippet that is relevant to the question but not retrieved, as demonstrated in some specific practical cases within the openEuler ecosystem. Adjustments were made to the embedding model bge-m3 to enable it to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and bring the generated answers closer to the standard answer.
+ Accuracy: The accuracy reached 99.00, which is already high. However, further optimization is needed to conduct deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this could more accurately demonstrate the specific characteristics of OpenEuler's high performance in cloud computing and edge computing.
+ Fidelity: The fidelity value was 90.00, indicating that some answer content was not fully derived from the retrieved snippet. The rag retrieval algorithm was optimized to improve the recall of the embedding model. Adjusting the text chunk size to 512 may be unreasonable and requires re-evaluation based on the content to ensure that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippet. For example, relevant technical details regarding the development of the OpenEuler ecosystem should be obtained from the retrieved snippet.
+ Interpretability: The interpretability value was 85.00, which is relatively low. Improve the compliance of the large qwen2.5-32b model and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieval fragments can better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this improves answer logic, makes it more targeted, and improves overall interpretability.
+
+ The following two analysis results:
+ Analysis Result 1: {analysis_result_1}
+ Analysis Result 2: {analysis_result_2}
+
+ 中文: |
+ 你是一个文本分析专家,你的任务融合两条分析结果输出一份新的分析结果。注意:
+ #01 请根据两条分析结果中的内容融合出一条新的分析结果
+ #02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析
+ #03 新的分析结果长度不超过500字
+ #04 请仅输出新的分析结果,不要输出其他内容
+ 例子:
+ 输入1:
+ 分析结果1:
+ 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
+ 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
+ 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
+ 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
+
+ 分析结果2:
+ 从召回率来看,目前为 95.00,可进一步优化 rag 检索算法和 embedding 模型,以提高生成答案与标准回答之间的语义相似程度,接近或达到更高的召回率,例如可以持续优化算法来更好地匹配相关片段。
+ 从精确度来看,为 99.00,接近满分,说明生成的答案与问题语义相似程度较高,但仍可进一步提升,可通过完善 embedding 模型来更好地理解问题语义,优化检索到的片段的上下文完整性,减少因上下文不足导致的精确度波动。
+ 对于忠实值,目前为 90.00,说明生成的答案中部分内容未完全来自检索到的片段。可优化 rag 检索算法,提高其召回率,同时合理调整文本分块大小,确保检索到的片段能充分回答问题,从而提高忠实值。
+ 关于可解释性,当前为 85.00,说明生成的答案在用于回答问题方面有一定提升空间。一方面可以优化使用的大模型,提高其遵从度,使其生成的答案更准确地回答问题;另一方面,继续优化 rag 检索算法和 embedding 模型的召回率,保证检索到的片段能全面支撑问题的回答,提高可解释性。
+
+ 输出:
+ 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
+ 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
+ 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
+ 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
+
+ 下面两条分析结果:
+ 分析结果1:{analysis_result_1}
+ 分析结果2:{analysis_result_2}
+
+ACC_RESULT_ANALYSIS_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to: analyze the large model used in the test, the embedding model used in the test, the parsing method and chunk size of related documents, the snippets matched by the RAG algorithm for a single test result, and propose methods to improve the accuracy of question-answering in the current knowledge base.
+
+ The test results include the following information:
+ - Question: The question used in the test
+ - Standard answer: The standard answer used in the test
+ - Generated answer: The answer output by the large model in the test results
+ - Original snippet: The original snippet provided in the test results
+ - Retrieved snippet: The snippet retrieved by the RAG algorithm in the test results
+
+ The four evaluation metrics are defined as follows:
+ - Precision: Evaluates the semantic similarity between the generated answer and the question. A lower score indicates lower compliance of the large model; additionally, it may mean the snippets retrieved by the RAG algorithm lack context and are insufficient to support the answer.
+ - Recall: Evaluates the semantic similarity between the generated answer and the standard answer. A lower score indicates lower compliance of the large model.
+ - Fidelity: Evaluates whether the content of the generated answer is derived from the retrieved snippet. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean the text chunk size is inappropriate.
+ - Interpretability: Evaluates whether the generated answer is useful for answering the question. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean lower compliance of the used large model.
+
+ Notes:
+ #01 Analyze methods to improve the accuracy of current knowledge base question-answering based on the test results.
+ #02 Conduct the analysis using the four metrics: Recall, Precision, Fidelity, and Interpretability.
+ #03 The analysis result must not exceed 500 words.
+ #04 Output only the analysis result; do not include any other content.
+
+ Example:
+ Input:
+ Model name: qwen2.5-32b
+ Embedding model: bge-m3
+ Text chunk size: 512
+ Used RAG algorithm: Vectorized retrieval
+ Question: What is OpenEuler?
+ Standard answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Generated answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Original snippet: openEuler is an open source operating system incubated and operated by the Open Atom Open Source Foundation. Its mission is to build an open source operating system ecosystem for digital infrastructure and provide solid underlying support for cutting-edge fields such as cloud computing and edge computing. In cloud computing scenarios, openEuler can fully optimize resource scheduling and allocation mechanisms. Through a lightweight kernel design and efficient virtualization technology, it significantly improves the responsiveness and throughput of cloud services. In edge computing, its exceptional low resource consumption and real-time processing capabilities ensure the timeliness and accuracy of data processing at edge nodes in complex environments. openEuler boasts a series of exceptional features: In terms of performance, its independently developed intelligent scheduling algorithm dynamically adapts to different load scenarios, and combined with deep optimization of hardware resources, significantly improves system efficiency. Regarding security, its built-in multi-layered security system, including mandatory access control, vulnerability scanning, and remediation mechanisms, provides a solid defense for system data and applications. Regarding reliability, its distributed storage, automatic fault detection, and rapid recovery technologies ensure stable system operation in the face of unexpected situations such as network fluctuations and hardware failures, minimizing the risk of service interruptions. These features make openEuler a crucial technological cornerstone for promoting high-quality development of the digital economy, helping enterprises and developers seize the initiative in digital transformation.
+ Retrieved snippet: As a pioneer in the open source operating system field, openEuler deeply integrates the wisdom of community developers and continuously iterates and upgrades to adapt to the rapidly changing technological environment. In the current era of prevalent microservices architectures, openEuler Through deep optimization of containerization technology and support for mainstream orchestration tools such as Kubernetes, it makes application deployment and management more convenient and efficient, significantly enhancing the flexibility of enterprise business deployments. At the same time, it actively embraces the AI era. By adapting and optimizing machine learning frameworks, it provides powerful computing power for AI model training and inference, effectively reducing the development and operating costs of AI applications. Regarding ecosystem development, openEuler boasts a large and active open source community, bringing together technology enthusiasts and industry experts from around the world, forming a complete ecosystem from kernel development and driver adaptation to application optimization. The community regularly hosts technical exchanges and developer conferences to promote knowledge sharing and technological innovation, providing developers with a wealth of learning resources and practical opportunities. Numerous hardware and software manufacturers have joined the openEuler ecosystem, launching solutions and products based on the system across key industries such as finance, telecommunications, and energy. These efforts, validated through real-world application scenarios and feeding back into openEuler's technological development, have fostered a virtuous cycle of innovation, making openEuler not just an operating system but a powerful engine driving collaborative industry development.
+ Recall: 95.00
+ Precision: 99.00
+ Fidelity: 90.00
+ Interpretability: 85.00
+
+ Output:
+ Based on the test results, methods for improving the accuracy of current knowledge base question-answering can be analyzed from the following aspects: Recall: The current recall is 95.00, with room for improvement. Optimize the vectorized retrieval algorithm to further mine question-related but unretrieved information in the original snippets, such as some specific practical cases in the openEuler ecosystem. Adjust the embedding model bge-m3 to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer. Precision: The accuracy reached 99.00, which is already high. However, further optimization is possible, including deeper semantic analysis of retrieved snippets. By combining the features of the large model qwen2.5-32b, this can accurately match the question semantics and avoid subtle semantic deviations. For example, more precise demonstration of openEuler's high performance in cloud computing and edge computing can be achieved. Fidelity: A fidelity score of 90.00 indicates that some answers are not fully derived from the search snippet. We optimized the rag retrieval algorithm, improved the recall of the embedding model, and adjusted the text chunk size to 512. This may be inappropriate and requires reassessment based on the content. We need to ensure that the search snippet contains sufficient context to support the answer, ensuring that the generated answer content is derived from the search snippet. For example, relevant technical details regarding the development of the openEuler ecosystem should be obtained from the search snippet. Interpretability: The interpretability score is 85.00, which is relatively low. We improved the compliance of the large model qwen2.5-32b and optimized the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that the search snippet better supports answer generation and clearly answers the question. For example, when answering openEuler-related questions, the answer logic is made clearer and more targeted, improving overall interpretability.
+
+ The following is the test result content:
+ Used large model: {model_name}
+ Embedding model: {embedding_model}
+ Text chunk size: {chunk_size}
+ Used RAG parsing algorithm: {rag_algorithm}
+ Question: {question}
+ Standard answer: {standard_answer}
+ Generated answer: {generated_answer}
+ Original fragment: {original_fragment}
+ Retrieved fragment: {retrieved_fragment}
+ Recall: {recall}
+ Precision: {precision}
+ Faithfulness: {faithfulness}
+ Interpretability: {relevance}
+
+ 中文: |
+ 你是一个文本分析专家,你的任务是:根据给出的测试使用的大模型、embedding模型、测试相关文档的解析方法和分块大小、单条测试结果分析RAG算法匹配到的片段,并分析当前知识库问答准确率的提升方法。
+
+ 测试结果包含以下内容:
+ - 问题:测试使用的问题
+ - 标准答案:测试使用的标准答案
+ - 生成的答案:测试结果中大模型输出的答案
+ - 原始片段:测试结果中的原始片段
+ - 检索的片段:测试结果中RAG算法检索到的片段
+
+ 四个评估指标定义如下:
+ - 精确率:评估生成的答案与问题之间的语义相似程度。评分越低,说明使用的大模型遵从度越低;其次可能是RAG检索到的片段缺少上下文,不足以支撑问题的回答。
+ - 召回率:评估生成的答案与标准回答之间的语义相似程度。评分越低,说明使用的大模型遵从度越低。
+ - 忠实值:评估生成的答案中的内容是否来自于检索到的片段。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是文本分块大小不合理。
+ - 可解释性:评估生成的答案是否能用于回答问题。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是使用的大模型遵从度越低。
+
+ 注意:
+ #01 请根据测试结果中的内容分析当前知识库问答准确率的提升方法。
+ #02 请结合召回率、精确率、忠实值和可解释性四个指标进行分析。
+ #03 分析结果长度不超过500字。
+ #04 请仅输出分析结果,不要输出其他内容。
+
+ 例子:
+ 输入:
+ 模型名称:qwen2.5-32b
+ embedding模型:bge-m3
+ 文本的分块大小:512
+ 使用解析的RAG算法:向量化检索
问题:openEuler是什么操作系统?
- 输出1:YES
-
- 输入2:
- 陈诉:白马非马
+ 标准答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 生成的答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 原始片段:openEuler是由开放原子开源基金会孵化及运营的开源操作系统,以构建面向数字基础设施的开源操作系统生态为使命,致力于为云计算、边缘计算等前沿领域提供坚实的底层支持。在云计算场景中,openEuler能够充分优化资源调度与分配机制,通过轻量化的内核设计和高效的虚拟化技术,显著提升云服务的响应速度与吞吐量;在边缘计算领域,它凭借出色的低资源消耗特性与实时处理能力,保障了边缘节点在复杂环境下数据处理的及时性与准确性。openEuler具备一系列卓越特性:在性能方面,其自主研发的智能调度算法能够动态适配不同负载场景,结合对硬件资源的深度优化利用,大幅提升系统运行效率;安全性上,通过内置的多层次安全防护体系,包括强制访问控制、漏洞扫描与修复机制,为系统数据与应用程序构筑起坚实的安全防线;可靠性层面,基于分布式存储、故障自动检测与快速恢复技术,确保系统在面对网络波动、硬件故障等突发状况时,依然能够稳定运行,最大限度降低服务中断风险。这些特性使openEuler成为推动数字经济高质量发展的重要技术基石,助力企业与开发者在数字化转型进程中抢占先机。
+ 检索的片段:openEuler作为开源操作系统领域的先锋力量,深度融合了社区开发者的智慧结晶,不断迭代升级以适应快速变化的技术环境。在微服务架构盛行的当下,openEuler通过对容器化技术的深度优化,支持Kubernetes等主流编排工具,让应用部署与管理变得更加便捷高效,极大提升了企业的业务部署灵活性。同时,它积极拥抱AI时代,通过对机器学习框架的适配与优化,为AI模型训练和推理提供强大的算力支撑,有效降低了AI应用的开发与运行成本。在生态建设方面,openEuler拥有庞大且活跃的开源社区,汇聚了来自全球的技术爱好者与行业专家,形成了从内核开发、驱动适配到应用优化的完整生态链。社区定期举办技术交流与开发者大会,推动知识共享与技术创新,为开发者提供了丰富的学习资源与实践机会。众多硬件厂商和软件企业纷纷加入openEuler生态,推出基于该系统的解决方案和产品,涵盖金融、电信、能源等关键行业,以实际应用场景验证并反哺openEuler的技术发展,形成了良性循环的创新生态,让openEuler不仅是一个操作系统,更成为推动产业协同发展的强大引擎。
+ 召回率:95.00
+ 精确率:99.00
+ 忠实值:90.00
+ 可解释性:85.00
+
+ 输出:
+ 根据测试结果中的内容,当前知识库问答准确率提升的方法可以从以下几个方面进行分析:召回率:目前召回率为95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如openEuler生态中一些具体实践案例等。调整embedding模型bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。精确率:精确率达99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型qwen2.5-32b的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述openEuler在云计算和边缘计算中高性能等特性的具体表现。忠实值:忠实值为90.00,说明部分答案内容未完全源于检索片段。优化RAG检索算法,提高embedding模型召回率,文本分块大小为512可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于openEuler生态建设中相关技术细节应从检索片段获取。可解释性:可解释性为85.00,相对较低。提升大模型qwen2.5-32b的遵从度,优化RAG检索算法和embedding模型bge-m3的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答openEuler相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
+
+ 下面是测试结果中的内容:
+ 使用的大模型:{model_name}
+ embedding模型:{embedding_model}
+ 文本的分块大小:{chunk_size}
+ 使用解析的RAG算法:{rag_algorithm}
+ 问题:{question}
+ 标准答案:{standard_answer}
+ 生成的答案:{generated_answer}
+ 原始片段:{original_fragment}
+ 检索的片段:{retrieved_fragment}
+ 召回率:{recall}
+ 精确率:{precision}
+ 忠实值:{faithfulness}
+ 可解释性:{relevance}
+
+ANSWER_TO_ANSWER_PROMPT:
+ # 英文文本相似度评分提示词
+ en: |
+ You are a text analysis expert. Your task is to compare the similarity between two documents and output a score between 0 and 100 with two decimal places.
+
+ Note:
+ #01 Score based on text similarity in three dimensions: semantics, word order, and keywords.
+ #02 If the core expressions of the two documents are consistent, the score will be relatively high.
+ #03 If one document contains the core content of the other, the score will also be relatively high.
+ #04 If there is content overlap between the two documents, the score will be determined by the proportion of the overlap.
+ #05 Output only the score (no other content).
+
+ Example 1:
+ Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Text 2: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: 100.00
+
+ Example 2:
+ Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Text 2: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance and high security.
+ Output: 90.00
+
+ Example 3:
+ Input - Text 1: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Text 2: A white horse is not a horse
+ Output: 00.00
+
+ The following are the given texts:
+ Text 1: {text_1}
+ Text 2: {text_2}
+
+ # 中文文本相似度评分提示词
+ 中文: |
+ 你是一个文本分析专家,你的任务是对比两个文本之间的相似度,并输出一个 0-100 之间的分数(保留两位小数)。
+
+ 注意:
+ #01 请根据文本在语义、语序和关键字三个维度的相似度进行打分。
+ #02 如果两个文本在核心表达上一致,那么分数将相对较高。
+ #03 如果一个文本包含另一个文本的核心内容,那么分数也将相对较高。
+ #04 如果两个文本间存在内容重合,那么将按照重合内容的比例确定分数。
+ #05 仅输出分数,不要输出其他任何内容。
+
+ 例子 1:
+ 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:100.00
+
+ 例子 2:
+ 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能和高安全性等特点。
+ 输出:90.00
+
+ 例子 3:
+ 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 文本 2:白马非马
+ 输出:00.00
+
+ 下面是给出的文本:
+ 文本 1:{text_1}
+ 文本 2:{text_2}
+
+CAL_QA_SCORE_PROMPT:
+ en: >-
+ You are a text analysis expert. Your task is to evaluate the questions and answers generated from a given fragment, and assign a score between 0 and 100 (retaining two decimal places). Please evaluate based on the following criteria:
+
+ ### 1. Question Evaluation
+ - **Relevance**: Is the question closely related to the topic of the given fragment? Is it accurately based on the fragment content? Does it deviate from or distort the core message of the fragment?
+ - **Plausibility**: Is the question formulated clearly and logically coherently? Does it conform to normal language and thinking habits? Is it free of semantic ambiguity, vagueness, or self-contradiction?
+ - **Variety**: If there are multiple questions, are their angles and types sufficiently varied to avoid being overly monotonous or repetitive? Can they explore the fragment content from different perspectives?
+ - **Difficulty**: Is the question difficulty appropriate? Not too easy (where answers can be directly copied from the fragment), nor too difficult (where respondents cannot find clues or evidence from the fragment)?
+
+ ### 2. Answer Evaluation
+ - **Accuracy**: Does the answer accurately address the question? Is it consistent with the information in the fragment? Does it contain errors or omit key points?
+ - **Completeness**: Is the answer comprehensive, covering all aspects of the question? For questions requiring elaboration, does it provide sufficient details and explanations?
+ - **Succinctness**: On the premise of ensuring completeness and accuracy, is the answer concise and clear? Does it avoid lengthy or redundant expressions, and convey key information in concise language?
+ - **Coherence**: Is the answer logically clear? Are transitions between content sections natural and smooth? Are there any jumps or confusion?
+
+ ### 3. Overall Assessment
+ - **Consistency**: Do the question and answer match each other? Does the answer address the raised question? Are they consistent in content and logic?
+ - **Integration**: Does the answer effectively integrate information from the fragment? Is it not just a simple excerpt, but rather an integrated, refined presentation in a logical manner?
+ - **Innovation**: In some cases, evaluate whether the answer demonstrates innovation or unique insights? Does it appropriately expand or deepen the information in the fragment?
+
+ ### Note
+ #01 Please output only the score (without any other content).
+
+ ### Example
+ Input 1:
+ Question: What operating system is openEuler?
+ Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Snippet: openEuler is an open source operating system designed to support cloud and edge computing. It features high performance, high security, and high reliability.
+ Output 1: 100.00
+
+ Below is the given question, answer, and snippet:
+ Question: {question}
+ Answer: {answer}
+ Snippet: {fragment}
+ 中文: >-
+ 你是文本分析专家,任务是评估由给定片段生成的问题与答案,输出 0-100 之间的分数(保留两位小数)。请根据以下标准进行评估:
+
+ ### 1. 问题评估
+ - **相关性**:问题是否与给定片段的主题紧密相关?是否准确基于片段内容提出?有无偏离或曲解片段的核心信息?
+ - **合理性**:问题表述是否清晰、逻辑连贯?是否符合正常的语言表达和思维习惯?不存在语义模糊、歧义或自相矛盾的情况?
+ - **多样性**:若存在多个问题,问题之间的角度和类型是否具有足够多样性(避免过于单一或重复)?能否从不同方面挖掘片段内容?
+ - **难度**:问题难度是否适中?既不过于简单(答案可直接从片段中照搬),也不过于困难(回答者难以从片段中找到线索或依据)?
+
+ ### 2. 答案评估
+ - **准确性**:答案是否准确无误地回答了问题?与片段中的信息是否一致?有无错误或遗漏关键要点?
+ - **完整性**:答案是否完整,涵盖问题涉及的各个方面?对于需要详细阐述的问题,是否提供了足够的细节和解释?
+ - **简洁性**:在保证回答完整、准确的前提下,答案是否简洁明了?是否避免冗长、啰嗦的表述,能否以简洁语言传达关键信息?
+ - **连贯性**:答案逻辑是否清晰?各部分内容之间的衔接是否自然流畅?有无跳跃或混乱的情况?
+
+ ### 3. 整体评估
+ - **一致性**:问题与答案之间是否相互匹配?答案是否针对所提出的问题进行回答?两者在内容和逻辑上是否保持一致?
+ - **融合性**:答案是否能很好地融合片段中的信息?是否并非简单摘抄,而是经过整合、提炼后以合理方式呈现?
+ - **创新性**:在某些情况下,评估答案是否具有一定创新性或独特见解?是否能在片段信息基础上进行适当拓展或深入思考?
+
+ ### 注意事项
+ #01 请仅输出分数,不要输出其他内容。
+
+ ### 示例
+ 输入 1:
+ 问题:openEuler 是什么操作系统?
+ 答案:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 片段:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出 1:100.00
+
+ 下面是给出的问题、答案和片段:
+ 问题:{question}
+ 答案:{answer}
+ 片段:{fragment}
+
+CHUNK_QUERY_MATCH_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to determine whether a given fragment is relevant to a question.
+ Note:
+ #01 If the fragment is relevant, output YES.
+ #02 If the fragment is not relevant, output NO.
+ #03 Only output YES or NO, and do not output anything else.
+
+ Example:
+ Input 1:
+ Fragment: openEuler is an open source operating system.
+ Question: What kind of operating system is openEuler?
+ Output 1: YES
+
+ Input 2:
+ Fragment: A white horse is not a horse.
+ Question: What kind of operating system is openEuler?
+ Output 2: NO
+
+ Here are the given fragment and question:
+ Fragment: {chunk}
+ Question: {question}
+ 中文: |
+ 你是一个文本分析专家,你的任务是根据给出的片段和问题,判断片段是否与问题相关。
+ 注意:
+ #01 如果片段与问题相关,请输出YES;
+ #02 如果片段与问题不相关,请输出NO;
+ #03 请仅输出YES或NO,不要输出其他内容。
+
+ 例子:
+ 输入1:
+ 片段:openEuler是一个开源的操作系统。
问题:openEuler是什么操作系统?
- 输出2:NO
-
- 下面是给出的陈诉和问题:
- 陈诉:{statement}
- 问题:{question}
- '
-GENREATE_QUESTION_FROM_CONTENT_PROMPT: '你是一个文本分析专家,你的任务是根据给出的文本生成{k}个问题并用列表返回
- 注意:
- #01 问题必须来源于文本中的内容
- #02 单个问题长度不超过50个字
- #03 不要输出重复的问题
- #04 输出的问题要多样,覆盖文本中的不同方面
- #05 请仅输出问题列表,不要输出其他内容
- 例子:
- 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出:
- [\"openEuler是什么操作系统?\",\"openEuler旨在为哪个领域提供支持?\",\"openEuler具有哪些特点?\",\"openEuler的安全性如何?\",\"openEuler的可靠性如何?\"]
- 下面是给出的文本:
- {content}
-'
-GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT: '你是一个文本分析专家,你的任务是根据给出的问题和文本
- 生成答案
- 注意:
- #01 答案必须来源于文本中的内容
- #02 答案长度不少于50字且不超过500个字
- #03 请仅输出答案,不要输出其他内容
- 例子:
- 输入1:
+ 输出1:YES
+
+ 输入2:
+ 片段:白马非马
问题:openEuler是什么操作系统?
- 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。
-
- 输入2:
- 问题:openEuler的安全性如何?
- 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出2:openEuler具有高安全性。
-
- 下面是给出的问题和文本:
- 问题:{question}
- 文本:{content}
-'
-CAL_QA_SCORE_PROMPT: '你是一个文本分析专家,你的任务是给出的问题 答案 片段 判断由片段生成的问题和答案的分数,输出一个0-100之间的数,保留两位小数
-请根据下面规则评估:
-问题评估
-相关性:问题是否与给定片段的主题紧密相关,是否准确地基于片段内容提出,有无偏离或曲解片段的核心信息。
-合理性:问题的表述是否清晰、逻辑连贯,是否符合正常的语言表达和思维习惯,不存在语义模糊、歧义或自相矛盾的情况。
-多样性:如果有多个问题,问题之间的角度和类型是否具有一定的多样性,避免过于单一或重复,能否从不同方面挖掘片段的内容。
-难度:问题的难度是否适中,既不过于简单,使答案可以直接从片段中照搬,也不过于困难,让回答者难以从片段中找到线索或依据。
-答案评估
-准确性:答案是否准确无误地回答了问题,与片段中的信息是否一致,有无错误或遗漏关键要点。
-完整性:答案是否完整,涵盖了问题所涉及的各个方面,对于需要详细阐述的问题,是否提供了足够的细节和解释。
-简洁性:在保证回答完整准确的前提下,答案是否简洁明了,避免冗长、啰嗦的表述,能否以简洁的语言传达关键信息。
-连贯性:答案的逻辑是否清晰,各部分内容之间的衔接是否自然流畅,有无跳跃或混乱的情况。
-整体评估
-一致性:问题和答案之间是否相互匹配,答案是否是针对所提出的问题进行的回答,两者在内容和逻辑上是否保持一致。
-融合性:答案是否能够很好地融合片段中的信息,不仅仅是简单的摘抄,而是经过整合和提炼,以合理的方式呈现出来。
-创新性:在某些情况下,评估答案是否具有一定的创新性或独特见解,是否能够在片段信息的基础上进行适当的拓展或深入思考。
-
-注意:
-#01 请仅输出分数,不要输出其他内容
-
-例子:
-输入1:
- 问题:openEuler是什么操作系统?
- 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出1:100.00
-
-下面是给出的问题、答案和片段:
- 问题:{question}
- 答案:{answer}
- 片段:{fragment}
-'
-CONTENT_TO_ABSTRACT_PROMPT: '你是一个文本摘要专家,你的任务是根据给出的文本和摘要生成一个新的摘要
- 注意:
- #01 请结合文本和摘要中最重要的内容生成新的摘要
- #02 新的摘要的长度必须大于200字小于500字
- #03 请仅输出新的摘要,不要输出其他内容
- 例子:
- 输入1:
+ 输出2:NO
+
+ 下面是给出的片段和问题:
+ 片段:{chunk}
+ 问题:{question}
+
+CONTENT_TO_ABSTRACT_PROMPT:
+ en: |
+ You are a text summarization expert. Your task is to generate a new English summary based on a given text and an existing summary.
+ Note:
+ #01 Please combine the most important content from the text and the existing summary to generate the new summary.
+ #02 The length of the new summary must be greater than 200 words and less than 500 words.
+ #03 Please only output the new English summary; do not output any other content.
+
+ Example:
+ Input 1:
+ Text: openEuler features high performance, high security, and high reliability.
+ Abstract: openEuler is an open source operating system designed to support cloud computing and edge computing.
+ Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing. openEuler features high performance, high security, and high reliability.
+
+ Below is the given text and summary:
+ Text: {content}
+ Abstract: {abstract}
+ 中文: |
+ 你是一个文本摘要专家,你的任务是根据给出的文本和已有摘要,生成一个新的中文摘要。
+ 注意:
+ #01 请结合文本和已有摘要中最重要的内容,生成新的摘要;
+ #02 新的摘要长度必须大于200字且小于500字;
+ #03 请仅输出新的中文摘要,不要输出其他内容。
+
+ 例子:
+ 输入1:
文本:openEuler具有高性能、高安全性和高可靠性等特点。
摘要:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。
- 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。openEuler具有高性能、高安全性和高可靠性等特点。
-
- 下面是给出的文本和摘要:
- 文本:{content}
- 摘要:{abstract}
-'
-
-CONTENT_TO_TITLE_PROMPT: '你是一个标题提取专家,你的任务是根据给出的文本生成一个标题
- 注意:
- #01 标题必须来源于文本中的内容
- #02 标题长度不超过20个字
- #03 请仅输出标题,不要输出其他内容
- #04 如果给出的文本不够生成标题,请输出“无法生成标题”
- 例子:
- 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
- 输出:openEuler操作系统概述
- 下面是给出的文本:
- {content}
-'
-
-ACC_RESULT_ANALYSIS_PROMPT: '你是一个文本分析专家,你的任务根据给出的 测试使用的大模型 embdding模型 测试相关文档的解析方法和分块大小 单条测试结果分析rag算法匹配到的片段分析当前知识库问答准确率提升的方法
-测试结果包含下面内容:
-问题:测试使用的问题
-标准答案:测试使用的标准答案
-生成的答案:测试结果中大模型的答案
-原始片段:测试结果中原始片段
-检索的片段:测试结果中rag算法检索到的片段
-精确率:评估生成的答案与问题之间的语义相似程度,这个评分月越低说明使用的大模型遵从度越低,其次是rag检索到的片段缺少上下文,不足以支撑问题的回答
-召回率度:评估生成的答案与标准回答之间的语义相似程度,这个评分月越低说明使用的大模型遵从度越低
-忠实值:评估生成的答案中的内容是否来自于检索到的片段,这个评分越低说明rag检索算法和embedding模型的召回率越低,检索到的片段不足以回答问题,其次是文本分块大小不合理
-可解释性:评估生成的答案是否用于回答问题,这个评分越低说明rag检索算法和embedding模型的召回率越低,检索到的片段不足以回答问题,其次是使用的大模型遵从度越低
-
-注意:
-#01 请根据测试结果中的内容分析当前知识库问答准确率提升的方法
-#02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析
-#03 分析结果长度不超过500字
-#04 请仅输出分析结果,不要输出其他内容
-例子:
-输入:
-模型名称:qwen2.5-32b
-embedding模型:bge-m3
-文本的分块大小:512
-使用解析的rag算法:向量化检索
-问题:openEuler是什么操作系统?
-标准答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
-生成的答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
-原始片段:openEuler 是由开放原子开源基金会孵化及运营的开源操作系统,以构建面向数字基础设施的开源操作系统生态为使命,致力于为云计算、边缘计算等前沿领域提供坚实的底层支持。在云计算场景中,openEuler 能够充分优化资源调度与分配机制,通过轻量化的内核设计和高效的虚拟化技术,显著提升云服务的响应速度与吞吐量;在边缘计算领域,它凭借出色的低资源消耗特性与实时处理能力,保障了边缘节点在复杂环境下数据处理的及时性与准确性。
-openEuler 具备一系列卓越特性:在性能方面,其自主研发的智能调度算法能够动态适配不同负载场景,结合对硬件资源的深度优化利用,大幅提升系统运行效率;安全性上,通过内置的多层次安全防护体系,包括强制访问控制、漏洞扫描与修复机制,为系统数据与应用程序构筑起坚实的安全防线;可靠性层面,基于分布式存储、故障自动检测与快速恢复技术,确保系统在面对网络波动、硬件故障等突发状况时,依然能够稳定运行,最大限度降低服务中断风险。这些特性使 openEuler 成为推动数字经济高质量发展的重要技术基石,助力企业与开发者在数字化转型进程中抢占先机。
-检索的片段:openEuler 作为开源操作系统领域的先锋力量,深度融合了社区开发者的智慧结晶,不断迭代升级以适应快速变化的技术环境。在微服务架构盛行的当下,openEuler 通过对容器化技术的深度优化,支持 Kubernetes 等主流编排工具,让应用部署与管理变得更加便捷高效,极大提升了企业的业务部署灵活性。同时,它积极拥抱 AI 时代,通过对机器学习框架的适配与优化,为 AI 模型训练和推理提供强大的算力支撑,有效降低了 AI 应用的开发与运行成本。
-在生态建设方面,openEuler 拥有庞大且活跃的开源社区,汇聚了来自全球的技术爱好者与行业专家,形成了从内核开发、驱动适配到应用优化的完整生态链。社区定期举办技术交流与开发者大会,推动知识共享与技术创新,为开发者提供了丰富的学习资源与实践机会。众多硬件厂商和软件企业纷纷加入 openEuler 生态,推出基于该系统的解决方案和产品,涵盖金融、电信、能源等关键行业,以实际应用场景验证并反哺 openEuler 的技术发展,形成了良性循环的创新生态,让 openEuler 不仅是一个操作系统,更成为推动产业协同发展的强大引擎 。
-
-召回率:95.00
-精确度:99.00
-忠实值:90.00
-可解释性:85.00
-
-输出:
-根据测试结果中的内容,当前知识库问答准确率提升的方法可以从以下几个方面进行分析:
-召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
-精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
-忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
-可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
-
-
-下面是测试结果中的内容:
-使用的大模型:{model_name}
-embedding模型:{embedding_model}
-文本的分块大小:{chunk_size}
-使用解析的rag算法:{rag_algorithm}
-问题:{question}
-标准答案:{standard_answer}
-生成的答案:{generated_answer}
-原始片段:{original_fragment}
-检索的片段:{retrieved_fragment}
-召回率:{recall}
-精确度:{precision}
-忠实值:{faithfulness}
-可解释性:{relevance}
-'
-ACC_ANALYSIS_RESULT_MERGE_PROMPT: '你是一个文本分析专家,你的任务融合两条分析结果输出一份新的分析结果
-注意:
-#01 请根据两条分析结果中的内容融合出一条新的分析结果
-#02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析
-#03 新的分析结果长度不超过500字
-#04 请仅输出新的分析结果,不要输出其他内容
-例子:
-输入1:
-分析结果1:
-
-召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
-精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
-忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
-可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
-
-分析结果2:
-
-从召回率来看,目前为 95.00,可进一步优化 rag 检索算法和 embedding 模型,以提高生成答案与标准回答之间的语义相似程度,接近或达到更高的召回率,例如可以持续优化算法来更好地匹配相关片段。
-从精确度来看,为 99.00,接近满分,说明生成的答案与问题语义相似程度较高,但仍可进一步提升,可通过完善 embedding 模型来更好地理解问题语义,优化检索到的片段的上下文完整性,减少因上下文不足导致的精确度波动。
-对于忠实值,目前为 90.00,说明生成的答案中部分内容未完全来自检索到的片段。可优化 rag 检索算法,提高其召回率,同时合理调整文本分块大小,确保检索到的片段能充分回答问题,从而提高忠实值。
-关于可解释性,当前为 85.00,说明生成的答案在用于回答问题方面有一定提升空间。一方面可以优化使用的大模型,提高其遵从度,使其生成的答案更准确地回答问题;另一方面,继续优化 rag 检索算法和 embedding 模型的召回率,保证检索到的片段能全面支撑问题的回答,提高可解释性。
-
-输出:
-召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
-精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
-忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
-可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
-
-下面两条分析结果:
-分析结果1:{analysis_result_1}
-分析结果2:{analysis_result_2}
-'
-CHUNK_QUERY_MATCH_PROMPT: '你是一个文本分析专家,你的任务是根据给出的片段和问题判断,片段是否与问题相关
- 注意:
- #01 如果片段与问题相关,请输出YES
- #02 如果片段与问题不相关,请输出NO
- #03 请仅输出YES或NO,不要输出其他内容
- 例子:
- 输入1:
- 片段:openEuler是一个开源的操作系统。
+ 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。openEuler具有高性能、高安全性和高可靠性等特点。
+
+ 下面是给出的文本和摘要:
+ 文本:{content}
+ 摘要:{abstract}
+
+CONTENT_TO_STATEMENTS_PROMPT:
+ en: |
+ You are a text parsing expert. Your task is to extract multiple English statements from a given text and return them as a list.
+
+ Note:
+ #01 Statements must be derived from key points in the text.
+ #02 Statements must be arranged in relative order.
+ #03 Each statement must be at least 20 characters long and no more than 50 characters long.
+ #04 The total number of statements output must not exceed three.
+ #05 Please output only the list of statements, not any other content. Each statement must be in English.
+ Example:
+
+ Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ]
+
+ The following is the given text: {content}
+ 中文: |
+ 你是一个文本分解专家,你的任务是根据我给出的文本,将文本提取为多个中文陈述,陈述使用列表形式返回
+
+ 注意:
+ #01 陈述必须来源于文本中的重点内容
+ #02 陈述按相对顺序排列
+ #03 输出的单个陈述长度不少于20个字,不超过50个字
+ #04 输出的陈述总数不超过3个
+ #05 请仅输出陈述列表,不要输出其他内容,且每一条陈述都是中文。
+ 例子:
+
+ 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ]
+
+ 下面是给出的文本: {content}
+
+CONTENT_TO_TITLE_PROMPT:
+ en: >-
+ You are a title extraction expert. Your task is to generate an English title based on the given text.
+ Note:
+ #01 The title must be derived from the content of the text.
+ #02 The title must be no longer than 20 characters.
+ #03 Please output only the English title, and do not output any other content.
+ #04 If the given text is insufficient to generate a title, output "Unable to generate title."
+ Example:
+ Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: Overview of the openEuler operating system.
+ Below is the given text: {content}
+ 中文: >-
+ 你是一个标题提取专家,你的任务是根据给出的文本生成一个中文标题。
+ 注意:
+ #01 标题必须来源于文本中的内容
+ #02 标题长度不超过20个字
+ #03 请仅输出中文标题,不要输出其他内容
+ #04 如果给出的文本不够生成标题,请输出“无法生成标题”
+ 例子:
+ 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:openEuler操作系统概述
+ 下面是给出的文本:{content}
+
+GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to generate an English answer based on a given question and text.
+ Note:
+ #01 The answer must be derived from the content in the text.
+ #02 The answer must be at least 50 words and no more than 500 words.
+ #03 Please only output the English answer; do not output any other content.
+ Example:
+ Input 1:
+ Question: What kind of operating system is openEuler?
+ Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing.
+
+ Input 2:
+ Question: How secure is openEuler?
+ Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 2: openEuler is highly secure.
+
+ Below is the given question and text:
+ Question: {question}
+ Text: {content}
+ 中文: |
+ 你是一个文本分析专家,你的任务是根据给出的问题和文本生成中文答案。
+ 注意:
+ #01 答案必须来源于文本中的内容;
+ #02 答案长度不少于50字且不超过500个字;
+ #03 请仅输出中文答案,不要输出其他内容。
+ 例子:
+ 输入1:
+ 问题:openEuler是什么操作系统?
+ 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。
+
+ 输入2:
+ 问题:openEuler的安全性如何?
+ 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出2:openEuler具有高安全性。
+
+ 下面是给出的问题和文本:
+ 问题:{question}
+ 文本:{content}
+
+GENERATE_QUESTION_FROM_CONTENT_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to generate {k} English questions based on the given text and return them as a list.
+ Note:
+ #01 Questions must be derived from the content of the text.
+ #02 A single question must not exceed 50 characters.
+ #03 Do not output duplicate questions.
+ #04 The output questions should be diverse, covering different aspects of the text.
+ #05 Please only output a list of English questions, not other content.
+ Example:
+ Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: ["What is openEuler?","What fields does openEuler support?","What are the characteristics of openEuler?","How secure is openEuler?","How reliable is openEuler?"]
+ The following is the given text: {content}
+ 中文: |
+ 你是一个文本分析专家,你的任务是根据给出的文本生成{k}个中文问题并用列表返回。
+ 注意:
+ #01 问题必须来源于文本中的内容;
+ #02 单个问题长度不超过50个字;
+ #03 不要输出重复的问题;
+ #04 输出的问题要多样,覆盖文本中的不同方面;
+ #05 请仅输出中文问题列表,不要输出其他内容。
+ 例子:
+ 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:["openEuler是什么操作系统?","openEuler旨在为哪个领域提供支持?","openEuler具有哪些特点?","openEuler的安全性如何?","openEuler的可靠性如何?"]
+ 下面是给出的文本:{content}
+
+OCR_ENHANCED_PROMPT:
+ en: |
+ You are an expert in image OCR content summarization. Your task is to describe the image based on the context I provide, descriptions of adjacent images, a summary of the previous OCR result for the current image, and the partial OCR results (including text and relative coordinates).
+
+ Note:
+ #01 The image content must be described in detail, using at least 200 and no more than 500 words. Detailed data listing is acceptable.
+ #02 If this diagram is a flowchart, please describe the content in the order of the flowchart.
+ #03 If this diagram is a table, please output the table content in Markdown format.
+ #04 If this diagram is an architecture diagram, please describe the content according to the hierarchy of the architecture diagram.
+ #05 The summarized image description must include the key information in the image; it cannot simply describe the image's location.
+ #06 Adjacent text in the image recognition results may be part of the same paragraph. Please merge them before summarizing.
+ #07 The text may be misplaced. Please correct the order before summarizing.
+ #08 Please only output the image summary; do not output any other content.
+ #09 Do not output coordinates or other information; only output a description of the relative position of each part.
+ #10 If the image content is empty, output "Image content is empty."
+ #11 If the image itself is a paragraph of text, output the text content directly.
+ #12 Please use English for the output.
+ Context: {image_related_text}
+ Summary of the OCR content of the previous part of the current image: {pre_part_description}
+ Result of the OCR of the current part of the image: {part}
+ 中文: |
+ 你是一个图片OCR内容总结专家,你的任务是根据我提供的上下文、相邻图片组描述、当前图片上一次的OCR内容总结、当前图片部分OCR的结果(包含文字和文字的相对坐标)给出图片描述。
+
+ 注意:
+ #01 必须使用大于200字小于500字详细描述这个图片的内容,可以详细列出数据。
+ #02 如果这个图是流程图,请按照流程图顺序描述内容。
+ #03 如果这张图是表格,请用Markdown形式输出表格内容。
+ #04 如果这张图是架构图,请按照架构图层次结构描述内容。
+ #05 总结的图片描述必须包含图片中的主要信息,不能只描述图片位置。
+ #06 图片识别结果中相邻的文字可能是同一段落的内容,请合并后总结。
+ #07 文字可能存在错位,请修正顺序后进行总结。
+ #08 请仅输出图片的总结即可,不要输出其他内容。
+ #09 不要输出坐标等信息,输出每个部分相对位置的描述即可。
+ #10 如果图片内容为空,请输出“图片内容为空”。
+ #11 如果图片本身就是一段文字,请直接输出文字内容。
+ #12 请使用中文输出。
+ 上下文:{image_related_text}
+ 当前图片上一部分的OCR内容总结:{pre_part_description}
+ 当前图片部分OCR的结果:{part}
+
+QA_TO_STATEMENTS_PROMPT:
+ en: |
+ You are a text parsing expert. Your task is to extract the answers from the questions and answers I provide into multiple English statements, returning them as a list.
+
+ Note:
+ #01 The statements must be derived from the key points of the answers.
+ #02 The statements must be arranged in relative order.
+ #03 The length of each statement output must not exceed 50 characters.
+ #04 The total number of statements output must not exceed 20.
+ #05 Please only output the list of English statements; do not output any other content.
+
+ Example:
+ Input: Question: What is openEuler? Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ]
+
+ Below are the given questions and answers:
+ Question: {question}
+ Answer: {answer}
+ 中文: |
+ 你是一个文本分解专家,你的任务是根据我给出的问题和答案,将答案提取为多个中文陈述,陈述使用列表形式返回。
+
+ 注意:
+ #01 陈述必须来源于答案中的重点内容
+ #02 陈述按相对顺序排列
+ #03 输出的单个陈述长度不超过50个字
+ #04 输出的陈述总数不超过20个
+ #05 请仅输出中文陈述列表,不要输出其他内容
+
+ 例子:
+ 输入:问题:openEuler是什么操作系统? 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ]
+
+ 下面是给出的问题和答案:
+ 问题:{question}
+ 答案:{answer}
+
+QUERY_EXTEND_PROMPT:
+ en: |
+ You are a question expansion expert. Your task is to expand {k} questions based on the given question.
+
+ Note:
+ #01 The content of the expanded question must be derived from the content of the original question.
+ #02 The expanded question length must not exceed 50 characters.
+ #03 Questions can be rewritten by replacing synonyms, swapping word order within the question, changing English capitalization, etc.
+ #04 Please only output the expanded question list, do not output other content.
+
+ Example:
+ Input: What operating system is openEuler?
+ Output: [ "What kind of operating system is openEuler?", "What are the characteristics of the openEuler operating system?", "What are the functions of the openEuler operating system?", "What are the advantages of the openEuler operating system?" ]
+
+ The following is the given question: {question}
+ 中文: |
+ 你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题。
+
+ 注意:
+ #01 扩写的问题的内容必须来源于原问题中的内容
+ #02 扩写的问题长度不超过50个字
+ #03 可以通过近义词替换、问题内词序交换、修改英文大小写等方式来改写问题
+ #04 请仅输出扩写的问题列表,不要输出其他内容
+
+ 例子:
+ 输入:openEuler是什么操作系统?
+ 输出:[ "openEuler是一个什么样的操作系统?", "openEuler操作系统的特点是什么?", "openEuler操作系统有哪些功能?", "openEuler操作系统的优势是什么?" ]
+
+ 下面是给出的问题:{question}
+
+STATEMENTS_TO_FRAGMENT_PROMPT:
+ en: |
+ You are a text expert. Your task is to determine whether a given statement is strongly related to the fragment.
+
+ Note:
+ #01 If the statement is strongly related to the fragment or is derived from the fragment, output YES.
+ #02 If the content in the statement is unrelated to the fragment, output NO.
+ #03 If the statement is a refinement of a portion of the fragment, output YES.
+ #05 Only output YES or NO, and do not output anything else.
+
+ Example:
+ Input 1:
+ Statement: openEuler is an open source operating system.
+ Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 1: YES
+
+ Input 2:
+ Statement: A white horse is not a horse.
+ Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 2: NO
+
+ Below is a given statement and fragment:
+ Statement: {statement}
+ Fragment: {fragment}
+ 中文: |
+ 你是一个文本专家,你的任务是判断给出的陈述是否与片段强相关。
+
+ 注意:
+ #01 如果陈述与片段强相关或者来自于片段,请输出YES
+ #02 如果陈述中的内容与片段无关,请输出NO
+ #03 如果陈述是片段中某部分的提炼,请输出YES
+ #05 请仅输出YES或NO,不要输出其他内容
+
+ 例子:
+ 输入1:
+ 陈述:openEuler是一个开源的操作系统。
+ 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出1:YES
+
+ 输入2:
+ 陈述:白马非马
+ 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出2:NO
+
+ 下面是给出的陈述和片段:
+ 陈述:{statement}
+ 片段:{fragment}
+
+STATEMENTS_TO_QUESTION_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to determine whether a given statement is relevant to a question.
+
+ Note:
+ #01 If the statement is relevant to the question, output YES.
+ #02 If the statement is not relevant to the question, output NO.
+ #03 Only output YES or NO, and do not output anything else.
+ #04 A statement's relevance to the question means that the content in the statement can answer the question or overlaps with the question in terms of content.
+
+ Example:
+ Input 1:
+ Statement: openEuler is an open source operating system.
+ Question: What kind of operating system is openEuler?
+ Output 1: YES
+
+ Input 2:
+ Statement: A white horse is not a horse.
+ Question: What kind of operating system is openEuler?
+ Output 2: NO
+
+ Below is the given statement and question:
+ Statement: {statement}
+ Question: {question}
+ 中文: |
+ 你是一个文本分析专家,你的任务是判断给出的陈述是否与问题相关。
+
+ 注意:
+ #01 如果陈述与问题相关,请输出YES
+ #02 如果陈述与问题不相关,请输出NO
+ #03 请仅输出YES或NO,不要输出其他内容
+ #04 陈述与问题相关是指,陈述中的内容可以回答问题或者与问题在内容上有交集
+
+ 例子:
+ 输入1:
+ 陈述:openEuler是一个开源的操作系统。
问题:openEuler是什么操作系统?
- 输出1:YES
+ 输出1:YES
- 输入2:
- 片段:白马非马
+ 输入2:
+ 陈述:白马非马
问题:openEuler是什么操作系统?
- 输出2:NO
-
- 下面是给出的片段和问题:
- 片段:{chunk}
- 问题:{question}
- '
-QUERY_EXTEND_PROMPT: '你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题
- 注意:
- #01 扩写的问题的内容必须来源于原问题中的内容
- #02 扩写的问题长度不超过50个字
- #03 可以通过近义词替换 问题内词序交换 修改英文大小写等方式来改写问题
- #04 请仅输出扩写的问题列表,不要输出其他内容
- 例子:
- 输入:openEuler是什么操作系统?
- 输出:
- [
- \"openEuler是一个什么样的操作系统?\",
- \"openEuler操作系统的特点是什么?\",
- \"openEuler操作系统有哪些功能?\",
- \"openEuler操作系统的优势是什么?\"
- ]
- 下面是给出的问题:
- {question}
- '
\ No newline at end of file
+ 输出2:NO
+
+ 下面是给出的陈述和问题:
+ 陈述:{statement}
+ 问题:{question}
diff --git a/data_chain/config/config.py b/data_chain/config/config.py
index 4d9678d64259384ccbabace6d8d700fabf555481..2565e0dcd8ef107e7886e350a417b1f273e85b15 100644
--- a/data_chain/config/config.py
+++ b/data_chain/config/config.py
@@ -71,6 +71,9 @@ class ConfigModel(DictBaseModel):
USE_CPU_LIMIT: int = Field(default=64, description="文档解析器使用CPU核数")
# Task Retry Time limit
TASK_RETRY_TIME_LIMIT: int = Field(default=3, description="任务重试次数限制")
+ # Ocr Method
+ OCR_METHOD: str = Field(default="offline", description="ocr识别方式,online or offline")
+ OCR_API_URL: str = Field(default="", description="ocr在线识别接口地址", pattern=r'^https?://.+')
class Config:
diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py
index 90d7c702f7be21caca42e5e8f2b18c2b2bb2441b..dd8cfea430ddca6b964a0131f7e34f4cef02ed5a 100644
--- a/data_chain/entities/enum.py
+++ b/data_chain/entities/enum.py
@@ -26,7 +26,7 @@ class Tokenizer(str, Enum):
ZH = "中文"
EN = "en"
- MIX = "mix"
+ # MIX = "mix"
class Embedding(str, Enum):
diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py
index 0c0b5825251e4c95f78879e3f9ee4d6c197f5fe1..b070a0310eac19d89cc42fff0a84cfde94a776c3 100644
--- a/data_chain/entities/request_data.py
+++ b/data_chain/entities/request_data.py
@@ -48,20 +48,20 @@ class ListTeamUserRequest(BaseModel):
class CreateTeamRequest(BaseModel):
- team_name: str = Field(default='这是一个默认的团队名称', min_length=1, max_length=30, alias="teamName")
- description: str = Field(default='', max_length=150)
+ team_name: str = Field(default='这是一个默认的团队名称', min_length=1, max_length=256, alias="teamName")
+ description: str = Field(default='', max_length=256)
is_public: bool = Field(default=False, alias="isPublic")
class UpdateTeamRequest(BaseModel):
- team_name: str = Field(default='这是一个默认的团队名称', min_length=1, max_length=30, alias="teamName")
- description: str = Field(default='', max_length=150)
+ team_name: str = Field(default='这是一个默认的团队名称', min_length=1, max_length=256, alias="teamName")
+ description: str = Field(default='', max_length=256)
is_public: bool = Field(default=False, alias="isPublic")
class DocumentType(BaseModel):
doc_type_id: uuid.UUID = Field(description="文档类型的id", alias="docTypeId")
- doc_type_name: str = Field(default='这是一个默认的文档类型名称', min_length=1, max_length=20, alias="docTypeName")
+ doc_type_name: str = Field(default='这是一个默认的文档类型名称', min_length=1, max_length=256, alias="docTypeName")
class ListKnowledgeBaseRequest(BaseModel):
@@ -74,8 +74,8 @@ class ListKnowledgeBaseRequest(BaseModel):
class CreateKnowledgeBaseRequest(BaseModel):
- kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, max_length=20, alias="kbName")
- description: str = Field(default='', max_length=150)
+ kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, max_length=256, alias="kbName")
+ description: str = Field(default='', max_length=256)
tokenizer: Tokenizer = Field(default=Tokenizer.ZH)
embedding_model: str = Field(default='', description="知识库使用的embedding模型", alias="embeddingModel")
default_chunk_size: int = Field(default=512, description="知识库默认文件分块大小", alias="defaultChunkSize", min=128, max=2048)
@@ -87,8 +87,8 @@ class CreateKnowledgeBaseRequest(BaseModel):
class UpdateKnowledgeBaseRequest(BaseModel):
- kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, max_length=30, alias="kbName")
- description: str = Field(default='', max_length=150)
+ kb_name: str = Field(default='这是一个默认的资产名称', min_length=1, max_length=256, alias="kbName")
+ description: str = Field(default='', max_length=256)
tokenizer: Tokenizer = Field(default=Tokenizer.ZH)
default_chunk_size: int = Field(default=512, description="知识库默认文件分块大小", alias="defaultChunkSize", min=128, max=2048)
default_parse_method: ParseMethod = Field(
@@ -115,7 +115,7 @@ class ListDocumentRequest(BaseModel):
class UpdateDocumentRequest(BaseModel):
- doc_name: str = Field(default='这是一个默认的文档名称', min_length=1, max_length=150, alias="docName")
+ doc_name: str = Field(default='这是一个默认的文档名称', min_length=1, max_length=256, alias="docName")
doc_type_id: uuid.UUID = Field(default=DEFAULT_DOC_TYPE_ID, description="文档类型的id", alias="docTypeId")
parse_method: ParseMethod = Field(
default=ParseMethod.GENERAL, description="知识库默认解析方法", alias="parseMethod")
@@ -129,7 +129,9 @@ class GetTemporaryDocumentStatusRequest(BaseModel):
class TemporaryDocument(BaseModel):
id: uuid.UUID = Field(description="临时文档id", alias="id")
- name: str = Field(default='这是一个默认的临时文档名称', min_length=1, max_length=150, alias="name")
+ parse_method: ParseMethod = Field(
+ default=ParseMethod.OCR, description="临时文档解析方法", alias="parseMethod")
+ name: str = Field(default='这是一个默认的临时文档名称', min_length=1, max_length=256, alias="name")
bucket_name: str = Field(default='default', description="临时文档存储的桶名称")
type: str = Field(default='txt', description="临时文档的类型", alias="type")
@@ -194,8 +196,8 @@ class ListDataInDatasetRequest(BaseModel):
class CreateDatasetRequest(BaseModel):
kb_id: uuid.UUID = Field(description="资产id", alias="kbId")
dataset_name: str = Field(default='这是一个默认的数据集名称', description="测试数据集名称",
- min_length=1, max_length=30, alias="datasetName")
- description: str = Field(default='', description="测试数据集简介", max_length=200)
+ min_length=1, max_length=256, alias="datasetName")
+ description: str = Field(default='', description="测试数据集简介", max_length=256)
document_ids: List[uuid.UUID] = Field(default=[], description="测试数据集关联的文档", alias="documentIds")
data_cnt: int = Field(default=64, alias="dataCnt", description="测试数据集内的数据数量", min=1, max=512)
llm_id: str = Field(description="测试数据集使用的大模型id", alias="llmId")
@@ -205,15 +207,15 @@ class CreateDatasetRequest(BaseModel):
class UpdateDatasetRequest(BaseModel):
dataset_name: str = Field(default='这是一个默认的数据集名称', description="测试数据集名称",
- min_length=1, max_length=30, alias="datasetName")
- description: str = Field(default='', description="测试数据集简介", max_length=200)
+ min_length=1, max_length=256, alias="datasetName")
+ description: str = Field(default='', description="测试数据集简介", max_length=256)
class UpdateDataRequest(BaseModel):
question: str = Field(default='这是一个默认的问题', description="问题",
- min_length=1, max_length=200, alias="question")
+ min_length=1, max_length=256, alias="question")
answer: str = Field(default='这是一个默认的答案', description="答案",
- min_length=1, max_length=1024, alias="answer")
+ min_length=1, max_length=4096, alias="answer")
class ListTestingRequest(BaseModel):
@@ -237,8 +239,8 @@ class ListTestCaseRequest(BaseModel):
class CreateTestingRequest(BaseModel):
testing_name: str = Field(default='这是一个默认的测试名称', description="测试名称",
- min_length=1, max_length=30, alias="testingName")
- description: str = Field(default='', description="测试简介", max_length=200)
+ min_length=1, max_length=256, alias="testingName")
+ description: str = Field(default='', description="测试简介", max_length=256)
dataset_id: uuid.UUID = Field(description="测试数据集id", alias="datasetId")
llm_id: str = Field(description="测试使用的大模型id", alias="llmId")
search_method: SearchMethod = Field(default=SearchMethod.KEYWORD_AND_VECTOR,
@@ -248,8 +250,8 @@ class CreateTestingRequest(BaseModel):
class UpdateTestingRequest(BaseModel):
testing_name: str = Field(default='这是一个默认的测试名称', description="测试名称",
- min_length=1, max_length=150, alias="testingName")
- description: str = Field(default='', description="测试简介", max_length=200)
+ min_length=1, max_length=256, alias="testingName")
+ description: str = Field(default='', description="测试简介", max_length=256)
llm_id: str = Field(description="测试使用的大模型id", alias="llmId")
search_method: SearchMethod = Field(default=SearchMethod.KEYWORD_AND_VECTOR,
description="测试使用的检索方法", alias="searchMethod")
@@ -265,12 +267,12 @@ class ListRoleRequest(BaseModel):
class CreateRoleRequest(BaseModel):
- role_name: str = Field(default='这是一个默认的角色名称', min_length=1, max_length=30, alias="roleName")
+ role_name: str = Field(default='这是一个默认的角色名称', min_length=1, max_length=256, alias="roleName")
actions: List[str] = Field(default=[], description="角色拥有的操作的列表", alias="actions")
class UpdateRoleRequest(BaseModel):
- role_name: str = Field(default='这是一个默认的角色名称', min_length=1, max_length=30, alias="roleName")
+ role_name: str = Field(default='这是一个默认的角色名称', min_length=1, max_length=256, alias="roleName")
actions: List[str] = Field(default=[], description="角色拥有的操作的列表", alias="actions")
diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py
index 5a3bf03b855be861587efb43f5d144a5007d5895..a81e892f042ae6c13aacf0c29a0a8723029c3f32 100644
--- a/data_chain/entities/response_data.py
+++ b/data_chain/entities/response_data.py
@@ -1,6 +1,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-from typing import Any, Optional
+from typing import Any, Optional, Union
from pydantic import BaseModel, Field
import uuid
@@ -24,6 +24,7 @@ from data_chain.entities.enum import (
TaskType,
TaskStatus,
OrderType)
+from data_chain.parser.parse_result import ParseResult
class ResponseData(BaseModel):
@@ -273,6 +274,11 @@ class UploadTemporaryDocumentResponse(ResponseData):
result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表")
+class GetTemporaryDocumentTextResponse(ResponseData):
+ """GET /doc/temporary/parse_result 响应"""
+ result: str = Field(default="", description="临时文档解析结果")
+
+
class DeleteTemporaryDocumentResponse(ResponseData):
"""DELETE /doc/temporary 响应"""
result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表")
@@ -283,6 +289,11 @@ class ParseDocumentResponse(ResponseData):
result: list[uuid.UUID] = Field(default=[], description="文档ID列表")
+class ParseDocumentRealTimeResponse(ResponseData):
+ """POST /doc/parse/realtime 响应"""
+ result: list[Union[ParseResult, None]] = Field(default=[], description="文档内容列表")
+
+
class UpdateDocumentResponse(ResponseData):
"""PUT /doc 响应"""
result: uuid.UUID = Field(default=None, description="文档ID")
@@ -325,10 +336,12 @@ class UpdateChunkEnabledResponse(ResponseData):
class DocChunk(BaseModel):
"""Post /chunk/search 数据结构"""
doc_id: uuid.UUID = Field(description="文档ID", alias="docId")
- doc_name: str = Field(description="文档名称", alias="docName")
+ doc_name: str = Field(default="", description="文档名称", alias="docName")
+ doc_author: str = Field(default="", description="文档作者", alias="docAuthor")
doc_abstract: str = Field(default="", description="文档摘要", alias="docAbstract")
doc_extension: str = Field(default="", description="文档扩展名", alias="docExtension")
doc_size: int = Field(default=0, description="文档大小,单位是KB", alias="docSize")
+ doc_created_at: str = Field(default="", description="文档创建时间", alias="docCreatedAt")
chunks: list[Chunk] = Field(default=[], description="分片列表", alias="chunks")
diff --git a/data_chain/llm/llm.py b/data_chain/llm/llm.py
index b5cd720ad24c049a8bce7e242fc1271b8df76449..098c27ab66331bdd3bf38d04a3f687bc35acbde8 100644
--- a/data_chain/llm/llm.py
+++ b/data_chain/llm/llm.py
@@ -1,12 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import asyncio
-import time
-import re
-import json
-import tiktoken
-from langchain_openai import ChatOpenAI
-from langchain.schema import SystemMessage, HumanMessage
-from data_chain.logger.logger import logger as logging
+from openai import AsyncOpenAI
+from data_chain.logger.logger import logger
class LLM:
@@ -17,79 +12,77 @@ class LLM:
self.max_tokens = max_tokens
self.request_timeout = request_timeout
self.temperature = temperature
- self.client = ChatOpenAI(model_name=model_name,
- openai_api_base=openai_api_base,
- openai_api_key=openai_api_key,
- request_timeout=request_timeout,
- max_tokens=max_tokens,
- temperature=temperature)
+ self._client = AsyncOpenAI(
+ api_key=self.openai_api_key,
+ base_url=self.openai_api_base,
+ )
def assemble_chat(self, chat=None, system_call='', user_call=''):
if chat is None:
chat = []
- chat.append(SystemMessage(content=system_call))
- chat.append(HumanMessage(content=user_call))
+ chat.append({"role": "system", "content": system_call})
+ chat.append({"role": "user", "content": user_call})
return chat
- async def nostream(self, chat, system_call, user_call,st_str:str=None,en_str:str=None):
- try:
- chat = self.assemble_chat(chat, system_call, user_call)
- response = await self.client.ainvoke(chat)
- content = re.sub(r'.*?\n?', '', response.content, flags=re.DOTALL)
- content = re.sub(r'.*?\n?', '', content, flags=re.DOTALL)
- content=content.strip()
- if st_str is not None:
- index = content.find(st_str)
- if index != -1:
- content = content[index:]
- if en_str is not None:
- index = content[::-1].find(en_str[::-1])
- if index != -1:
- content = content[:len(content)-index]
- logging.error("[LLM] 非流式输出内容: %s", content)
- except Exception as e:
- err = f"[LLM] 非流式输出异常: {e}"
- logging.error("[LLM] %s", err)
- return ''
- return content
+ async def create_stream(
+ self, message):
+ return await self._client.chat.completions.create(
+ model=self.model_name,
+ messages=message, # type: ignore[]
+ max_completion_tokens=self.max_tokens,
+ temperature=self.temperature,
+ stream=True,
+ stream_options={"include_usage": True},
+ timeout=300
+ ) # type: ignore[]
async def data_producer(self, q: asyncio.Queue, history, system_call, user_call):
message = self.assemble_chat(history, system_call, user_call)
+ stream = await self.create_stream(message)
try:
- async for frame in self.client.astream(message):
- await q.put(frame.content)
+ async for chunk in stream:
+ if len(chunk.choices) == 0:
+ continue
+ if chunk.choices[0].delta.content is not None:
+ content = chunk.choices[0].delta.content
+ else:
+ continue
+ await q.put(content)
except Exception as e:
await q.put(None)
err = f"[LLM] 流式输出生产者任务异常: {e}"
- logging.error("[LLM] %s", err)
+ logger.error(err)
raise e
await q.put(None)
async def stream(self, chat, system_call, user_call):
- st = time.time()
q = asyncio.Queue(maxsize=10)
# 启动生产者任务
- producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call))
- first_token_reach = False
- enc = tiktoken.encoding_for_model("gpt-4")
- input_tokens = len(enc.encode(system_call))
- output_tokens = 0
+ asyncio.create_task(self.data_producer(q, chat, system_call, user_call))
while True:
data = await q.get()
if data is None:
break
- if not first_token_reach:
- first_token_reach = True
- logging.info(f"大模型回复第一个字耗时 = {time.time() - st}")
- output_tokens += len(enc.encode(data))
- yield "data: " + json.dumps(
- {'content': data,
- 'input_tokens': input_tokens,
- 'output_tokens': output_tokens
- }, ensure_ascii=False
- ) + '\n\n'
- await asyncio.sleep(0.03) # 使用异步 sleep
+ yield data
- yield "data: [DONE]"
- logging.info(f"大模型回复耗时 = {time.time() - st}")
+ async def nostream(self, chat, system_call, user_call, st_str: str = None, en_str: str = None):
+ try:
+ content = ''
+ async for chunk in self.stream(chat, system_call, user_call):
+ content += chunk
+ content = content.strip()
+ if st_str is not None:
+ index = content.find(st_str)
+ if index != -1:
+ content = content[index:]
+ if en_str is not None:
+ index = content[::-1].find(en_str[::-1])
+ if index != -1:
+ content = content[:len(content)-index]
+ logger.error(f"LLM nostream content: {content}")
+ except Exception as e:
+ err = f"[LLM] 非流式输出异常: {e}"
+ logger.error("[LLM] %s", err)
+ return ''
+ return content
diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py
index 277abc3e41a18c3a8ea824c4808893ac8bcf36d0..4b755b9420d8935fbad92611711be84a260b5192 100644
--- a/data_chain/manager/chunk_manager.py
+++ b/data_chain/manager/chunk_manager.py
@@ -176,7 +176,7 @@ class ChunkManager():
async with await DataBase.get_session() as session:
fetch_cnt = top_k
chunk_entities = []
- while True:
+ for i in range(20):
# 计算相似度分数
similarity_score = ChunkEntity.text_vector.cosine_distance(vector).label("similarity_score")
@@ -191,9 +191,9 @@ class ChunkManager():
.where(ChunkEntity.kb_id == kb_id)
.where(ChunkEntity.enabled == True)
.where(ChunkEntity.status != ChunkStatus.DELETED.value)
- .where(ChunkEntity.id.notin_(banned_ids))
)
-
+ if banned_ids:
+ stmt = stmt.where(ChunkEntity.id.notin_(banned_ids))
# 添加可选条件
if doc_ids is not None:
stmt = stmt.where(DocumentEntity.id.in_(doc_ids))
@@ -270,9 +270,9 @@ class ChunkManager():
.where(ChunkEntity.kb_id == kb_id)
.where(ChunkEntity.enabled == True)
.where(ChunkEntity.status != ChunkStatus.DELETED.value)
- .where(ChunkEntity.id.notin_(banned_ids))
)
-
+ if banned_ids:
+ stmt = stmt.where(ChunkEntity.id.notin_(banned_ids))
if doc_ids is not None:
stmt = stmt.where(DocumentEntity.id.in_(doc_ids))
if chunk_to_type is not None:
@@ -357,8 +357,9 @@ class ChunkManager():
.where(ChunkEntity.kb_id == kb_id)
.where(ChunkEntity.enabled == True)
.where(ChunkEntity.status != ChunkStatus.DELETED.value)
- .where(ChunkEntity.id.notin_(banned_ids))
)
+ if banned_ids:
+ stmt = stmt.where(ChunkEntity.id.notin_(banned_ids))
# 添加 GROUP BY 子句,按 ChunkEntity.id 分组
stmt = stmt.group_by(ChunkEntity.id)
diff --git a/data_chain/manager/dataset_manager.py b/data_chain/manager/dataset_manager.py
index cb61aa8a495a892f7ca7271c2960ba6692f031dd..89fcaafd1484d07b6e29d2d1a7d6604afdccf205 100644
--- a/data_chain/manager/dataset_manager.py
+++ b/data_chain/manager/dataset_manager.py
@@ -98,7 +98,8 @@ class DatasetManager:
stmt = stmt.where(DataSetEntity.is_chunk_related == req.is_chunk_related)
if req.generate_status is not None:
status_list = [status.value for status in req.generate_status]
- status_list += [DataSetStatus.DELETED.value]
+ if TaskStatus.SUCCESS in req.generate_status:
+ status_list += [TaskStatus.DELETED.value]
stmt = stmt.where(subq.c.status.in_(status_list))
stmt = stmt.order_by(DataSetEntity.created_at.desc(), DataSetEntity.id.desc())
if req.score_order:
diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py
index 55cb87a90397d1d0a6df5db6176b423311b301c2..99ad63bf2e34f013afc64791d9f25d43a2d5e165 100644
--- a/data_chain/manager/document_manager.py
+++ b/data_chain/manager/document_manager.py
@@ -59,11 +59,12 @@ class DocumentManager():
stmt = (
select(DocumentEntity, similarity_score)
.where(DocumentEntity.kb_id == kb_id)
- .where(DocumentEntity.id.notin_(banned_ids))
.where(DocumentEntity.status != DocumentStatus.DELETED.value)
.where(DocumentEntity.enabled == True)
)
- if doc_ids:
+ if banned_ids:
+ stmt = stmt.where(DocumentEntity.id.notin_(banned_ids))
+ if doc_ids is not None:
stmt = stmt.where(DocumentEntity.id.in_(doc_ids))
stmt = stmt.order_by(
similarity_score
diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py
index 8f4db40dfc5fa1f28631d4a10711e219168790b6..b0df886ba3e7a15691f3b533d6d3993903c9a638 100644
--- a/data_chain/manager/task_queue_mamanger.py
+++ b/data_chain/manager/task_queue_mamanger.py
@@ -1,12 +1,10 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
from sqlalchemy import select, delete, update, desc, asc, func, exists, or_, and_
-from sqlalchemy.orm import aliased
import uuid
from typing import Dict, List, Optional, Tuple
from data_chain.logger.logger import logger as logging
-from data_chain.stores.database.database import DataBase, TaskEntity
-from data_chain.stores.mongodb.mongodb import MongoDB, Task
+from data_chain.stores.database.database import DataBase, TaskQueueEntity
from data_chain.entities.enum import TaskStatus
@@ -14,60 +12,72 @@ class TaskQueueManager():
"""任务队列管理类"""
@staticmethod
- async def add_task(task: Task):
+ async def add_task(task: TaskQueueEntity):
try:
- async with MongoDB.get_session() as session, await session.start_transaction():
- task_colletion = MongoDB.get_collection('witchiand_task')
- await task_colletion.insert_one(task.model_dump(by_alias=True), session=session)
+ async with await DataBase.get_session() as session:
+ session.add(task)
+ await session.commit()
except Exception as e:
err = "添加任务到队列失败"
logging.exception("[TaskQueueManager] %s", err)
+ raise e
@staticmethod
async def delete_task_by_id(task_id: uuid.UUID):
"""根据任务ID删除任务"""
try:
- async with MongoDB.get_session() as session, await session.start_transaction():
- task_colletion = MongoDB.get_collection('witchiand_task')
- await task_colletion.delete_one({"_id": task_id}, session=session)
+ async with await DataBase.get_session() as session:
+ stmt = delete(TaskQueueEntity).where(TaskQueueEntity.id == task_id)
+ await session.execute(stmt)
+ await session.commit()
except Exception as e:
err = "删除任务失败"
logging.exception("[TaskQueueManager] %s", err)
raise e
@staticmethod
- async def get_oldest_tasks_by_status(status: TaskStatus) -> Task:
+ async def get_oldest_tasks_by_status(status: TaskStatus) -> Optional[TaskQueueEntity]:
"""根据任务状态获取最早的任务"""
try:
- async with MongoDB.get_session() as session:
- task_colletion = MongoDB.get_collection('witchiand_task')
- task = await task_colletion.find_one({"status": status}, sort=[("created_time", 1)], session=session)
- return Task(**task) if task else None
+ async with await DataBase.get_session() as session:
+ stmt = (
+ select(TaskQueueEntity)
+ .where(TaskQueueEntity.status == status.value)
+ .order_by(asc(TaskQueueEntity.created_time))
+ .limit(1)
+ )
+ result = await session.execute(stmt)
+ return result.scalars().first()
except Exception as e:
err = "获取最早的任务失败"
logging.exception("[TaskQueueManager] %s", err)
raise e
@staticmethod
- async def get_task_by_id(task_id: uuid.UUID) -> Task:
+ async def get_task_by_id(task_id: uuid.UUID) -> Optional[TaskQueueEntity]:
"""根据任务ID获取任务"""
try:
- async with MongoDB.get_session() as session:
- task_colletion = MongoDB.get_collection('witchiand_task')
- task = await task_colletion.find_one({"_id": task_id}, session=session)
- return Task(**task) if task else None
+ async with await DataBase.get_session() as session:
+ stmt = select(TaskQueueEntity).where(TaskQueueEntity.id == task_id)
+ result = await session.execute(stmt)
+ return result.scalars().first()
except Exception as e:
err = "获取任务失败"
logging.exception("[TaskQueueManager] %s", err)
raise e
@staticmethod
- async def update_task_by_id(task_id: uuid.UUID, task: Task):
+ async def update_task_by_id(task_id: uuid.UUID, task: TaskQueueEntity):
"""根据任务ID更新任务"""
try:
- async with MongoDB.get_session() as session, await session.start_transaction():
- task_colletion = MongoDB.get_collection('witchiand_task')
- await task_colletion.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, session=session)
+ async with await DataBase.get_session() as session:
+ stmt = (
+ update(TaskQueueEntity)
+ .where(TaskQueueEntity.id == task_id)
+ .values(status=task.status)
+ )
+ await session.execute(stmt)
+ await session.commit()
except Exception as e:
err = "更新任务失败"
logging.exception("[TaskQueueManager] %s", err)
diff --git a/data_chain/parser/parse_result.py b/data_chain/parser/parse_result.py
index b69e3f451ae5c3335e2a2c4578ae5d4535f023b1..1f4db212187348033615929bb39db4f425344ef8 100644
--- a/data_chain/parser/parse_result.py
+++ b/data_chain/parser/parse_result.py
@@ -24,5 +24,6 @@ class ParseNode(BaseModel):
class ParseResult(BaseModel):
"""解析结果"""
+ doc_hash: str = Field(default='', description="文档hash值")
parse_topology_type: DocParseRelutTopology = Field(..., description="解析拓扑类型")
nodes: list[ParseNode] = Field(..., description="节点列表")
diff --git a/data_chain/parser/tools/ocr_tool.py b/data_chain/parser/tools/ocr_tool.py
index 858517dab74091ccc2f6d9badcf86049021e17cb..d2db21503883bee61e5455259e111fb42df88750 100644
--- a/data_chain/parser/tools/ocr_tool.py
+++ b/data_chain/parser/tools/ocr_tool.py
@@ -2,11 +2,13 @@ from PIL import Image, ImageEnhance
import yaml
import cv2
import numpy as np
+import requests
from data_chain.parser.tools.token_tool import TokenTool
from data_chain.logger.logger import logger as logging
from data_chain.config.config import config
from data_chain.llm.llm import LLM
from data_chain.parser.tools.instruct_scan_tool import InstructScanTool
+from data_chain.config.config import config
class OcrTool:
@@ -14,7 +16,7 @@ class OcrTool:
rec_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_rec_infer'
cls_model_dir = 'data_chain/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer'
# 优化 OCR 参数配置
- if InstructScanTool.check_avx512_support():
+ if InstructScanTool.check_avx512_support() and config['OCR_METHOD'] == "offline":
from paddleocr import PaddleOCR
model = PaddleOCR(
det_model_dir=det_model_dir,
@@ -30,6 +32,10 @@ class OcrTool:
async def ocr_from_image_path(image_path: str) -> list:
try:
# 打开图片
+ if config['OCR_METHOD'] == 'online' and config['OCR_API_URL']:
+ result = requests.get(config['OCR_API_URL'], files={'file': (
+ image_path, open(image_path, 'rb'), 'image/jpeg')}).json()
+ return result.get("result", [])
if OcrTool.model is None:
err = "[OCRTool] 当前机器不支持 AVX-512,无法进行OCR识别"
logging.error(err)
@@ -58,6 +64,8 @@ class OcrTool:
async def merge_text_from_ocr_result(ocr_result: list) -> str:
text = ''
try:
+ if ocr_result[0] is None or len(ocr_result[0]) == 0:
+ return ""
for _ in ocr_result[0]:
text += str(_[1][0])
return text
@@ -67,16 +75,20 @@ class OcrTool:
return ''
@staticmethod
- async def enhance_ocr_result(ocr_result, image_related_text='', llm: LLM = None) -> str:
+ async def enhance_ocr_result(ocr_result, image_related_text='', llm: LLM = None, language: str = "中文") -> str:
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('OCR_ENHANCED_PROMPT', '')
+ prompt_template = prompt_dict.get('OCR_ENHANCED_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
pre_part_description = ""
token_limit = llm.max_tokens//2
image_related_text = TokenTool.get_k_tokens_words_from_content(image_related_text, token_limit)
ocr_result_parts = TokenTool.split_str_with_slide_window(str(ocr_result), token_limit)
- user_call = '请详细输出图片的摘要,不要输出其他内容'
+ if language == 'en':
+ user_call = 'Please provide a English summary of the image content, do not output anything else'
+ else:
+ user_call = '请详细输出图片的中文摘要,不要输出其他内容'
for part in ocr_result_parts:
pre_part_description_cp = pre_part_description
try:
@@ -96,19 +108,16 @@ class OcrTool:
return OcrTool.merge_text_from_ocr_result(ocr_result)
@staticmethod
- async def image_to_text(image: np.ndarray, image_related_text: str = '', llm: LLM = None) -> str:
+ async def image_to_text(
+ image_file_path: str, image_related_text: str = '', llm: LLM = None, language: str = '中文') -> str:
try:
- if OcrTool.model is None:
- err = "[OCRTool] 当前机器不支持 AVX-512,无法进行OCR识别"
- logging.error(err)
- return ''
- ocr_result = await OcrTool.ocr_from_image(image)
+ ocr_result = await OcrTool.ocr_from_image_path(image_file_path)
if ocr_result is None:
return ''
if llm is None:
text = await OcrTool.merge_text_from_ocr_result(ocr_result)
else:
- text = await OcrTool.enhance_ocr_result(ocr_result, image_related_text, llm)
+ text = await OcrTool.enhance_ocr_result(ocr_result, image_related_text, llm, language)
if "图片内容为空" in text:
return ""
return text
diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py
index a9a050fd135d1878d1b46579e29814dcbc8466e2..5a6be14e8b5bb9772404e2b1796433c472df8f49 100644
--- a/data_chain/parser/tools/token_tool.py
+++ b/data_chain/parser/tools/token_tool.py
@@ -261,20 +261,24 @@ class TokenTool:
return [sentence for index, sentence, score in top_k_sentence_and_score_list]
@staticmethod
- async def get_abstract_by_llm(content: str, llm: LLM) -> str:
+ async def get_abstract_by_llm(content: str, llm: LLM, language: str) -> str:
"""
使用llm进行内容摘要
"""
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('CONTENT_TO_ABSTRACT_PROMPT', '')
+ prompt_template = prompt_dict.get('CONTENT_TO_ABSTRACT_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
sentences = TokenTool.split_str_with_slide_window(content, llm.max_tokens//3*2)
abstract = ''
for sentence in sentences:
abstract = TokenTool.get_k_tokens_words_from_content(abstract, llm.max_tokens//3)
sys_call = prompt_template.format(content=sentence, abstract=abstract)
- user_call = '请结合文本和摘要输出新的摘要'
+ if language == 'en':
+ user_call = 'Please output a new English abstract based on the text and the existing abstract'
+ else:
+ user_call = '请结合文本和已有摘要生成新的中文摘要'
abstract = await llm.nostream([], sys_call, user_call)
return abstract
except Exception as e:
@@ -282,17 +286,21 @@ class TokenTool:
logging.exception("[TokenTool] %s", err)
@staticmethod
- async def get_title_by_llm(content: str, llm: LLM) -> str:
+ async def get_title_by_llm(content: str, llm: LLM, language: str = '中文') -> str:
"""
使用llm进行标题生成
"""
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('CONTENT_TO_TITLE_PROMPT', '')
+ prompt_template = prompt_dict.get('CONTENT_TO_TITLE_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
content = TokenTool.get_k_tokens_words_from_content(content, llm.max_tokens)
sys_call = prompt_template.format(content=content)
- user_call = '请结合文本输出标题'
+ if language == 'en':
+ user_call = 'Please generate a English title based on the text'
+ else:
+ user_call = '请结合文本生成一个中文标题'
title = await llm.nostream([], sys_call, user_call)
return title
except Exception as e:
@@ -300,7 +308,7 @@ class TokenTool:
logging.exception("[TokenTool] %s", err)
@staticmethod
- async def cal_recall(answer_1: str, answer_2: str, llm: LLM) -> float:
+ async def cal_recall(answer_1: str, answer_2: str, llm: LLM, language: str) -> float:
"""
计算recall
参数:
@@ -311,7 +319,8 @@ class TokenTool:
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', '')
+ prompt_template = prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
answer_1 = TokenTool.get_k_tokens_words_from_content(answer_1, llm.max_tokens//2)
answer_2 = TokenTool.get_k_tokens_words_from_content(answer_2, llm.max_tokens//2)
prompt = prompt_template.format(text_1=answer_1, text_2=answer_2)
@@ -325,7 +334,7 @@ class TokenTool:
return -1
@staticmethod
- async def cal_precision(question: str, content: str, llm: LLM) -> float:
+ async def cal_precision(question: str, content: str, llm: LLM, language: str) -> float:
"""
计算precision
参数:
@@ -335,17 +344,19 @@ class TokenTool:
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', '')
+ prompt_template = prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
content = TokenTool.compress_tokens(content, llm.max_tokens)
sys_call = prompt_template.format(content=content)
user_call = '请结合文本输出陈诉列表'
statements = await llm.nostream([], sys_call, user_call, st_str='[',
- en_str=']')
+ en_str=']')
statements = json.loads(statements)
if len(statements) == 0:
return 0
score = 0
- prompt_template = prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', '')
+ prompt_template = prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
for statement in statements:
statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens)
prompt = prompt_template.format(statement=statement, question=question)
@@ -362,7 +373,7 @@ class TokenTool:
return -1
@staticmethod
- async def cal_faithfulness(question: str, answer: str, content: str, llm: LLM) -> float:
+ async def cal_faithfulness(question: str, answer: str, content: str, llm: LLM, language: str) -> float:
"""
计算faithfulness
参数:
@@ -372,15 +383,17 @@ class TokenTool:
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('QA_TO_STATEMENTS_PROMPT', '')
+ prompt_template = prompt_dict.get('QA_TO_STATEMENTS_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
question = TokenTool.get_k_tokens_words_from_content(question, llm.max_tokens//8)
answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//8*7)
prompt = prompt_template.format(question=question, answer=answer)
sys_call = prompt
user_call = '请结合问题和答案输出陈诉'
- statements = await llm.nostream([], sys_call, user_call,st_str='[',
- en_str=']')
- prompt_template = prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', '')
+ statements = await llm.nostream([], sys_call, user_call, st_str='[',
+ en_str=']')
+ prompt_template = prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
statements = json.loads(statements)
if len(statements) == 0:
return 0
@@ -416,7 +429,7 @@ class TokenTool:
return cosine_dist
@staticmethod
- async def cal_relevance(question: str, answer: str, llm: LLM) -> float:
+ async def cal_relevance(question: str, answer: str, llm: LLM, language: str) -> float:
"""
计算relevance
参数:
@@ -426,7 +439,8 @@ class TokenTool:
try:
with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- prompt_template = prompt_dict.get('GENREATE_QUESTION_FROM_CONTENT_PROMPT', '')
+ prompt_template = prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens)
sys_call = prompt_template.format(k=5, content=answer)
user_call = '请结合文本输出问题列表'
diff --git a/data_chain/rag/doc2chunk_bfs_searcher.py b/data_chain/rag/doc2chunk_bfs_searcher.py
index c72e7bd1c1a05b2c9ea68e027dac522714871210..629e8d230b3095ec7e7d8352e673718e244f9ce5 100644
--- a/data_chain/rag/doc2chunk_bfs_searcher.py
+++ b/data_chain/rag/doc2chunk_bfs_searcher.py
@@ -37,7 +37,7 @@ class Doc2ChunkBfsSearcher(BaseSearcher):
root_chunk_entities_vector = []
for _ in range(3):
try:
- root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=3)
+ root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, ChunkParseTopology.TREEROOT.value), timeout=20)
break
except Exception as e:
err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}"
@@ -54,7 +54,7 @@ class Doc2ChunkBfsSearcher(BaseSearcher):
root_chunk_entities_vector = []
for _ in range(3):
try:
- root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, None, pre_ids), timeout=3)
+ root_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(root_chunk_entities_keyword), doc_ids, banned_ids, None, pre_ids), timeout=20)
break
except Exception as e:
err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}"
diff --git a/data_chain/rag/doc2chunk_searcher.py b/data_chain/rag/doc2chunk_searcher.py
index 40510d513c9f2355a87a3ac3efe64f1fb008794c..6b3aca75cb1cc0d915acabe029a8d100364cd9ff 100644
--- a/data_chain/rag/doc2chunk_searcher.py
+++ b/data_chain/rag/doc2chunk_searcher.py
@@ -37,7 +37,7 @@ class Doc2ChunkSearcher(BaseSearcher):
doc_entities_vector = []
for _ in range(3):
try:
- doc_entities_vector = await asyncio.wait_for(DocumentManager.get_top_k_document_by_kb_id_vector(kb_id, vector, top_k-len(doc_entities_keyword), use_doc_ids, banned_ids), timeout=3)
+ doc_entities_vector = await asyncio.wait_for(DocumentManager.get_top_k_document_by_kb_id_vector(kb_id, vector, top_k-len(doc_entities_keyword), use_doc_ids, banned_ids), timeout=10)
break
except Exception as e:
err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}"
@@ -53,7 +53,7 @@ class Doc2ChunkSearcher(BaseSearcher):
chunk_entities_vector = []
for _ in range(3):
try:
- chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=3)
+ chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_keyword), use_doc_ids, banned_ids), timeout=10)
break
except Exception as e:
err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}"
diff --git a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py
index 5efe05ee080f65920d5fcdff60fe0ae80745ccc8..280ff8b591911337d7e84425201b2913ed51f0f5 100644
--- a/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py
+++ b/data_chain/rag/dynamic_weighted_keyword_and_vector_searcher.py
@@ -42,12 +42,13 @@ class KeywordVectorSearcher(BaseSearcher):
try:
import time
start_time = time.time()
- chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=3)
+ chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword)-len(chunk_entities_get_by_dynamic_weighted_keyword), doc_ids, banned_ids), timeout=20)
end_time = time.time()
logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒")
break
except Exception as e:
- err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}"
+ import traceback
+ err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}, traceback: {traceback.format_exc()}"
logging.error(err)
continue
chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_dynamic_weighted_keyword + chunk_entities_get_by_vector
diff --git a/data_chain/rag/enhanced_by_llm_searcher.py b/data_chain/rag/enhanced_by_llm_searcher.py
index 00b7bae3afb3397ad67c77402ad7568ffdaf3416..738eaac974ef3121aff613532edec4bff75afe8b 100644
--- a/data_chain/rag/enhanced_by_llm_searcher.py
+++ b/data_chain/rag/enhanced_by_llm_searcher.py
@@ -13,6 +13,7 @@ from data_chain.entities.enum import SearchMethod
from data_chain.parser.tools.token_tool import TokenTool
from data_chain.llm.llm import LLM
from data_chain.config.config import config
+from data_chain.manager.knowledge_manager import KnowledgeBaseManager
class EnhancedByLLMSearcher(BaseSearcher):
@@ -36,7 +37,9 @@ class EnhancedByLLMSearcher(BaseSearcher):
try:
with open('./data_chain/common/prompt.yaml', 'r', encoding='utf-8') as f:
prompt_dict = yaml.safe_load(f)
- prompt_template = prompt_dict['CHUNK_QUERY_MATCH_PROMPT']
+ knowledge_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id)
+ prompt_template = prompt_dict.get('CHUNK_QUERY_MATCH_PROMPT', {})
+ prompt_template = prompt_template.get(knowledge_entity.tokenizer, '')
chunk_entities = []
rd = 0
max_retry = 5
@@ -56,7 +59,7 @@ class EnhancedByLLMSearcher(BaseSearcher):
sub_chunk_entities_vector = []
for _ in range(3):
try:
- sub_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=3)
+ sub_chunk_entities_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=20)
break
except Exception as e:
err = f"[EnhancedByLLMSearcher] 向量检索失败,error: {e}"
diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py
index 86b3b4f5cfca9065c6318caa45ab39c2ae517f74..9a0c7de20b63cae03044d217406caa4f6c1a3939 100644
--- a/data_chain/rag/keyword_and_vector_searcher.py
+++ b/data_chain/rag/keyword_and_vector_searcher.py
@@ -40,7 +40,7 @@ class KeywordVectorSearcher(BaseSearcher):
try:
import time
start_time = time.time()
- chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3)
+ chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=20)
end_time = time.time()
logging.info(f"[KeywordVectorSearcher] 向量检索成功完成,耗时: {end_time - start_time:.2f}秒")
break
diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py
index a09f660baeebd3e16ce61b31ae85e9723bb2fd3f..bbccd51035af5b933ed85ed77395a52cfb97ad6f 100644
--- a/data_chain/rag/query_extend_searcher.py
+++ b/data_chain/rag/query_extend_searcher.py
@@ -14,6 +14,7 @@ from data_chain.entities.enum import SearchMethod
from data_chain.parser.tools.token_tool import TokenTool
from data_chain.llm.llm import LLM
from data_chain.config.config import config
+from data_chain.manager.knowledge_manager import KnowledgeBaseManager
class QueryExtendSearcher(BaseSearcher):
@@ -35,7 +36,9 @@ class QueryExtendSearcher(BaseSearcher):
"""
with open('./data_chain/common/prompt.yaml', 'r', encoding='utf-8') as f:
prompt_dict = yaml.safe_load(f)
- prompt_template = prompt_dict['QUERY_EXTEND_PROMPT']
+ konwledge_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id)
+ prompt_template = prompt_dict.get('QUERY_EXTEND_PROMPT', {})
+ prompt_template = prompt_template.get(konwledge_entity.tokenizer, '')
chunk_entities = []
llm = LLM(
openai_api_key=config['OPENAI_API_KEY'],
@@ -61,7 +64,7 @@ class QueryExtendSearcher(BaseSearcher):
chunk_entities_get_by_vector = []
for _ in range(3):
try:
- chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3)
+ chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=20)
break
except Exception as e:
err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}"
diff --git a/data_chain/rag/vector_searcher.py b/data_chain/rag/vector_searcher.py
index dad5e8676792927fa28f27a0ec9b8ac0cb08a079..1bd1d0cac655c2196db2232d84af26da3b3e02fe 100644
--- a/data_chain/rag/vector_searcher.py
+++ b/data_chain/rag/vector_searcher.py
@@ -29,7 +29,7 @@ class VectorSearcher(BaseSearcher):
chunk_entities = []
for _ in range(3):
try:
- chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=3)
+ chunk_entities = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k, doc_ids, banned_ids), timeout=20)
break
except Exception as e:
err = f"[VectorSearcher] 向量检索失败,error: {e}"
diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py
index 4e8ae10d581bdeb25159026b668404bb1f3f08db..e0af0a43c52c930f326ac5b25e6ed75a6997ab8d 100644
--- a/data_chain/stores/database/database.py
+++ b/data_chain/stores/database/database.py
@@ -1,13 +1,15 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy import Index
+from datetime import datetime
+import uuid
from uuid import uuid4
import urllib.parse
from data_chain.logger.logger import logger as logging
from pgvector.sqlalchemy import Vector
from sqlalchemy import Boolean, Column, ForeignKey, Integer, Float, String, func
from sqlalchemy.types import TIMESTAMP, UUID
-from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import declarative_base, DeclarativeBase, MappedAsDataclass, Mapped, mapped_column
from data_chain.config.config import config
from data_chain.entities.enum import (Tokenizer,
ParseMethod,
@@ -534,6 +536,23 @@ class TaskReportEntity(Base):
)
+class TaskQueueEntity(Base):
+ __tablename__ = 'task_queue'
+
+ id = Column(UUID, default=uuid4, primary_key=True) # 任务ID
+ status = Column(String) # 任务状态
+ created_time = Column(
+ TIMESTAMP(timezone=True),
+ nullable=True,
+ server_default=func.current_timestamp()
+ )
+ # 添加索引以提高查询性能
+ __table_args__ = (
+ Index('idx_task_queue_status', 'status'),
+ Index('idx_task_queue_created_time', 'created_time'),
+ )
+
+
class DataBase:
# 对密码进行 URL 编码
diff --git a/ocr_server/init.py b/ocr_server/init.py
new file mode 100644
index 0000000000000000000000000000000000000000..19012fa73b9771ef13cbc7e372ec8e97a02f3d61
--- /dev/null
+++ b/ocr_server/init.py
@@ -0,0 +1,6 @@
+from paddleocr import PaddleOCR
+import cv2
+ocr = PaddleOCR(use_angle_cls=True, lang="ch")
+image_path = 'test.jpg'
+image = cv2.imread(image_path)
+result = ocr.predict(image)
diff --git a/ocr_server/requiremenets.text b/ocr_server/requiremenets.text
new file mode 100644
index 0000000000000000000000000000000000000000..0066d61d9aa587a9f34b015fa36ad82a8bb1d1a5
--- /dev/null
+++ b/ocr_server/requiremenets.text
@@ -0,0 +1,5 @@
+aiofiles 24.1.0
+fastapi 0.116.1
+paddleocr 3.2.0
+paddlepaddle 3.1.1
+uvicorn 0.35.0
\ No newline at end of file
diff --git a/ocr_server/requiremenets.txt b/ocr_server/requiremenets.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d61d9906ba30f3cdff775d7939b0eda976a00fbf
--- /dev/null
+++ b/ocr_server/requiremenets.txt
@@ -0,0 +1,61 @@
+aiofiles==24.1.0
+aistudio_sdk==0.3.5
+annotated-types==0.7.0
+anyio==4.10.0
+bce-python-sdk==0.9.42
+certifi==2025.8.3
+chardet==5.2.0
+charset-normalizer==3.4.3
+click==8.2.1
+colorlog==6.9.0
+fastapi==0.116.1
+filelock==3.19.1
+fsspec==2025.7.0
+future==1.0.0
+h11==0.16.0
+hf-xet==1.1.8
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==0.34.4
+idna==3.10
+imagesize==1.4.1
+mdc==1.2.1
+modelscope==1.29.1
+networkx==3.5
+numpy==2.3.2
+opencv-contrib-python==4.10.0.84
+opt-einsum==3.3.0
+packaging==25.0
+paddleocr==3.2.0
+paddlepaddle==3.1.1
+paddlex==3.2.0
+pandas==2.3.2
+pillow==11.3.0
+prettytable==3.16.0
+protobuf==6.32.0
+py-cpuinfo==9.0.0
+pyclipper==1.3.0.post6
+pycryptodome==3.23.0
+pydantic==2.11.7
+pydantic_core==2.33.2
+pypdfium2==4.30.0
+python-dateutil==2.9.0.post0
+python-json-logger==2.0.7
+python-multipart==0.0.20
+pytz==2025.2
+PyYAML==6.0.2
+requests==2.32.5
+ruamel.yaml==0.18.15
+ruamel.yaml.clib==0.2.12
+shapely==2.1.1
+six==1.17.0
+sniffio==1.3.1
+starlette==0.47.3
+tqdm==4.67.1
+typing-inspection==0.4.1
+typing_extensions==4.15.0
+tzdata==2025.2
+ujson==5.11.0
+urllib3==2.5.0
+uvicorn==0.35.0
+wcwidth==0.2.13
diff --git a/ocr_server/server.py b/ocr_server/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b53d95a66bb815227c7c3547341f5944ccc2e23
--- /dev/null
+++ b/ocr_server/server.py
@@ -0,0 +1,152 @@
+from typing import Any
+from pydantic import BaseModel
+from fastapi import FastAPI, UploadFile, File
+from fastapi.responses import JSONResponse
+import logging
+import os
+import aiofiles
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+import uuid
+from datetime import datetime
+
+import os
+# 强制离线模式
+os.environ["PADDLEX_OFFLINE"] = "True"
+# 禁用Paddle的网络请求
+os.environ["PADDLE_NO_NETWORK"] = "True"
+# 指定模型缓存路径(确保已放置模型)
+os.environ["PADDLEX_HOME"] = "/root/.paddlex"
+
+
+class ResponseData(BaseModel):
+ """基础返回数据结构"""
+
+ code: int
+ message: str
+ result: list
+
+
+# 配置日志
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# 初始化FastAPI应用
+app = FastAPI()
+
+# 创建保存上传文件的目录
+UPLOAD_DIR = "uploaded_files"
+os.makedirs(UPLOAD_DIR, exist_ok=True)
+
+# 初始化PaddleOCR,使用角度分类,中文识别
+ocr = PaddleOCR(
+ # 1. 指定本地文本检测模型目录(PP-OCRv5_server_det)
+ text_detection_model_dir="/root/.paddlex/official_models/PP-OCRv5_server_det",
+ # 2. 指定本地文本识别模型目录(PP-OCRv5_server_rec)
+ text_recognition_model_dir="/root/.paddlex/official_models/PP-OCRv5_server_rec",
+ # (可选)若需要文档方向分类/文本行方向分类,也可指定对应本地模型
+ doc_orientation_classify_model_dir="/root/.paddlex/official_models/PP-LCNet_x1_0_doc_ori",
+ textline_orientation_model_dir="/root/.paddlex/official_models/PP-LCNet_x1_0_textline_ori",
+ # (可选)关闭不需要的功能(如文档矫正,根据需求调整)
+ use_doc_unwarping=False, # 若不需要 UVDoc 文档矫正,可关闭
+ lang=None, # 因已指定本地模型,lang/ocr_version 会被自动忽略(符合原代码逻辑)
+ ocr_version=None,
+ device="npu:0"
+)
+
+
+@app.get("/ocr", response_model=ResponseData)
+async def ocr_recognition(file: UploadFile = File(...)) -> JSONResponse:
+ """
+ 接收上传的图片文件,先保存到本地,再进行OCR识别并返回结果字符串
+ """
+ try:
+ # 生成唯一的文件名,避免冲突
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ file_extension = os.path.splitext(file.filename)[1]
+ unique_filename = f"{timestamp}_{uuid.uuid4().hex[:8]}{file_extension}"
+ file_path = os.path.join(UPLOAD_DIR, unique_filename)
+
+ # 异步保存文件到本地
+ async with aiofiles.open(file_path, 'wb') as out_file:
+ content = await file.read() # 读取文件内容
+ await out_file.write(content) # 写入到本地文件
+
+ logger.info(f"文件已保存到: {file_path}, 大小: {len(content)} bytes")
+
+ # 使用cv2读取本地文件
+ image = cv2.imread(file_path)
+ if image is None:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ logger.error("无法读取上传的图片文件,可能格式不支持或文件损坏")
+ return JSONResponse(
+ status_code=400,
+ content=ResponseData(
+ code=400,
+ message="无法读取上传的图片文件,可能格式不支持或文件损坏",
+ result=[]
+ ).model_dump(exclude_none=True)
+ )
+
+ logger.info(f"图片读取成功,尺寸: {image.shape}")
+
+ # 进行OCR识别
+ # PaddleOCR可以直接处理numpy数组(cv2格式)
+ result = ocr.predict(image)
+ if not result:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ return JSONResponse(
+ status_code=200,
+ content=ResponseData(
+ code=200,
+ message="OCR识别完成,但未识别到任何文本",
+ result=[]
+ ).model_dump(exclude_none=True)
+ )
+ if not result[0]:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ return JSONResponse(
+ status_code=200,
+ content=ResponseData(
+ code=200,
+ message="OCR识别完成,但未识别到任何文本",
+ result=[]
+ ).model_dump(exclude_none=True)
+ )
+ rec_texts = result[0].get("rec_texts", [])
+ rec_scores = result[0].get("rec_scores", [])
+ rec_polys = result[0].get("rec_polys", [])
+ rt = [[]]
+ for i, text in enumerate(rec_texts):
+ rt[0].append([rec_polys[i].tolist(), [text, float(f"{rec_scores[i]:.4f}")]])
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ return JSONResponse(
+ status_code=200,
+ content=ResponseData(
+ code=200,
+ message="OCR识别成功",
+ result=rt
+ ).model_dump(exclude_none=True)
+ )
+
+ except Exception as e:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ logger.error(f"处理过程出错: {str(e)}", exc_info=True)
+ return JSONResponse(
+ status_code=500,
+ content=ResponseData(
+ code=500,
+ message=f"处理过程出错: {str(e)}",
+ result=[]
+ ).model_dump(exclude_none=True)
+ )
+if __name__ == "__main__":
+ import uvicorn
+ # 在9999端口启动服务,允许外部访问
+ uvicorn.run(app, host="0.0.0.0", port=9999)
diff --git a/ocr_server/test.jpg b/ocr_server/test.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..614fa5c83ed355d60a1119a32ee74300648845c6
Binary files /dev/null and b/ocr_server/test.jpg differ
diff --git a/ocr_server/test.py b/ocr_server/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f11c4c9da33713da802525e9ac8e41805ef943ee
--- /dev/null
+++ b/ocr_server/test.py
@@ -0,0 +1,52 @@
+import requests
+
+
+def call_ocr_api(image_path, api_url="http://localhost:9999/ocr"):
+ """
+ 调用OCR接口识别图片中的文字
+
+ 参数:
+ image_path: 本地图片文件路径
+ api_url: OCR接口的URL地址
+
+ 返回:
+ 识别到的文字字符串
+ """
+ try:
+ # 打开图片文件并准备上传
+ with open(image_path, 'rb') as file:
+ # 构造表单数据,键名需与接口中的参数名一致
+ files = {'file': (image_path, file, 'image/jpeg')}
+ # 发送GET请求
+ response = requests.get(api_url, files=files)
+
+ # 检查响应状态
+ if response.status_code == 200:
+ # 返回识别结果
+ return response.json()
+ else:
+ print(f"请求失败,状态码: {response.status_code}")
+ print(f"错误信息: {response.text}")
+ return None
+
+ except FileNotFoundError:
+ print(f"错误: 找不到图片文件 {image_path}")
+ return None
+ except Exception as e:
+ print(f"调用接口时发生错误: {str(e)}")
+ return None
+
+
+# 使用示例
+if __name__ == "__main__":
+ # 替换为你的图片路径
+ image_path = "test.jpg"
+ # 调用OCR接口
+ result = call_ocr_api(image_path)
+
+ if result:
+ print("OCR识别结果:")
+ print("-" * 50)
+ print(type(result))
+ print(result)
+ print("-" * 50)
diff --git a/requirements.txt b/requirements.txt
index e2b1e3981564323a264e26273bca90b2c45f447d..f97dae85049f6f64d8260f48fc241e046838ac26 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,14 +11,12 @@ fastapi-pagination==0.12.19
httpx==0.27.0
itsdangerous==2.1.2
jieba==0.42.1
-langchain==0.3.7
-langchain-openai==0.2.5
minio==7.2.4
markdown2==2.5.2
markdown==3.3.4
more-itertools==10.1.0
numpy==1.26.4
-openai==1.65.2
+openai==1.91.0
opencv-python==4.9.0.80
openpyxl==3.1.2
paddleocr==2.9.1
@@ -48,4 +46,5 @@ uvicorn==0.21.0
xlrd==2.0.1
py-cpuinfo==9.0.0
opengauss-sqlalchemy==2.4.0
-#marker-pdf==1.8.0
\ No newline at end of file
+#marker-pdf==1.8.0
+motor==3.7.1
\ No newline at end of file
diff --git a/run.sh b/run.sh
index 56374de5e257fa6ea17027f6992391dfa6204573..706e3107068fb05753185a86ef484920df4b67f3 100644
--- a/run.sh
+++ b/run.sh
@@ -1,9 +1,7 @@
#!/usr/bin/env sh
java -jar tika-server-standard-2.9.2.jar &
-python3 /rag-service/chat2db/app/app.py &
+python3 /rag-service/chat2db/main.py &
python3 /rag-service/data_chain/apps/app.py &
-sleep 5
-python3 /rag-service/chat2db/common/init_sql_example.py
while true
do
diff --git a/test.pdf b/test.pdf
deleted file mode 100644
index a64e0a48ef0f81bd2dde554984850f61956cfc41..0000000000000000000000000000000000000000
Binary files a/test.pdf and /dev/null differ
diff --git a/test/config.py b/test/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba2a4bcc32e88811a959f25f76a3124a8d129ba3
--- /dev/null
+++ b/test/config.py
@@ -0,0 +1,51 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""配置文件处理模块"""
+import toml
+from enum import Enum
+from typing import Any
+from pydantic import BaseModel, Field
+from pathlib import Path
+from copy import deepcopy
+import sys
+import os
+
+
+class LLMConfig(BaseModel):
+ """LLM配置模型"""
+ llm_endpoint: str = Field(default="https://dashscope.aliyuncs.com/compatible-mode/v1", description="LLM远程主机地址")
+ llm_api_key: str = Field(default="", description="LLM API Key")
+ llm_model_name: str = Field(default="qwen3-coder-480b-a35b-instruct", description="LLM模型名称")
+ max_tokens: int = Field(default=8192, description="LLM最大Token数")
+ temperature: float = Field(default=0.7, description="LLM温度参数")
+
+
+class EmbeddingType(str, Enum):
+ OPENAI = "openai"
+ MINDIE = "mindie"
+
+
+class EmbeddingConfig(BaseModel):
+ """Embedding配置模型"""
+ embedding_type: EmbeddingType = Field(default=EmbeddingType.OPENAI, description="向量化类型")
+ embedding_endpoint: str = Field(default="", description="向量化API地址")
+ embedding_api_key: str = Field(default="", description="向量化API Key")
+ embedding_model_name: str = Field(default="text-embedding-3-small", description="向量化模型名称")
+
+
+class ConfigModel(BaseModel):
+ """公共配置模型"""
+ embedding: EmbeddingConfig = Field(default=EmbeddingConfig(), description="向量化配置")
+ llm: LLMConfig = Field(default=LLMConfig(), description="LLM配置")
+
+
+class BaseConfig():
+ """配置文件读取和使用Class"""
+
+ def __init__(self) -> None:
+ """读取配置文件;当PROD环境变量设置时,配置文件将在读取后删除"""
+ config_file = os.path.join("config.toml")
+ self._config = ConfigModel.model_validate(toml.load(config_file))
+
+ def get_config(self) -> ConfigModel:
+ """获取配置文件内容"""
+ return deepcopy(self._config)
diff --git a/test/config.toml b/test/config.toml
new file mode 100644
index 0000000000000000000000000000000000000000..64cf50923f1898d92979c26c9f7e4d5d5de66ef6
--- /dev/null
+++ b/test/config.toml
@@ -0,0 +1,12 @@
+[embedding]
+embedding_type = "openai"
+embedding_endpoint = "https://api.siliconflow.cn/v1"
+embedding_api_key = "sk-123456"
+embedding_model_name = "BAAI/bge-m3"
+
+[llm]
+llm_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
+llm_api_key = "sk-123456"
+llm_model_name = "qwen3-coder-480b-a35b-instruct"
+max_tokens = 8192
+temperature = 0.7
\ No newline at end of file
diff --git a/chat2db/app/base/vectorize.py b/test/embedding.py
similarity index 34%
rename from chat2db/app/base/vectorize.py
rename to test/embedding.py
index 5362047fa0fd407a523bba76e1862e77aa6ef389..0637189d01e014cbf7afd9a5643020018cfc6b5f 100644
--- a/chat2db/app/base/vectorize.py
+++ b/test/embedding.py
@@ -1,47 +1,50 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import requests
-import urllib3
-from chat2db.config.config import config
import json
-import sys
-import logging
-
-logging.basicConfig(stream=sys.stdout, level=logging.INFO,
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+import urllib3
+from config import BaseConfig, EmbeddingType
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
-class Vectorize():
+class Embedding():
@staticmethod
async def vectorize_embedding(text):
- if config['EMBEDDING_TYPE']=='openai':
+ vector = None
+ if BaseConfig().get_config().embedding.embedding_type == EmbeddingType.OPENAI:
headers = {
- "Authorization": f"Bearer {config['EMBEDDING_API_KEY']}"
- }
+ "Authorization": f"Bearer {BaseConfig().get_config().embedding.embedding_api_key}",
+ }
data = {
"input": text,
- "model": config["EMBEDDING_MODEL_NAME"],
+ "model": BaseConfig().get_config().embedding.embedding_model_name,
"encoding_format": "float"
}
try:
- res = requests.post(url=config["EMBEDDING_ENDPOINT"],headers=headers, json=data, verify=False)
+ res = requests.post(url=BaseConfig().get_config().embedding.embedding_endpoint,
+ headers=headers, json=data, verify=False)
if res.status_code != 200:
return None
- return res.json()['data'][0]['embedding']
+ vector = res.json()['data'][0]['embedding']
except Exception as e:
- logging.error(f"Embedding error failed due to: {e}")
+ err = f"[Embedding] 向量化失败 ,error: {e}"
+ print(err)
return None
- elif config['EMBEDDING_TYPE'] =='mindie':
+ elif BaseConfig().get_config().embedding.embedding_type == 'mindie':
try:
data = {
- "inputs": text,
+ "inputs": text,
}
- res = requests.post(url=config["EMBEDDING_ENDPOINT"], json=data, verify=False)
+ res = requests.post(url=BaseConfig().get_config().embedding.embedding_endpoint, json=data, verify=False)
if res.status_code != 200:
return None
- return json.loads(res.text)[0]
+ vector = json.loads(res.text)[0]
except Exception as e:
- logging.error(f"Embedding error failed due to: {e}")
- return None
+ err = f"[Embedding] 向量化失败 ,error: {e}"
+ print(err)
+ return None
else:
return None
+ while len(vector) < 1024:
+ vector.append(0)
+ return vector[:1024]
diff --git a/test/tools/llm.py b/test/llm.py
similarity index 51%
rename from test/tools/llm.py
rename to test/llm.py
index 103f4ff1f4577ca1fd6d800256c56b6d10445b66..f9e19cc651d65621d7a38610c527237ffd6e90d4 100644
--- a/test/tools/llm.py
+++ b/test/llm.py
@@ -1,60 +1,93 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-import asyncio
-import time
-import json
-from langchain_openai import ChatOpenAI
-from langchain.schema import SystemMessage, HumanMessage
-
-
-class LLM:
- def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1):
- self.client = ChatOpenAI(model_name=model_name,
- openai_api_base=openai_api_base,
- openai_api_key=openai_api_key,
- request_timeout=request_timeout,
- max_tokens=max_tokens,
- temperature=temperature)
- print(model_name)
- def assemble_chat(self, chat=None, system_call='', user_call=''):
- if chat is None:
- chat = []
- chat.append(SystemMessage(content=system_call))
- chat.append(HumanMessage(content=user_call))
- return chat
-
- async def nostream(self, chat, system_call, user_call):
- chat = self.assemble_chat(chat, system_call, user_call)
- response = await self.client.ainvoke(chat)
- return response.content
-
- async def data_producer(self, q: asyncio.Queue, history, system_call, user_call):
- message = self.assemble_chat(history, system_call, user_call)
- try:
- async for frame in self.client.astream(message):
- await q.put(frame.content)
- except Exception as e:
- await q.put(None)
- print(f"Error in data producer due to: {e}")
- return
- await q.put(None)
-
- async def stream(self, chat, system_call, user_call):
- st = time.time()
- q = asyncio.Queue(maxsize=10)
-
- # 启动生产者任务
- producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call))
- first_token_reach = False
- while True:
- data = await q.get()
- if data is None:
- break
- if not first_token_reach:
- first_token_reach = True
- print(f"大模型回复第一个字耗时 = {time.time() - st}")
- for char in data:
- yield "data: " + json.dumps({'content': char}, ensure_ascii=False) + '\n\n'
- await asyncio.sleep(0.03) # 使用异步 sleep
-
- yield "data: [DONE]"
- print(f"大模型回复耗时 = {time.time() - st}")
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+import asyncio
+import time
+import re
+import json
+import tiktoken
+from langchain_openai import ChatOpenAI
+from langchain.schema import SystemMessage, HumanMessage
+
+
+class LLM:
+ def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1):
+ self.openai_api_key = openai_api_key
+ self.openai_api_base = openai_api_base
+ self.model_name = model_name
+ self.max_tokens = max_tokens
+ self.request_timeout = request_timeout
+ self.temperature = temperature
+ self.client = ChatOpenAI(model_name=model_name,
+ openai_api_base=openai_api_base,
+ openai_api_key=openai_api_key,
+ request_timeout=request_timeout,
+ max_tokens=max_tokens,
+ temperature=temperature)
+
+ def assemble_chat(self, chat=None, system_call='', user_call=''):
+ if chat is None:
+ chat = []
+ chat.append(SystemMessage(content=system_call))
+ chat.append(HumanMessage(content=user_call))
+ return chat
+
+ async def nostream(self, chat, system_call, user_call, st_str: str = None, en_str: str = None):
+ try:
+ chat = self.assemble_chat(chat, system_call, user_call)
+ response = await self.client.ainvoke(chat)
+ content = re.sub(r'.*?\n?', '', response.content, flags=re.DOTALL)
+ content = re.sub(r'.*?\n?', '', content, flags=re.DOTALL)
+ content = content.strip()
+ if st_str is not None:
+ index = content.find(st_str)
+ if index != -1:
+ content = content[index:]
+ if en_str is not None:
+ index = content[::-1].find(en_str[::-1])
+ if index != -1:
+ content = content[:len(content)-index]
+ except Exception as e:
+ err = f"[LLM] 非流式输出异常: {e}"
+ print("[LLM] %s", err)
+ return ''
+ return content
+
+ async def data_producer(self, q: asyncio.Queue, history, system_call, user_call):
+ message = self.assemble_chat(history, system_call, user_call)
+ try:
+ async for frame in self.client.astream(message):
+ await q.put(frame.content)
+ except Exception as e:
+ await q.put(None)
+ err = f"[LLM] 流式输出生产者任务异常: {e}"
+ print("[LLM] %s", err)
+ raise e
+ await q.put(None)
+
+ async def stream(self, chat, system_call, user_call):
+ st = time.time()
+ q = asyncio.Queue(maxsize=10)
+
+ # 启动生产者任务
+ producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call))
+ first_token_reach = False
+ enc = tiktoken.encoding_for_model("gpt-4")
+ input_tokens = len(enc.encode(system_call))
+ output_tokens = 0
+ while True:
+ data = await q.get()
+ if data is None:
+ break
+ if not first_token_reach:
+ first_token_reach = True
+ print(f"大模型回复第一个字耗时 = {time.time() - st}")
+ output_tokens += len(enc.encode(data))
+ yield "data: " + json.dumps(
+ {'content': data,
+ 'input_tokens': input_tokens,
+ 'output_tokens': output_tokens
+ }, ensure_ascii=False
+ ) + '\n\n'
+ await asyncio.sleep(0.03) # 使用异步 sleep
+
+ yield "data: [DONE]"
+ print(f"大模型回复耗时 = {time.time() - st}")
diff --git a/test/prompt.yaml b/test/prompt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25ff89c0a46ec731936edb7dde5bb728e5069036
--- /dev/null
+++ b/test/prompt.yaml
@@ -0,0 +1,708 @@
+ACC_ANALYSIS_RESULT_MERGE_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to combine two analysis results and output a new one. Note:
+ #01 Please combine the content of the two analysis results to produce a new analysis result.
+ #02 Please analyze using the four metrics of recall, precision, faithfulness, and interpretability.
+ #03 The new analysis result must be no longer than 500 characters.
+ #04 Please output only the new analysis result; do not output any other content.
+ Example:
+ Input 1:
+ Analysis Result 1:
+ Recall: Currently, the recall is 95.00, with room for improvement. We will optimize the vectorized search algorithm to further mine information in the original fragment that is relevant to the question but not retrieved, such as some specific practical cases in the openEuler ecosystem. The embedding model bge-m3 will be adjusted to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer.
+ Accuracy: The accuracy is 99.00, which is quite high. However, further optimization is possible, including deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this can more precisely illustrate the specific manifestations of OpenEuler's high performance in cloud computing and edge computing.
+ Fidelity: The fidelity value is 90.00, indicating that some answers are not fully derived from the retrieved snippets. Optimizing the rag retrieval algorithm, improving the recall rate of the embedding model, and adjusting the text chunk size to 512 may be inappropriate and require re-evaluation based on the content. This ensures that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippets. For example, regarding the development of the openEuler ecosystem, relevant technical details should be obtained from the retrieved snippets.
+ Interpretability: The interpretability is 85.00, which is relatively low. Improve the compliance of the large model qwen2.5-32b and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieved snippets better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this makes the answer logic clearer and more targeted, improving overall interpretability.
+
+ Analysis Result 2:
+ The recall rate is currently 95.00. Further optimization of the rag retrieval algorithm and embedding model can be used to increase the semantic similarity between the generated answers and the standard answers, approaching or achieving a higher recall rate. For example, the algorithm can be continuously optimized to better match relevant snippets.
+ The precision is 99.00, close to the maximum score, indicating that the generated answers are semantically similar to the questions. However, further improvement is possible. This can be achieved by refining the embedding model to better understand the question semantics, optimizing the contextual completeness of the retrieved snippets, and reducing fluctuations in precision caused by insufficient context.
+ The faithfulness score is currently 90.00, indicating that some content in the generated answer is not fully derived from the retrieved snippet. The rag retrieval algorithm can be optimized to improve its recall rate. The text chunk size can also be adjusted appropriately to ensure that the retrieved snippet fully answers the question, thereby improving the faithfulness score.
+ Regarding interpretability, it is currently 85.00, indicating that the generated answer has room for improvement in terms of answering questions. On the one hand, the large model used can be optimized to improve its compliance, making the generated answer more accurate. On the other hand, the recall rates of the rag retrieval algorithm and embedding model can be further optimized to ensure that the retrieved snippet fully supports the answer and improve interpretability.
+
+ Output:
+ Recall: Currently at 95.00, there is room for improvement. The vectorized retrieval algorithm can be optimized to further uncover information in the original snippet that is relevant to the question but not retrieved, as demonstrated in some specific practical cases within the openEuler ecosystem. Adjustments were made to the embedding model bge-m3 to enable it to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and bring the generated answers closer to the standard answer.
+ Accuracy: The accuracy reached 99.00, which is already high. However, further optimization is needed to conduct deeper semantic analysis of the retrieved snippets. By combining the features of the large model qwen2.5-32b, this can precisely match the question semantics and avoid subtle semantic deviations. For example, this could more accurately demonstrate the specific characteristics of OpenEuler's high performance in cloud computing and edge computing.
+ Fidelity: The fidelity value was 90.00, indicating that some answer content was not fully derived from the retrieved snippet. The rag retrieval algorithm was optimized to improve the recall of the embedding model. Adjusting the text chunk size to 512 may be unreasonable and requires re-evaluation based on the content to ensure that the retrieved snippets contain sufficient context to support the answer, ensuring that the generated answer content is fully derived from the retrieved snippet. For example, relevant technical details regarding the development of the OpenEuler ecosystem should be obtained from the retrieved snippet.
+ Interpretability: The interpretability value was 85.00, which is relatively low. Improve the compliance of the large qwen2.5-32b model and optimize the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that retrieval fragments can better support answer generation and clearly answer questions. For example, when answering questions related to OpenEuler, this improves answer logic, makes it more targeted, and improves overall interpretability.
+
+ The following two analysis results:
+ Analysis Result 1: {analysis_result_1}
+ Analysis Result 2: {analysis_result_2}
+
+ 中文: |
+ 你是一个文本分析专家,你的任务融合两条分析结果输出一份新的分析结果。注意:
+ #01 请根据两条分析结果中的内容融合出一条新的分析结果
+ #02 请结合召回率、精确度、忠实值和可解释性四个指标进行分析
+ #03 新的分析结果长度不超过500字
+ #04 请仅输出新的分析结果,不要输出其他内容
+ 例子:
+ 输入1:
+ 分析结果1:
+ 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
+ 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
+ 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
+ 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
+
+ 分析结果2:
+ 从召回率来看,目前为 95.00,可进一步优化 rag 检索算法和 embedding 模型,以提高生成答案与标准回答之间的语义相似程度,接近或达到更高的召回率,例如可以持续优化算法来更好地匹配相关片段。
+ 从精确度来看,为 99.00,接近满分,说明生成的答案与问题语义相似程度较高,但仍可进一步提升,可通过完善 embedding 模型来更好地理解问题语义,优化检索到的片段的上下文完整性,减少因上下文不足导致的精确度波动。
+ 对于忠实值,目前为 90.00,说明生成的答案中部分内容未完全来自检索到的片段。可优化 rag 检索算法,提高其召回率,同时合理调整文本分块大小,确保检索到的片段能充分回答问题,从而提高忠实值。
+ 关于可解释性,当前为 85.00,说明生成的答案在用于回答问题方面有一定提升空间。一方面可以优化使用的大模型,提高其遵从度,使其生成的答案更准确地回答问题;另一方面,继续优化 rag 检索算法和 embedding 模型的召回率,保证检索到的片段能全面支撑问题的回答,提高可解释性。
+
+ 输出:
+ 召回率:目前召回率为 95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如 openEuler 生态中一些具体实践案例等。调整 embedding 模型 bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。
+ 精确度:精确度达 99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型 qwen2.5-32b 的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述 openEuler 在云计算和边缘计算中高性能等特性的具体表现。
+ 忠实值:忠实值为 90.00,说明部分答案内容未完全源于检索片段。优化 rag 检索算法,提高 embedding 模型召回率,调整文本分块大小为 512 可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于 openEuler 生态建设中相关技术细节应从检索片段获取。
+ 可解释性:可解释性为 85.00,相对较低。提升大模型 qwen2.5-32b 的遵从度,优化 rag 检索算法和 embedding 模型 bge-m3 的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答 openEuler 相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
+
+ 下面两条分析结果:
+ 分析结果1:{analysis_result_1}
+ 分析结果2:{analysis_result_2}
+
+ACC_RESULT_ANALYSIS_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to: analyze the large model used in the test, the embedding model used in the test, the parsing method and chunk size of related documents, the snippets matched by the RAG algorithm for a single test result, and propose methods to improve the accuracy of question-answering in the current knowledge base.
+
+ The test results include the following information:
+ - Question: The question used in the test
+ - Standard answer: The standard answer used in the test
+ - Generated answer: The answer output by the large model in the test results
+ - Original snippet: The original snippet provided in the test results
+ - Retrieved snippet: The snippet retrieved by the RAG algorithm in the test results
+
+ The four evaluation metrics are defined as follows:
+ - Precision: Evaluates the semantic similarity between the generated answer and the question. A lower score indicates lower compliance of the large model; additionally, it may mean the snippets retrieved by the RAG algorithm lack context and are insufficient to support the answer.
+ - Recall: Evaluates the semantic similarity between the generated answer and the standard answer. A lower score indicates lower compliance of the large model.
+ - Fidelity: Evaluates whether the content of the generated answer is derived from the retrieved snippet. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean the text chunk size is inappropriate.
+ - Interpretability: Evaluates whether the generated answer is useful for answering the question. A lower score indicates lower recall of the RAG retrieval algorithm and embedding model (resulting in retrieved snippets insufficient to answer the question); additionally, it may mean lower compliance of the used large model.
+
+ Notes:
+ #01 Analyze methods to improve the accuracy of current knowledge base question-answering based on the test results.
+ #02 Conduct the analysis using the four metrics: Recall, Precision, Fidelity, and Interpretability.
+ #03 The analysis result must not exceed 500 words.
+ #04 Output only the analysis result; do not include any other content.
+
+ Example:
+ Input:
+ Model name: qwen2.5-32b
+ Embedding model: bge-m3
+ Text chunk size: 512
+ Used RAG algorithm: Vectorized retrieval
+ Question: What is OpenEuler?
+ Standard answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Generated answer: OpenEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Original snippet: openEuler is an open source operating system incubated and operated by the Open Atom Open Source Foundation. Its mission is to build an open source operating system ecosystem for digital infrastructure and provide solid underlying support for cutting-edge fields such as cloud computing and edge computing. In cloud computing scenarios, openEuler can fully optimize resource scheduling and allocation mechanisms. Through a lightweight kernel design and efficient virtualization technology, it significantly improves the responsiveness and throughput of cloud services. In edge computing, its exceptional low resource consumption and real-time processing capabilities ensure the timeliness and accuracy of data processing at edge nodes in complex environments. openEuler boasts a series of exceptional features: In terms of performance, its independently developed intelligent scheduling algorithm dynamically adapts to different load scenarios, and combined with deep optimization of hardware resources, significantly improves system efficiency. Regarding security, its built-in multi-layered security system, including mandatory access control, vulnerability scanning, and remediation mechanisms, provides a solid defense for system data and applications. Regarding reliability, its distributed storage, automatic fault detection, and rapid recovery technologies ensure stable system operation in the face of unexpected situations such as network fluctuations and hardware failures, minimizing the risk of service interruptions. These features make openEuler a crucial technological cornerstone for promoting high-quality development of the digital economy, helping enterprises and developers seize the initiative in digital transformation.
+ Retrieved snippet: As a pioneer in the open source operating system field, openEuler deeply integrates the wisdom of community developers and continuously iterates and upgrades to adapt to the rapidly changing technological environment. In the current era of prevalent microservices architectures, openEuler Through deep optimization of containerization technology and support for mainstream orchestration tools such as Kubernetes, it makes application deployment and management more convenient and efficient, significantly enhancing the flexibility of enterprise business deployments. At the same time, it actively embraces the AI era. By adapting and optimizing machine learning frameworks, it provides powerful computing power for AI model training and inference, effectively reducing the development and operating costs of AI applications. Regarding ecosystem development, openEuler boasts a large and active open source community, bringing together technology enthusiasts and industry experts from around the world, forming a complete ecosystem from kernel development and driver adaptation to application optimization. The community regularly hosts technical exchanges and developer conferences to promote knowledge sharing and technological innovation, providing developers with a wealth of learning resources and practical opportunities. Numerous hardware and software manufacturers have joined the openEuler ecosystem, launching solutions and products based on the system across key industries such as finance, telecommunications, and energy. These efforts, validated through real-world application scenarios and feeding back into openEuler's technological development, have fostered a virtuous cycle of innovation, making openEuler not just an operating system but a powerful engine driving collaborative industry development.
+ Recall: 95.00
+ Precision: 99.00
+ Fidelity: 90.00
+ Interpretability: 85.00
+
+ Output:
+ Based on the test results, methods for improving the accuracy of current knowledge base question-answering can be analyzed from the following aspects: Recall: The current recall is 95.00, with room for improvement. Optimize the vectorized retrieval algorithm to further mine question-related but unretrieved information in the original snippets, such as some specific practical cases in the openEuler ecosystem. Adjust the embedding model bge-m3 to more comprehensively and accurately capture semantics, expand the search scope, improve recall, and make the generated answers closer to the standard answer. Precision: The accuracy reached 99.00, which is already high. However, further optimization is possible, including deeper semantic analysis of retrieved snippets. By combining the features of the large model qwen2.5-32b, this can accurately match the question semantics and avoid subtle semantic deviations. For example, more precise demonstration of openEuler's high performance in cloud computing and edge computing can be achieved. Fidelity: A fidelity score of 90.00 indicates that some answers are not fully derived from the search snippet. We optimized the rag retrieval algorithm, improved the recall of the embedding model, and adjusted the text chunk size to 512. This may be inappropriate and requires reassessment based on the content. We need to ensure that the search snippet contains sufficient context to support the answer, ensuring that the generated answer content is derived from the search snippet. For example, relevant technical details regarding the development of the openEuler ecosystem should be obtained from the search snippet. Interpretability: The interpretability score is 85.00, which is relatively low. We improved the compliance of the large model qwen2.5-32b and optimized the recall of the rag retrieval algorithm and the embedding model bge-m3. This ensures that the search snippet better supports answer generation and clearly answers the question. For example, when answering openEuler-related questions, the answer logic is made clearer and more targeted, improving overall interpretability.
+
+ The following is the test result content:
+ Used large model: {model_name}
+ Embedding model: {embedding_model}
+ Text chunk size: {chunk_size}
+ Used RAG parsing algorithm: {rag_algorithm}
+ Question: {question}
+ Standard answer: {standard_answer}
+ Generated answer: {generated_answer}
+ Original fragment: {original_fragment}
+ Retrieved fragment: {retrieved_fragment}
+ Recall: {recall}
+ Precision: {precision}
+ Faithfulness: {faithfulness}
+ Interpretability: {relevance}
+
+ 中文: |
+ 你是一个文本分析专家,你的任务是:根据给出的测试使用的大模型、embedding模型、测试相关文档的解析方法和分块大小、单条测试结果分析RAG算法匹配到的片段,并分析当前知识库问答准确率的提升方法。
+
+ 测试结果包含以下内容:
+ - 问题:测试使用的问题
+ - 标准答案:测试使用的标准答案
+ - 生成的答案:测试结果中大模型输出的答案
+ - 原始片段:测试结果中的原始片段
+ - 检索的片段:测试结果中RAG算法检索到的片段
+
+ 四个评估指标定义如下:
+ - 精确率:评估生成的答案与问题之间的语义相似程度。评分越低,说明使用的大模型遵从度越低;其次可能是RAG检索到的片段缺少上下文,不足以支撑问题的回答。
+ - 召回率:评估生成的答案与标准回答之间的语义相似程度。评分越低,说明使用的大模型遵从度越低。
+ - 忠实值:评估生成的答案中的内容是否来自于检索到的片段。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是文本分块大小不合理。
+ - 可解释性:评估生成的答案是否能用于回答问题。评分越低,说明RAG检索算法和embedding模型的召回率越低(导致检索到的片段不足以回答问题);其次可能是使用的大模型遵从度越低。
+
+ 注意:
+ #01 请根据测试结果中的内容分析当前知识库问答准确率的提升方法。
+ #02 请结合召回率、精确率、忠实值和可解释性四个指标进行分析。
+ #03 分析结果长度不超过500字。
+ #04 请仅输出分析结果,不要输出其他内容。
+
+ 例子:
+ 输入:
+ 模型名称:qwen2.5-32b
+ embedding模型:bge-m3
+ 文本的分块大小:512
+ 使用解析的RAG算法:向量化检索
+ 问题:openEuler是什么操作系统?
+ 标准答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 生成的答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 原始片段:openEuler是由开放原子开源基金会孵化及运营的开源操作系统,以构建面向数字基础设施的开源操作系统生态为使命,致力于为云计算、边缘计算等前沿领域提供坚实的底层支持。在云计算场景中,openEuler能够充分优化资源调度与分配机制,通过轻量化的内核设计和高效的虚拟化技术,显著提升云服务的响应速度与吞吐量;在边缘计算领域,它凭借出色的低资源消耗特性与实时处理能力,保障了边缘节点在复杂环境下数据处理的及时性与准确性。openEuler具备一系列卓越特性:在性能方面,其自主研发的智能调度算法能够动态适配不同负载场景,结合对硬件资源的深度优化利用,大幅提升系统运行效率;安全性上,通过内置的多层次安全防护体系,包括强制访问控制、漏洞扫描与修复机制,为系统数据与应用程序构筑起坚实的安全防线;可靠性层面,基于分布式存储、故障自动检测与快速恢复技术,确保系统在面对网络波动、硬件故障等突发状况时,依然能够稳定运行,最大限度降低服务中断风险。这些特性使openEuler成为推动数字经济高质量发展的重要技术基石,助力企业与开发者在数字化转型进程中抢占先机。
+ 检索的片段:openEuler作为开源操作系统领域的先锋力量,深度融合了社区开发者的智慧结晶,不断迭代升级以适应快速变化的技术环境。在微服务架构盛行的当下,openEuler通过对容器化技术的深度优化,支持Kubernetes等主流编排工具,让应用部署与管理变得更加便捷高效,极大提升了企业的业务部署灵活性。同时,它积极拥抱AI时代,通过对机器学习框架的适配与优化,为AI模型训练和推理提供强大的算力支撑,有效降低了AI应用的开发与运行成本。在生态建设方面,openEuler拥有庞大且活跃的开源社区,汇聚了来自全球的技术爱好者与行业专家,形成了从内核开发、驱动适配到应用优化的完整生态链。社区定期举办技术交流与开发者大会,推动知识共享与技术创新,为开发者提供了丰富的学习资源与实践机会。众多硬件厂商和软件企业纷纷加入openEuler生态,推出基于该系统的解决方案和产品,涵盖金融、电信、能源等关键行业,以实际应用场景验证并反哺openEuler的技术发展,形成了良性循环的创新生态,让openEuler不仅是一个操作系统,更成为推动产业协同发展的强大引擎。
+ 召回率:95.00
+ 精确率:99.00
+ 忠实值:90.00
+ 可解释性:85.00
+
+ 输出:
+ 根据测试结果中的内容,当前知识库问答准确率提升的方法可以从以下几个方面进行分析:召回率:目前召回率为95.00,有提升空间。优化向量化检索算法,进一步挖掘原始片段中与问题相关但未被检索到的信息,如openEuler生态中一些具体实践案例等。调整embedding模型bge-m3,使其能更全面准确地捕捉语义,扩大检索范围,提高召回率,使生成答案更接近标准答案。精确率:精确率达99.00,已较高。但可进一步优化,对检索到的片段进行更深入的语义分析,结合大模型qwen2.5-32b的特点,精准匹配问题语义,避免细微语义偏差,例如更精确阐述openEuler在云计算和边缘计算中高性能等特性的具体表现。忠实值:忠实值为90.00,说明部分答案内容未完全源于检索片段。优化RAG检索算法,提高embedding模型召回率,文本分块大小为512可能存在不合理,需根据内容重新评估,确保检索片段包含足够上下文以支撑答案,使生成答案内容均来自检索片段,如关于openEuler生态建设中相关技术细节应从检索片段获取。可解释性:可解释性为85.00,相对较低。提升大模型qwen2.5-32b的遵从度,优化RAG检索算法和embedding模型bge-m3的召回率,使检索片段能更好支撑生成答案,保证答案能清晰回答问题,例如在回答openEuler相关问题时,使答案逻辑更清晰、针对性更强,提高整体可解释性。
+
+ 下面是测试结果中的内容:
+ 使用的大模型:{model_name}
+ embedding模型:{embedding_model}
+ 文本的分块大小:{chunk_size}
+ 使用解析的RAG算法:{rag_algorithm}
+ 问题:{question}
+ 标准答案:{standard_answer}
+ 生成的答案:{generated_answer}
+ 原始片段:{original_fragment}
+ 检索的片段:{retrieved_fragment}
+ 召回率:{recall}
+ 精确率:{precision}
+ 忠实值:{faithfulness}
+ 可解释性:{relevance}
+
+ANSWER_TO_ANSWER_PROMPT:
+ # 英文文本相似度评分提示词
+ en: |
+ You are a text analysis expert. Your task is to compare the similarity between two documents and output a score between 0 and 100 with two decimal places.
+
+ Note:
+ #01 Score based on text similarity in three dimensions: semantics, word order, and keywords.
+ #02 If the core expressions of the two documents are consistent, the score will be relatively high.
+ #03 If one document contains the core content of the other, the score will also be relatively high.
+ #04 If there is content overlap between the two documents, the score will be determined by the proportion of the overlap.
+ #05 Output only the score (no other content).
+
+ Example 1:
+ Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Text 2: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: 100.00
+
+ Example 2:
+ Input - Text 1: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Text 2: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance and high security.
+ Output: 90.00
+
+ Example 3:
+ Input - Text 1: openEuler is an open-source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Text 2: A white horse is not a horse
+ Output: 00.00
+
+ The following are the given texts:
+ Text 1: {text_1}
+ Text 2: {text_2}
+
+ # 中文文本相似度评分提示词
+ 中文: |
+ 你是一个文本分析专家,你的任务是对比两个文本之间的相似度,并输出一个 0-100 之间的分数(保留两位小数)。
+
+ 注意:
+ #01 请根据文本在语义、语序和关键字三个维度的相似度进行打分。
+ #02 如果两个文本在核心表达上一致,那么分数将相对较高。
+ #03 如果一个文本包含另一个文本的核心内容,那么分数也将相对较高。
+ #04 如果两个文本间存在内容重合,那么将按照重合内容的比例确定分数。
+ #05 仅输出分数,不要输出其他任何内容。
+
+ 例子 1:
+ 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:100.00
+
+ 例子 2:
+ 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 文本 2:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能和高安全性等特点。
+ 输出:90.00
+
+ 例子 3:
+ 输入 - 文本 1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 文本 2:白马非马
+ 输出:00.00
+
+ 下面是给出的文本:
+ 文本 1:{text_1}
+ 文本 2:{text_2}
+
+CAL_QA_SCORE_PROMPT:
+ en: >-
+ You are a text analysis expert. Your task is to evaluate the questions and answers generated from a given fragment, and assign a score between 0 and 100 (retaining two decimal places). Please evaluate based on the following criteria:
+
+ ### 1. Question Evaluation
+ - **Relevance**: Is the question closely related to the topic of the given fragment? Is it accurately based on the fragment content? Does it deviate from or distort the core message of the fragment?
+ - **Plausibility**: Is the question formulated clearly and logically coherently? Does it conform to normal language and thinking habits? Is it free of semantic ambiguity, vagueness, or self-contradiction?
+ - **Variety**: If there are multiple questions, are their angles and types sufficiently varied to avoid being overly monotonous or repetitive? Can they explore the fragment content from different perspectives?
+ - **Difficulty**: Is the question difficulty appropriate? Not too easy (where answers can be directly copied from the fragment), nor too difficult (where respondents cannot find clues or evidence from the fragment)?
+
+ ### 2. Answer Evaluation
+ - **Accuracy**: Does the answer accurately address the question? Is it consistent with the information in the fragment? Does it contain errors or omit key points?
+ - **Completeness**: Is the answer comprehensive, covering all aspects of the question? For questions requiring elaboration, does it provide sufficient details and explanations?
+ - **Succinctness**: On the premise of ensuring completeness and accuracy, is the answer concise and clear? Does it avoid lengthy or redundant expressions, and convey key information in concise language?
+ - **Coherence**: Is the answer logically clear? Are transitions between content sections natural and smooth? Are there any jumps or confusion?
+
+ ### 3. Overall Assessment
+ - **Consistency**: Do the question and answer match each other? Does the answer address the raised question? Are they consistent in content and logic?
+ - **Integration**: Does the answer effectively integrate information from the fragment? Is it not just a simple excerpt, but rather an integrated, refined presentation in a logical manner?
+ - **Innovation**: In some cases, evaluate whether the answer demonstrates innovation or unique insights? Does it appropriately expand or deepen the information in the fragment?
+
+ ### Note
+ #01 Please output only the score (without any other content).
+
+ ### Example
+ Input 1:
+ Question: What operating system is openEuler?
+ Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Snippet: openEuler is an open source operating system designed to support cloud and edge computing. It features high performance, high security, and high reliability.
+ Output 1: 100.00
+
+ Below is the given question, answer, and snippet:
+ Question: {question}
+ Answer: {answer}
+ Snippet: {fragment}
+ 中文: >-
+ 你是文本分析专家,任务是评估由给定片段生成的问题与答案,输出 0-100 之间的分数(保留两位小数)。请根据以下标准进行评估:
+
+ ### 1. 问题评估
+ - **相关性**:问题是否与给定片段的主题紧密相关?是否准确基于片段内容提出?有无偏离或曲解片段的核心信息?
+ - **合理性**:问题表述是否清晰、逻辑连贯?是否符合正常的语言表达和思维习惯?不存在语义模糊、歧义或自相矛盾的情况?
+ - **多样性**:若存在多个问题,问题之间的角度和类型是否具有足够多样性(避免过于单一或重复)?能否从不同方面挖掘片段内容?
+ - **难度**:问题难度是否适中?既不过于简单(答案可直接从片段中照搬),也不过于困难(回答者难以从片段中找到线索或依据)?
+
+ ### 2. 答案评估
+ - **准确性**:答案是否准确无误地回答了问题?与片段中的信息是否一致?有无错误或遗漏关键要点?
+ - **完整性**:答案是否完整,涵盖问题涉及的各个方面?对于需要详细阐述的问题,是否提供了足够的细节和解释?
+ - **简洁性**:在保证回答完整、准确的前提下,答案是否简洁明了?是否避免冗长、啰嗦的表述,能否以简洁语言传达关键信息?
+ - **连贯性**:答案逻辑是否清晰?各部分内容之间的衔接是否自然流畅?有无跳跃或混乱的情况?
+
+ ### 3. 整体评估
+ - **一致性**:问题与答案之间是否相互匹配?答案是否针对所提出的问题进行回答?两者在内容和逻辑上是否保持一致?
+ - **融合性**:答案是否能很好地融合片段中的信息?是否并非简单摘抄,而是经过整合、提炼后以合理方式呈现?
+ - **创新性**:在某些情况下,评估答案是否具有一定创新性或独特见解?是否能在片段信息基础上进行适当拓展或深入思考?
+
+ ### 注意事项
+ #01 请仅输出分数,不要输出其他内容。
+
+ ### 示例
+ 输入 1:
+ 问题:openEuler 是什么操作系统?
+ 答案:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 片段:openEuler 是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出 1:100.00
+
+ 下面是给出的问题、答案和片段:
+ 问题:{question}
+ 答案:{answer}
+ 片段:{fragment}
+
+CHUNK_QUERY_MATCH_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to determine whether a given fragment is relevant to a question.
+ Note:
+ #01 If the fragment is relevant, output YES.
+ #02 If the fragment is not relevant, output NO.
+ #03 Only output YES or NO, and do not output anything else.
+
+ Example:
+ Input 1:
+ Fragment: openEuler is an open source operating system.
+ Question: What kind of operating system is openEuler?
+ Output 1: YES
+
+ Input 2:
+ Fragment: A white horse is not a horse.
+ Question: What kind of operating system is openEuler?
+ Output 2: NO
+
+ Here are the given fragment and question:
+ Fragment: {chunk}
+ Question: {question}
+ 中文: |
+ 你是一个文本分析专家,你的任务是根据给出的片段和问题,判断片段是否与问题相关。
+ 注意:
+ #01 如果片段与问题相关,请输出YES;
+ #02 如果片段与问题不相关,请输出NO;
+ #03 请仅输出YES或NO,不要输出其他内容。
+
+ 例子:
+ 输入1:
+ 片段:openEuler是一个开源的操作系统。
+ 问题:openEuler是什么操作系统?
+ 输出1:YES
+
+ 输入2:
+ 片段:白马非马
+ 问题:openEuler是什么操作系统?
+ 输出2:NO
+
+ 下面是给出的片段和问题:
+ 片段:{chunk}
+ 问题:{question}
+
+CONTENT_TO_ABSTRACT_PROMPT:
+ en: |
+ You are a text summarization expert. Your task is to generate a new English summary based on a given text and an existing summary.
+ Note:
+ #01 Please combine the most important content from the text and the existing summary to generate the new summary.
+ #02 The length of the new summary must be greater than 200 words and less than 500 words.
+ #03 Please only output the new English summary; do not output any other content.
+
+ Example:
+ Input 1:
+ Text: openEuler features high performance, high security, and high reliability.
+ Abstract: openEuler is an open source operating system designed to support cloud computing and edge computing.
+ Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing. openEuler features high performance, high security, and high reliability.
+
+ Below is the given text and summary:
+ Text: {content}
+ Abstract: {abstract}
+ 中文: |
+ 你是一个文本摘要专家,你的任务是根据给出的文本和已有摘要,生成一个新的中文摘要。
+ 注意:
+ #01 请结合文本和已有摘要中最重要的内容,生成新的摘要;
+ #02 新的摘要长度必须大于200字且小于500字;
+ #03 请仅输出新的中文摘要,不要输出其他内容。
+
+ 例子:
+ 输入1:
+ 文本:openEuler具有高性能、高安全性和高可靠性等特点。
+ 摘要:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。
+ 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。openEuler具有高性能、高安全性和高可靠性等特点。
+
+ 下面是给出的文本和摘要:
+ 文本:{content}
+ 摘要:{abstract}
+
+CONTENT_TO_STATEMENTS_PROMPT:
+ en: |
+ You are a text parsing expert. Your task is to extract multiple English statements from a given text and return them as a list.
+
+ Note:
+ #01 Statements must be derived from key points in the text.
+ #02 Statements must be arranged in relative order.
+ #03 Each statement must be at least 20 characters long and no more than 50 characters long.
+ #04 The total number of statements output must not exceed three.
+ #05 Please output only the list of statements, not any other content. Each statement must be in English.
+ Example:
+
+ Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ]
+
+ The following is the given text: {content}
+ 中文: |
+ 你是一个文本分解专家,你的任务是根据我给出的文本,将文本提取为多个中文陈述,陈述使用列表形式返回
+
+ 注意:
+ #01 陈述必须来源于文本中的重点内容
+ #02 陈述按相对顺序排列
+ #03 输出的单个陈述长度不少于20个字,不超过50个字
+ #04 输出的陈述总数不超过3个
+ #05 请仅输出陈述列表,不要输出其他内容,且每一条陈述都是中文。
+ 例子:
+
+ 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ]
+
+ 下面是给出的文本: {content}
+
+CONTENT_TO_TITLE_PROMPT:
+ en: >-
+ You are a title extraction expert. Your task is to generate an English title based on the given text.
+ Note:
+ #01 The title must be derived from the content of the text.
+ #02 The title must be no longer than 20 characters.
+ #03 Please output only the English title, and do not output any other content.
+ #04 If the given text is insufficient to generate a title, output "Unable to generate title."
+ Example:
+ Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: Overview of the openEuler operating system.
+ Below is the given text: {content}
+ 中文: >-
+ 你是一个标题提取专家,你的任务是根据给出的文本生成一个中文标题。
+ 注意:
+ #01 标题必须来源于文本中的内容
+ #02 标题长度不超过20个字
+ #03 请仅输出中文标题,不要输出其他内容
+ #04 如果给出的文本不够生成标题,请输出“无法生成标题”
+ 例子:
+ 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:openEuler操作系统概述
+ 下面是给出的文本:{content}
+
+GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to generate an English answer based on a given question and text.
+ Note:
+ #01 The answer must be derived from the content in the text.
+ #02 The answer must be at least 50 words and no more than 500 words.
+ #03 Please only output the English answer; do not output any other content.
+ Example:
+ Input 1:
+ Question: What kind of operating system is openEuler?
+ Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 1: openEuler is an open source operating system designed to support cloud computing and edge computing.
+
+ Input 2:
+ Question: How secure is openEuler?
+ Text: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 2: openEuler is highly secure.
+
+ Below is the given question and text:
+ Question: {question}
+ Text: {content}
+ 中文: |
+ 你是一个文本分析专家,你的任务是根据给出的问题和文本生成中文答案。
+ 注意:
+ #01 答案必须来源于文本中的内容;
+ #02 答案长度不少于50字且不超过500个字;
+ #03 请仅输出中文答案,不要输出其他内容。
+ 例子:
+ 输入1:
+ 问题:openEuler是什么操作系统?
+ 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出1:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。
+
+ 输入2:
+ 问题:openEuler的安全性如何?
+ 文本:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出2:openEuler具有高安全性。
+
+ 下面是给出的问题和文本:
+ 问题:{question}
+ 文本:{content}
+
+GENERATE_QUESTION_FROM_CONTENT_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to generate {k} English questions based on the given text and return them as a list.
+ Note:
+ #01 Questions must be derived from the content of the text.
+ #02 A single question must not exceed 50 characters.
+ #03 Do not output duplicate questions.
+ #04 The output questions should be diverse, covering different aspects of the text.
+ #05 Please only output a list of English questions, not other content.
+ Example:
+ Input: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: ["What is openEuler?","What fields does openEuler support?","What are the characteristics of openEuler?","How secure is openEuler?","How reliable is openEuler?"]
+ The following is the given text: {content}
+ 中文: |
+ 你是一个文本分析专家,你的任务是根据给出的文本生成{k}个中文问题并用列表返回。
+ 注意:
+ #01 问题必须来源于文本中的内容;
+ #02 单个问题长度不超过50个字;
+ #03 不要输出重复的问题;
+ #04 输出的问题要多样,覆盖文本中的不同方面;
+ #05 请仅输出中文问题列表,不要输出其他内容。
+ 例子:
+ 输入:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:["openEuler是什么操作系统?","openEuler旨在为哪个领域提供支持?","openEuler具有哪些特点?","openEuler的安全性如何?","openEuler的可靠性如何?"]
+ 下面是给出的文本:{content}
+
+OCR_ENHANCED_PROMPT:
+ en: |
+ You are an expert in image OCR content summarization. Your task is to describe the image based on the context I provide, descriptions of adjacent images, a summary of the previous OCR result for the current image, and the partial OCR results (including text and relative coordinates).
+
+ Note:
+ #01 The image content must be described in detail, using at least 200 and no more than 500 words. Detailed data listing is acceptable.
+ #02 If this diagram is a flowchart, please describe the content in the order of the flowchart.
+ #03 If this diagram is a table, please output the table content in Markdown format.
+ #04 If this diagram is an architecture diagram, please describe the content according to the hierarchy of the architecture diagram.
+ #05 The summarized image description must include the key information in the image; it cannot simply describe the image's location.
+ #06 Adjacent text in the image recognition results may be part of the same paragraph. Please merge them before summarizing.
+ #07 The text may be misplaced. Please correct the order before summarizing.
+ #08 Please only output the image summary; do not output any other content.
+ #09 Do not output coordinates or other information; only output a description of the relative position of each part.
+ #10 If the image content is empty, output "Image content is empty."
+ #11 If the image itself is a paragraph of text, output the text content directly.
+ #12 Please use English for the output.
+ Context: {image_related_text}
+ Summary of the OCR content of the previous part of the current image: {pre_part_description}
+ Result of the OCR of the current part of the image: {part}
+ 中文: |
+ 你是一个图片OCR内容总结专家,你的任务是根据我提供的上下文、相邻图片组描述、当前图片上一次的OCR内容总结、当前图片部分OCR的结果(包含文字和文字的相对坐标)给出图片描述。
+
+ 注意:
+ #01 必须使用大于200字小于500字详细描述这个图片的内容,可以详细列出数据。
+ #02 如果这个图是流程图,请按照流程图顺序描述内容。
+ #03 如果这张图是表格,请用Markdown形式输出表格内容。
+ #04 如果这张图是架构图,请按照架构图层次结构描述内容。
+ #05 总结的图片描述必须包含图片中的主要信息,不能只描述图片位置。
+ #06 图片识别结果中相邻的文字可能是同一段落的内容,请合并后总结。
+ #07 文字可能存在错位,请修正顺序后进行总结。
+ #08 请仅输出图片的总结即可,不要输出其他内容。
+ #09 不要输出坐标等信息,输出每个部分相对位置的描述即可。
+ #10 如果图片内容为空,请输出“图片内容为空”。
+ #11 如果图片本身就是一段文字,请直接输出文字内容。
+ #12 请使用中文输出。
+ 上下文:{image_related_text}
+ 当前图片上一部分的OCR内容总结:{pre_part_description}
+ 当前图片部分OCR的结果:{part}
+
+QA_TO_STATEMENTS_PROMPT:
+ en: |
+ You are a text parsing expert. Your task is to extract the answers from the questions and answers I provide into multiple English statements, returning them as a list.
+
+ Note:
+ #01 The statements must be derived from the key points of the answers.
+ #02 The statements must be arranged in relative order.
+ #03 The length of each statement output must not exceed 50 characters.
+ #04 The total number of statements output must not exceed 20.
+ #05 Please only output the list of English statements; do not output any other content.
+
+ Example:
+ Input: Question: What is openEuler? Answer: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output: [ "openEuler is an open source operating system", "openEuler is designed to support cloud computing and edge computing", "openEuler features high performance, high security, and high reliability" ]
+
+ Below are the given questions and answers:
+ Question: {question}
+ Answer: {answer}
+ 中文: |
+ 你是一个文本分解专家,你的任务是根据我给出的问题和答案,将答案提取为多个中文陈述,陈述使用列表形式返回。
+
+ 注意:
+ #01 陈述必须来源于答案中的重点内容
+ #02 陈述按相对顺序排列
+ #03 输出的单个陈述长度不超过50个字
+ #04 输出的陈述总数不超过20个
+ #05 请仅输出中文陈述列表,不要输出其他内容
+
+ 例子:
+ 输入:问题:openEuler是什么操作系统? 答案:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出:[ "openEuler是一个开源的操作系统", "openEuler旨在为云计算和边缘计算提供支持", "openEuler具有高性能、高安全性和高可靠性等特点" ]
+
+ 下面是给出的问题和答案:
+ 问题:{question}
+ 答案:{answer}
+
+QUERY_EXTEND_PROMPT:
+ en: |
+ You are a question expansion expert. Your task is to expand {k} questions based on the given question.
+
+ Note:
+ #01 The content of the expanded question must be derived from the content of the original question.
+ #02 The expanded question length must not exceed 50 characters.
+ #03 Questions can be rewritten by replacing synonyms, swapping word order within the question, changing English capitalization, etc.
+ #04 Please only output the expanded question list, do not output other content.
+
+ Example:
+ Input: What operating system is openEuler?
+ Output: [ "What kind of operating system is openEuler?", "What are the characteristics of the openEuler operating system?", "What are the functions of the openEuler operating system?", "What are the advantages of the openEuler operating system?" ]
+
+ The following is the given question: {question}
+ 中文: |
+ 你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题。
+
+ 注意:
+ #01 扩写的问题的内容必须来源于原问题中的内容
+ #02 扩写的问题长度不超过50个字
+ #03 可以通过近义词替换、问题内词序交换、修改英文大小写等方式来改写问题
+ #04 请仅输出扩写的问题列表,不要输出其他内容
+
+ 例子:
+ 输入:openEuler是什么操作系统?
+ 输出:[ "openEuler是一个什么样的操作系统?", "openEuler操作系统的特点是什么?", "openEuler操作系统有哪些功能?", "openEuler操作系统的优势是什么?" ]
+
+ 下面是给出的问题:{question}
+
+STATEMENTS_TO_FRAGMENT_PROMPT:
+ en: |
+ You are a text expert. Your task is to determine whether a given statement is strongly related to the fragment.
+
+ Note:
+ #01 If the statement is strongly related to the fragment or is derived from the fragment, output YES.
+ #02 If the content in the statement is unrelated to the fragment, output NO.
+ #03 If the statement is a refinement of a portion of the fragment, output YES.
+ #05 Only output YES or NO, and do not output anything else.
+
+ Example:
+ Input 1:
+ Statement: openEuler is an open source operating system.
+ Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 1: YES
+
+ Input 2:
+ Statement: A white horse is not a horse.
+ Fragment: openEuler is an open source operating system designed to support cloud computing and edge computing. It features high performance, high security, and high reliability.
+ Output 2: NO
+
+ Below is a given statement and fragment:
+ Statement: {statement}
+ Fragment: {fragment}
+ 中文: |
+ 你是一个文本专家,你的任务是判断给出的陈述是否与片段强相关。
+
+ 注意:
+ #01 如果陈述与片段强相关或者来自于片段,请输出YES
+ #02 如果陈述中的内容与片段无关,请输出NO
+ #03 如果陈述是片段中某部分的提炼,请输出YES
+ #05 请仅输出YES或NO,不要输出其他内容
+
+ 例子:
+ 输入1:
+ 陈述:openEuler是一个开源的操作系统。
+ 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出1:YES
+
+ 输入2:
+ 陈述:白马非马
+ 片段:openEuler是一个开源的操作系统,旨在为云计算和边缘计算提供支持。它具有高性能、高安全性和高可靠性等特点。
+ 输出2:NO
+
+ 下面是给出的陈述和片段:
+ 陈述:{statement}
+ 片段:{fragment}
+
+STATEMENTS_TO_QUESTION_PROMPT:
+ en: |
+ You are a text analysis expert. Your task is to determine whether a given statement is relevant to a question.
+
+ Note:
+ #01 If the statement is relevant to the question, output YES.
+ #02 If the statement is not relevant to the question, output NO.
+ #03 Only output YES or NO, and do not output anything else.
+ #04 A statement's relevance to the question means that the content in the statement can answer the question or overlaps with the question in terms of content.
+
+ Example:
+ Input 1:
+ Statement: openEuler is an open source operating system.
+ Question: What kind of operating system is openEuler?
+ Output 1: YES
+
+ Input 2:
+ Statement: A white horse is not a horse.
+ Question: What kind of operating system is openEuler?
+ Output 2: NO
+
+ Below is the given statement and question:
+ Statement: {statement}
+ Question: {question}
+ 中文: |
+ 你是一个文本分析专家,你的任务是判断给出的陈述是否与问题相关。
+
+ 注意:
+ #01 如果陈述与问题相关,请输出YES
+ #02 如果陈述与问题不相关,请输出NO
+ #03 请仅输出YES或NO,不要输出其他内容
+ #04 陈述与问题相关是指,陈述中的内容可以回答问题或者与问题在内容上有交集
+
+ 例子:
+ 输入1:
+ 陈述:openEuler是一个开源的操作系统。
+ 问题:openEuler是什么操作系统?
+ 输出1:YES
+
+ 输入2:
+ 陈述:白马非马
+ 问题:openEuler是什么操作系统?
+ 输出2:NO
+
+ 下面是给出的陈述和问题:
+ 陈述:{statement}
+ 问题:{question}
diff --git a/test/requirements.txt b/test/requirements.txt
deleted file mode 100644
index e4b18f8b56f026c78097a3ca6e28fa6671cc9393..0000000000000000000000000000000000000000
--- a/test/requirements.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-jieba==0.42.1
-pandas==2.1.4
-pydantic==2.10.2
-langchain==0.1.16
-langchain-openai==0.1.7
-synonyms==3.23.5
\ No newline at end of file
diff --git a/test/requiremnets.txt b/test/requiremnets.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e32be7d73bf0c41be259e29e5b441978602ba691
--- /dev/null
+++ b/test/requiremnets.txt
@@ -0,0 +1,11 @@
+toml==0.10.2
+pydantic==2.11.7
+urllib3==2.2.1
+requests==2.32.2
+langchain==0.3.7
+langchain-core==0.3.56
+langchain-openai==0.2.5
+tiktoken==0.9.0
+jieba==0.42.1
+numpy==1.26.4
+jieba==0.42.1
\ No newline at end of file
diff --git a/test/result.xlsx b/test/result.xlsx
new file mode 100644
index 0000000000000000000000000000000000000000..4a85ab395df84ef5f5511eb15477e0b2005f5557
Binary files /dev/null and b/test/result.xlsx differ
diff --git a/test/tools/stopwords.txt b/test/stopwords.txt
similarity index 99%
rename from test/tools/stopwords.txt
rename to test/stopwords.txt
index 5784b4462a67442a7301abb939b8ca17fa791598..bfb5f302afa87935686501368c011a0a99de855e 100644
--- a/test/tools/stopwords.txt
+++ b/test/stopwords.txt
@@ -1276,7 +1276,6 @@ indeed
第三句
更
看上去
-安全
零
也好
上去
@@ -3702,7 +3701,6 @@ sup
它们的
它是
它的
-安全
完全
完成
定
diff --git a/test/test.py b/test/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cde43529da7dac29d5ed7f6b838f4b36c22c8ec4
--- /dev/null
+++ b/test/test.py
@@ -0,0 +1,88 @@
+import argparse
+import asyncio
+import pandas as pd
+
+from token_tool import TokenTool
+from pydantic import BaseModel, Field
+
+
+class TestEntity(BaseModel):
+ """测试实体模型"""
+ question: str = Field(default="", description="问题")
+ answer: str = Field(default="", description="答案")
+ chunk: str = Field(default="", description="上下文片段")
+ llm_answer: str = Field(default="", description="大模型答案")
+ related_chunk: str = Field(default="", description="相关上下文片段")
+ pre: float = Field(default=0.0, description="准确率")
+ rec: float = Field(default=0.0, description="召回率")
+ fai: float = Field(default=0.0, description="可信度")
+ rel: float = Field(default=0.0, description="相关性")
+ lcs: float = Field(default=0.0, description="最长公共子串")
+ leve: float = Field(default=0.0, description="编辑距离")
+ jac: float = Field(default=0.0, description="Jaccard相似度")
+
+
+async def read_data_from_file(input_xlsx_file: str) -> list[TestEntity]:
+ """从文件读取测试数据"""
+ df = pd.read_excel(input_xlsx_file)
+ data = []
+ for _, row in df.iterrows():
+ entity = TestEntity(
+ question=row.get('question', ''),
+ answer=row.get('answer', ''),
+ chunk=row.get('chunk', ''),
+ llm_answer=row.get('llm_answer', ''),
+ related_chunk=row.get('related_chunk', '')
+ )
+ data.append(entity)
+ return data
+
+
+async def write_data_to_file(output_xlsx_file: str, data: list[TestEntity]) -> None:
+ """
+ 将测试数据写入文件
+ 第一个sheet写入平均分,第二个sheet写入详细数据
+ """
+ average_data = {
+ 'pre': sum(item.pre for item in data) / len(data) if data else 0,
+ 'rec': sum(item.rec for item in data) / len(data) if data else 0,
+ 'fai': sum(item.fai for item in data) / len(data) if data else 0,
+ 'rel': sum(item.rel for item in data) / len(data) if data else 0,
+ 'lcs': sum(item.lcs for item in data) / len(data) if data else 0,
+ 'leve': sum(item.leve for item in data) / len(data) if data else 0,
+ 'jac': sum(item.jac for item in data) / len(data) if data else 0,
+ }
+ average_df = pd.DataFrame([average_data])
+ detailed_df = pd.DataFrame([item.model_dump() for item in data])
+ with pd.ExcelWriter(output_xlsx_file) as writer:
+ average_df.to_excel(writer, sheet_name='average', index=False)
+ detailed_df.to_excel(writer, sheet_name='detailed', index=False)
+
+
+async def evaluate_metrics(data: list[TestEntity], language: str) -> None:
+ """评估测试数据的各项指标"""
+ token_tool = TokenTool()
+ for item in data:
+ item.pre = await token_tool.cal_precision(item.question, item.llm_answer, language)
+ item.rec = await token_tool.cal_recall(item.question, item.related_chunk, language)
+ item.fai = await token_tool.cal_faithfulness(item.question, item.llm_answer, item.related_chunk, language)
+ item.rel = await token_tool.cal_relevance(item.question, item.llm_answer, language)
+ item.lcs = token_tool.cal_lcs(item.answer, item.llm_answer)
+ item.leve = token_tool.cal_leve(item.answer, item.llm_answer)
+ item.jac = token_tool.cal_jac(item.answer, item.llm_answer)
+ print(f"评估完成: 问题: {item.question}, 准确率: {item.pre}, 召回率: {item.rec}, 可信度: {item.fai}, 相关性: {item.rel}, 最长公共子串: {item.lcs}, 编辑距离: {item.leve}, Jaccard相似度: {item.jac}")
+
+
+def work(input_xlsx_file: str, output_xlsx_file: str, language: str) -> None:
+ data = asyncio.run(read_data_from_file(input_xlsx_file))
+ asyncio.run(evaluate_metrics(data, language))
+ asyncio.run(write_data_to_file(output_xlsx_file, data))
+
+
+if __name__ == '__main__':
+ args = argparse.ArgumentParser()
+ args.add_argument('--input_xlsx_file', type=str, required=True, help='输入xlsx文件路径')
+ args.add_argument('--output_xlsx_file', type=str, required=True, help='输出xlsx文件路径')
+ args.add_argument('--language', type=str, default='中文', help='语言类型,默认中文zh,英文en')
+ parsed_args = args.parse_args()
+ work(parsed_args.input_xlsx_file, parsed_args.output_xlsx_file, parsed_args.language)
diff --git a/test/test.xlsx b/test/test.xlsx
new file mode 100644
index 0000000000000000000000000000000000000000..38396ac66b526a7647465e7468a14097a09b9586
Binary files /dev/null and b/test/test.xlsx differ
diff --git a/test/test_qa.py b/test/test_qa.py
deleted file mode 100644
index ab3408d5aa1b1ee5aa49c16b2c5509ec332a732a..0000000000000000000000000000000000000000
--- a/test/test_qa.py
+++ /dev/null
@@ -1,719 +0,0 @@
-import subprocess
-import argparse
-import asyncio
-import json
-import os
-import random
-import time
-from pathlib import Path
-import jieba
-import pandas as pd
-
-import yaml
-import requests
-from typing import Optional, List
-from pydantic import BaseModel, Field
-from tools.config import config
-from tools.llm import LLM
-from tools.similar_cal_tool import Similar_cal_tool
-current_dir = Path(__file__).resolve().parent
-
-
-def login_and_get_tokens(account, password, base_url):
- """
- 尝试登录并获取新的session ID和CSRF token。
-
- :param login_url: 登录的URL地址
- :param account: 用户账号
- :param password: 用户密码
- :return: 包含新session ID和CSRF token的字典,或者在失败时返回None
- """
- # 构造请求头部
- headers = {
- 'Content-Type': 'application/x-www-form-urlencoded',
- }
-
- # 构造请求数据
- params = {
- 'account': account,
- 'password': password
- }
- # 发送POST请求
- url = f"{base_url}/user/login"
- response = requests.get(url, headers=headers, params=params)
- # 检查响应状态码是否为200表示成功
- if response.status_code == 200:
- # 如果登录成功,获取新的session ID和CSRF token
- new_session = response.cookies.get("WD_ECSESSION")
- new_csrf_token = response.cookies.get("wd_csrf_tk")
- if new_session and new_csrf_token:
- return response.json(), {
- 'ECSESSION': new_session,
- 'csrf_token': new_csrf_token
- }
- else:
- print("Failed to get new session or CSRF token.")
- return None
- else:
- print(f"Failed to login, status code: {response.status_code}")
- return None
-
-
-def tokenize(text):
- return len(list(jieba.cut(str(text))))
-
-
-class DictionaryBaseModel(BaseModel):
- pass
-
-
-class ListChunkRequest(DictionaryBaseModel):
- document_id: str
- text: Optional[str] = None
- page_number: int = 1
- page_size: int = 50
- type: Optional[list[str]] = None
-
-
-def list_chunks(session_cookie: str, csrf_cookie: str, document_id: str,
- text: Optional[str] = None, page_number: int = 1, page_size: int = 50,
- base_url="http://0.0.0.0:9910") -> dict:
- """
- 请求文档块列表的函数。
-
- :param session_cookie: 用户会话cookie
- :param csrf_cookie: CSRF保护cookie
- :param document_id: 文档ID
- :param text: 可选的搜索文本
- :param page_number: 页码,默认为1
- :param page_size: 每页大小,默认为10
- :param base_url: API基础URL,默认为本地测试服务器地址
- :return: JSON响应数据
- """
- # 构造请求cookies
- # print(document_id)
- cookies = {
- "WD_ECSESSION": session_cookie,
- "wd_csrf_tk": csrf_cookie
- }
-
- # 创建请求体实例
- payload = ListChunkRequest(
- document_id=document_id,
- text=text,
- page_number=page_number,
- page_size=page_size,
- ).dict()
-
- # 发送POST请求
- url = f"{base_url}/chunk/list"
- response = requests.post(url, cookies=cookies, json=payload)
-
- # 一次性获取所有chunk
- # print(response.json())
- page_size = response.json()['data']['total']
-
- # 创建请求体实例
- payload = ListChunkRequest(
- document_id=document_id,
- text=text,
- page_number=page_number,
- page_size=page_size,
- ).dict()
-
- # 发送POST请求
- url = f"{base_url}/chunk/list"
- response = requests.post(url, cookies=cookies, json=payload)
-
- # 返回JSON响应数据
- return response.json()
-
-
-def parser():
- # 创建 ArgumentParser 对象
- parser = argparse.ArgumentParser(description="Script to process document and generate QA pairs.")
- subparser = parser.add_subparsers(dest='mode', required=True, help='Mode of operation')
-
- # 离线模式参数
- offline = subparser.add_parser('offline', help='Offline mode for processing documents') # noqa: F841
- offline.add_argument("-i", "--input_path", required=True, default="./document", help="Path of document names",)
- # 在线模式所需添加的参数
- online = subparser.add_parser('online', help='Online mode for processing documents')
- online.add_argument('-n', '--name', type=str, required=True, help='User name')
- online.add_argument('-p', '--password', type=str, required=True, help='User password')
- online.add_argument('-k', '--kb_id', type=str, required=True, help='KnowledgeBase ID')
- online.add_argument('-u', '--url', type=str, required=True, help='URL for witChainD')
-
- # 添加可选参数,并设置默认值
- online.add_argument('-q', '--qa_count', type=int, default=1,
- help='Number of QA pairs to generate per text block (default: 1)')
-
- # 添加文件名列表参数
- online.add_argument('-d', '--doc_names', nargs='+', required=False, default=[], help='List of document names')
-
- # 解析命令行参数
- args = parser.parse_args()
- return args
-
-
-def get_prompt_dict():
- """
- 获取prompt表
- """
- try:
- with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f:
- prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
- return prompt_dict
- except Exception as e:
- print(f"open {config['PROMPT_PATH']} error {e}")
- raise e
-
-
-prompt_dict = get_prompt_dict()
-llm = LLM(model_name=config['MODEL_NAME'],
- openai_api_base=config['OPENAI_API_BASE'],
- openai_api_key=config['OPENAI_API_KEY'],
- max_tokens=config['MAX_TOKENS'],
- request_timeout=60,
- temperature=0.35)
-
-def get_random_number(l, r):
- return random.randint(l, r-1)
-
-
-class QAgenerator:
-
- async def qa_generate(self, chunks, file):
- """
- 多线程生成问答对
- """
- start_time = time.time()
- results = []
- prev_texts = []
- ans = 0
- # 使用 asyncio.gather 来并行处理每个 chunk
- tasks = []
- # 获取 chunks 的长度
- num_chunks = len(chunks)
- image_sum = 0
- for chunk in chunks:
- chunk['count'] = 0
- # if chunk['type'] == 'image':
- # chunk['count'] = chunk['count'] + 1
- # image_sum = image_sum + 1
- for i in range(args.qa_count):
- x = get_random_number(min(3, num_chunks-1), num_chunks)
- print(x)
- chunks[x]['count'] = chunks[x]['count'] + 1
-
- now_text = ""
- for chunk in chunks:
- now_text = now_text + chunk['text'] + '\n'
- # if chunk['type'] == 'table' and len(now_text) < (config['MAX_TOKENS'] // 8):
- # continue
- prev_text = '\n'.join(prev_texts)
- while tokenize(prev_text) > (config['MAX_TOKENS'] / 4):
- prev_texts.pop(0)
- prev_text = '\n'.join(prev_texts)
- if chunk['count'] > 0:
- tasks.append(self.generate(now_text, prev_text, results, file, chunk['count'], chunk['type']))
- prev_texts.append(now_text)
- now_text = ''
- ans = ans + chunk['count'] + image_sum
-
- # 等待所有任务完成
- await asyncio.gather(*tasks)
- print('问答对案例:', results[:50])
- print("问答对生成总计用时:", time.time() - start_time)
- print(f"总计生成{ans}条问答对")
- return results
-
- async def generate(self, now_text, prev_text, results, file, qa_count, type_text):
- """
- 生成问答
- """
- prev_text = prev_text[-(config['MAX_TOKENS'] // 8):]
- prompt = prompt_dict.get('GENERATE_QA')
- count = 0
- while count < 5:
- try:
- # 使用多线程处理 chat_with_llm 调用
- result_temp = await self.chat_with_llm(llm, prompt, now_text, prev_text,
- qa_count, file)
-
- for result in result_temp:
- result['text'] = prev_text + now_text
- result['type_text'] = type_text
- results.append(result)
- count = 5
- except Exception as e:
- count += 1
- print('error:', e, 'retry times', count)
- if count == 5:
- results.append({'text': now_text, 'question': '无法生成问答对',
- 'answer': '无法生成问答对', 'type': 'error', 'type_text': 'error'})
-
- @staticmethod
- async def chat_with_llm(llm, prompt, text, prev_text, qa_count, file_name) -> dict:
- """
- 对于给定的文本,通过llm生成问题-答案-段落对。
- params:
- - llm: LLm
- - text: str
- - prompt: str
- return:
- - qa_pairs: list[dict]
-
- """
- text.replace("\"", "\\\"")
- user_call = (f"文本内容来自于{file_name},请以JSON格式输出{qa_count}对不同的问题-答案-领域,格式为["
- "{"
- "\"question\": \" 问题 \", "
- "\"answer\": \" 回答 \","
- "\"type\": \" 领域 \""
- "}\n"
- "],并且必须将问题和回答中和未被转义的双引号转义,元素标签请用双引号括起来")
- prompt = prompt.format(chunk=text, qa_count=qa_count, text=prev_text, file_name=file_name)
- # print(prompt)
- qa_pair = await llm.nostream([], prompt, user_call)
- # 提取问题、答案段落对的list,字符串格式为["问题","答案","段落对"]
- print(qa_pair)
- # print("原文:", text)
- qa_pair = json.loads(qa_pair)
- return qa_pair
-
-
-class QueryRequest(BaseModel):
- question: str
- kb_sn: Optional[str] = None
- top_k: int = Field(5, ge=0, le=10)
- fetch_source: bool = False
- history: Optional[List] = []
-
-
-def call_get_answer(text, kb_id, session_cookie, csrf_cookie, base_url="http://0.0.0.0:9910"):
- # 构造请求cookies
- cookies = {
- "WD_ECSESSION": session_cookie,
- "wd_csrf_tk": csrf_cookie
- }
-
- # 构造请求体
- req = QueryRequest(
- question=text,
- kb_sn=kb_id,
- top_k=3,
- fetch_source=True,
- history=[]
- )
-
- url = f"{base_url}/kb/get_answer"
- print(url)
- headers = {
- "Content-Type": "application/json",
- "Accept": "application/json"
- }
- data = req.json().encode("utf-8")
-
- for i in range(5):
- try:
- response = requests.post(url, headers=headers, cookies=cookies, data=data)
-
- if response.status_code == 200:
- result = response.json()
- # print("成功获取答案")
- return result
- print(f"请求失败,状态码: {response.status_code}, 响应内容: {response.text}")
- time.sleep(1)
- except Exception as e:
- print(f"请求answer失败,错误原因{e}, 重试次数:{i+1}")
- time.sleep(1)
-
-
-async def get_answers(QA, kb_id, session_cookie, csrf_cookie, base_url):
- text = QA['question']
- print(f"原文:{QA['text'][:40]}...")
- result = call_get_answer(text, kb_id, session_cookie, csrf_cookie, base_url)
- if result is None:
- return None
- else:
- QA['witChainD_answer'] = result['data']['answer']
- QA['witChainD_source'] = result['data']['source']
- QA['time_cost']=result['data']['time_cost']
- print(f"原文:{QA['text'][:40] + '...'}\n问题:{text}\n回答:{result['data']['answer'][:40]}\n\n")
- return QA
-
-
-async def get_QAs_answers(QAs, kb_id, session_cookie, csrf_cookie, base_url):
- results = []
- tasks = []
- for QA in QAs:
- tasks.append(get_answers(QA, kb_id, session_cookie, csrf_cookie, base_url))
- response = await asyncio.gather(*tasks)
- for idx, result in enumerate(response):
- if result is not None:
- results.append(result)
- return results
-
-
-class QAScore():
- async def get_score(self, QA):
- prompt = prompt_dict['SCORE_QA']
- llm_score_dict = await self.chat_with_llm(llm, prompt, QA['question'], QA['text'], QA['witChainD_source'], QA['answer'], QA['witChainD_answer'])
- print(llm_score_dict)
-
- QA['context_relevancy'] = llm_score_dict['context_relevancy']
- QA['context_recall'] = llm_score_dict['context_recall']
- QA['faithfulness'] = llm_score_dict['faithfulness']
- QA['answer_relevancy'] = llm_score_dict['answer_relevancy']
- print(QA)
- try:
- lcs_score = Similar_cal_tool.longest_common_subsequence(QA['answer'], QA['witChainD_answer'])
- except:
- lcs_score = 0
- QA['lcs_score'] = lcs_score
- try:
- jac_score = Similar_cal_tool.jaccard_distance(QA['answer'], QA['witChainD_answer'])
- except:
- jac_score = 0
- QA['jac_score'] = jac_score
- try:
- leve_score = Similar_cal_tool.levenshtein_distance(QA['answer'], QA['witChainD_answer'])
- except:
- leve_score = 0
- QA['leve_score'] = leve_score
- return QA
-
- async def get_scores(self, QAs):
- tasks = []
- results = []
- for QA in QAs:
- tasks.append(self.get_score(QA))
- response = await asyncio.gather(*tasks)
- for idx, result in enumerate(response):
- if result is not None:
- results.append(result)
- return results
-
- @staticmethod
- async def chat_with_llm(llm, prompt, question, meta_chunk, chunk, answer, answer_text) -> dict:
- """
- 对于给定的文本,通过llm生成问题-答案-段落对。
- params:
- - llm: LLm
- - text: str
- - prompt: str
- return:
- - qa_pairs: list[dict]
-
- """
- required_metrics = {
- "context_relevancy",
- "context_recall",
- "faithfulness",
- "answer_relevancy",
- }
- for i in range(5):
- try:
- user_call = """请对答案打分,并以下面形式返回结果{
- \"context_relevancy\": 分数,
- \"context_recall\": 分数,
- \"faithfulness\": 分数,
- \"answer_relevancy\": 分数
-}
-注意:属性名必须使用双引号,分数为数字,保留两位小数。"""
- prompt = prompt.format(question=question, meta_chunk=meta_chunk,
- chunk=chunk, answer=answer, answer_text=answer_text)
- # print(prompt)
- score_dict = await llm.nostream([], prompt, user_call)
- st = score_dict.find('{')
- en = score_dict.rfind('}')
- if st != -1 and en != -1:
- score_dict = score_dict[st:en+1]
- # print(score_dict)
- score_dict = json.loads(score_dict)
- # 提取问题、答案段落对的list,字符串格式为["问题","答案","段落对"]
- # print(score)
- present_metrics = set(score_dict.keys())
- missing_metrics = required_metrics - present_metrics
- if missing_metrics:
- missing = ", ".join(missing_metrics)
- print(f"评分结果缺少必要指标: {missing}")
- for metric in required_metrics:
- if metric not in score_dict:
- score_dict[metric] = 0.00
- print(score_dict)
- return score_dict
- except Exception as e:
- continue
- return {
- "context_relevancy": 0,
- "context_recall": 0,
- "faithfulness": 0,
- "answer_relevancy": 0,
- }
-
-
-def list_documents(session_cookie, csrf_cookie, kb_id, base_url="http://0.0.0.0:9910"):
- # 构造请求cookies
- cookies = {
- "WD_ECSESSION": session_cookie,
- "wd_csrf_tk": csrf_cookie
- }
-
- # 构造请求URL
- url = f"{base_url}/doc/list"
-
- # 构造请求体
- payload = {
- "kb_id": str(kb_id), # 将uuid对象转换为字符串
- "page_number": 1,
- "page_size": 50,
- }
-
- # 发送POST请求
- response = requests.post(url, cookies=cookies, json=payload)
- # print(response.text)
-
- # 一次性获取所有document
- total = response.json()['data']['total']
- documents = []
- for i in range(1, (total + 50) // 50 + 1):
- # 创建请求体实例
- print(f"page {i} gets")
- payload = {
- "kb_id": str(kb_id), # 将uuid对象转换为字符串
- "page_number": i,
- "page_size": 50,
- }
-
- response = requests.post(url, cookies=cookies, json=payload)
- js = response.json()
- now_documents = js['data']['data_list']
- documents.extend(now_documents)
- # 返回响应文本
- return documents
-
-def get_document(dir):
- documents = []
- print(os.listdir(dir))
- for file in os.listdir(dir):
- if file.endswith('.xlsx'):
- file_path = os.path.join(dir, file)
- df = pd.read_excel(file_path)
- documents.append(df.to_dict(orient='records'))
- if file.endswith('.csv'):
- file_path = os.path.join(dir, file)
- df = pd.read_csv(file_path, )
- documents.append(df.to_dict(orient='records'))
- return documents
-
-if __name__ == '__main__':
- """
- 脚本参数包含 name, password, doc_id, qa_count, url
- - name: 通过-n或者--name读入,必须
- - password: 通过-p或者--password读入,必须
- - kb_id: 通过-k或者--kb_id读入,必须
- - qa_count: 通过-q或者--qa_count读入,非必须,默认为1,表示每个文档生成多少个问答对
- - url: 通过-u或者--url读入,必须,为witChainD的路径
- - doc_names: 通过-d或者--doc_names读入,非必须,默认为None,表示所有文档的名称
- 需要在.env中配置好LLM和witChainD相关的config,以及prompt路径
- """
- args = parser()
- QAs = []
- if args.mode == 'online':
- js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url)
- session_cookie = tmp_dict['ECSESSION']
- csrf_cookie = tmp_dict['csrf_token']
- print('login success')
- documents = list_documents(session_cookie, csrf_cookie, args.kb_id, args.url)
- print('get document success')
- print(documents)
- for document in documents:
- # print('refresh tokens')
- # print(json.dumps(document, indent=4, ensure_ascii=False))
- if args.doc_names != [] and document['name'] not in args.doc_names:
- # args.doc_names = []
- continue
- else:
- args.doc_names = []
- js, tmp_dict = login_and_get_tokens(args.name, args.password, args.url)
- session_cookie = tmp_dict['ECSESSION']
- csrf_cookie = tmp_dict['csrf_token']
- args.doc_id = document['id']
- args.doc_name = document['name']
- count = 0
- while count < 5:
- try:
- js = list_chunks(session_cookie, csrf_cookie, str(args.doc_id), base_url=args.url)
- print(f'js: {js}')
- count = 10
- except Exception as e:
- print(f"document {args.doc_name} check failed {e} with retry {count}")
- count = count + 1
- time.sleep(1)
- continue
- if count == 5:
- print(f"document {args.doc_name} check failed")
- continue
- chunks = js['data']['data_list']
- new_chunks = []
- for chunk in chunks:
- new_chunk = {
- 'text': chunk['text'],
- 'type': chunk['type'],
- }
- new_chunks.append(new_chunk)
- chunks = new_chunks
- model = QAgenerator()
- try:
- print('正在生成QA对...')
- t_QAs = asyncio.run(model.qa_generate(chunks=chunks, file=args.doc_name))
- print("QA对生成完毕,正在获取答案...")
- tt_QAs = asyncio.run(get_QAs_answers(t_QAs, args.kb_id, session_cookie, csrf_cookie, args.url))
- print(f"tt_QAs: {tt_QAs}")
- print("答案获取完毕,正在计算答案正确性...")
- ttt_QAs = asyncio.run(QAScore().get_scores(tt_QAs))
- print(f"ttt_QAs: {ttt_QAs}")
- for QA in t_QAs:
- QAs.append(QA)
- df = pd.DataFrame(QAs)
- df.astype(str)
- print(document['name'], 'down')
- print('sample:', t_QAs[0]['question'][:40])
- df.to_excel(current_dir / 'temp_answer.xlsx', index=False)
- print(f'temp_Excel结果已输出到{current_dir}/temp_answer.xlsx')
- except Exception as e:
- import traceback
- print(traceback.print_exc())
- print(f"document {args.doc_name} failed {e}")
- continue
- else:
- # 离线模式
- # print(document_path)
- t_QAs = get_document(args.input_path)
- print(f"获取到{len(t_QAs)}个文档")
- for item in t_QAs[0]:
- single_item = {
- "question": item["问题"],
- "answer": item["标准答案"],
- "witChainD_answer": item["llm的回答"],
- "text": item["原始片段"],
- "witChainD_source": item["检索片段"],
- }
- # print(single_item)
- ttt_QAs = asyncio.run(QAScore().get_score(single_item))
- QAs.append(ttt_QAs)
- # # 输出QAs到xlsx中
- # exit(0)
- newQAs = []
- total = {
- "context_relevancy(上下文相关性)": [],
- "context_recall(召回率)": [],
- "faithfulness(忠实性)": [],
- "answer_relevancy(答案的相关性)": [],
- "lcs_score(最大公共子串)": [],
- "jac_score(杰卡德距离)": [],
- "leve_score(编辑距离)": [],
- "time_cost": {
- "keyword_searching": [],
- "text_to_vector": [],
- "vector_searching": [],
- "vectors_related_texts": [],
- "text_expanding": [],
- "llm_answer": [],
- },
- }
-
- time_cost_metrics = list(total["time_cost"].keys())
-
- for QA in QAs:
- print(QA)
- try:
- if 'time_cost' in QA.keys():
- ReOrderedQA = {
- '领域': str(QA['type']),
- '问题': str(QA['question']),
- '标准答案': str(QA['answer']),
- 'llm的回答': str(QA['witChainD_answer']),
- 'context_relevancy(上下文相关性)': str(QA['context_relevancy']),
- 'context_recall(召回率)': str(QA['context_recall']),
- 'faithfulness(忠实性)': str(QA['faithfulness']),
- 'answer_relevancy(答案的相关性)': str(QA['answer_relevancy']),
- 'lcs_score(最大公共子串)': str(QA['lcs_score']),
- 'jac_score(杰卡德距离)': str(QA['jac_score']),
- 'leve_score(编辑距离)': str(QA['leve_score']),
- '原始片段': str(QA['text']),
- '检索片段': str(QA['witChainD_source']),
- 'keyword_searching_cost(关键字搜索时间消耗)': str(QA['time_cost']['keyword_searching'])+'s',
- 'query_to_vector_cost(qeury向量化时间消耗)': str(QA['time_cost']['text_to_vector'])+'s',
- 'vector_searching_cost(向量化检索时间消耗)': str(QA['time_cost']['vector_searching'])+'s',
- 'vectors_related_texts_cost(向量关联文档时间消耗)': str(QA['time_cost']['vectors_related_texts'])+'s',
- 'text_expanding_cost(上下文关联时间消耗)': str(QA['time_cost']['text_expanding'])+'s',
- 'llm_answer_cost(大模型回答时间消耗)': str(QA['time_cost']['llm_answer'])+'s'
- }
- else:
- ReOrderedQA = {
- # '领域': str(QA['type']),
- '问题': str(QA['question']),
- '标准答案': str(QA['answer']),
- 'llm的回答': str(QA['witChainD_answer']),
- 'context_relevancy(上下文相关性)': str(QA['context_relevancy']),
- 'context_recall(召回率)': str(QA['context_recall']),
- 'faithfulness(忠实性)': str(QA['faithfulness']),
- 'answer_relevancy(答案的相关性)': str(QA['answer_relevancy']),
- 'lcs_score(最大公共子串)': str(QA['lcs_score']),
- 'jac_score(杰卡德距离)': str(QA['jac_score']),
- 'leve_score(编辑距离)': str(QA['leve_score']),
- '原始片段': str(QA['text']),
- '检索片段': str(QA['witChainD_source'])
- }
- print(ReOrderedQA)
- newQAs.append(ReOrderedQA)
-
- for metric in total.keys():
- if metric != "time_cost": # 跳过time_cost(特殊处理)
- value = ReOrderedQA.get(metric)
- if value is not None:
- total[metric].append(float(value))
-
- if "time_cost" in QA:
- for sub_metric in time_cost_metrics:
- value = QA["time_cost"].get(sub_metric)
- if value is not None:
- total["time_cost"][sub_metric].append(float(value))
- except Exception as e:
- print(f"QA {QA} error {e}")
-
- # 计算平均值
- avg = {}
- for metric, values in total.items():
- if metric != "time_cost":
- avg[metric] = sum(values) / len(values) if values else 0.0
- else: # 处理time_cost
- avg_time_cost = {}
- for sub_metric, sub_values in values.items():
- avg_time_cost[sub_metric] = (
- sum(sub_values) / len(sub_values) if sub_values else 0.0
- )
- avg[metric] = avg_time_cost
-
-
- excel_path = current_dir / 'answer.xlsx'
- with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer:
- # 写入第一个sheet(测试样例)
- df = pd.DataFrame(newQAs).astype(str)
- df.to_excel(writer, sheet_name="测试样例", index=False)
-
- # 写入第二个sheet(测试结果)
- filtered_time_cost = {k: v for k, v in avg["time_cost"].items() if v != 0}
- flat_avg = {
- **{k: v for k, v in avg.items() if k != "time_cost"},
- **{f"time_cost_{k}": v for k, v in filtered_time_cost.items()},
- }
- print(f"写入测试结果:{flat_avg}")
- avg_df = pd.DataFrame([flat_avg])
- avg_df.to_excel(writer, sheet_name="测试结果", index=False)
-
-
- print(f'测试样例和结果已输出到{excel_path}')
diff --git a/test/token_tool.py b/test/token_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..36e067bcb1e1ebefbbd18fa3e5711f76f2b71d13
--- /dev/null
+++ b/test/token_tool.py
@@ -0,0 +1,506 @@
+import asyncio
+import tiktoken
+import jieba
+from jieba.analyse import extract_tags
+import yaml
+import json
+import re
+import uuid
+import numpy as np
+from pydantic import BaseModel, Field
+from llm import LLM
+from embedding import Embedding
+from config import BaseConfig
+
+
+class Grade(BaseModel):
+ content_len: int = Field(..., description="内容长度")
+ tokens: int = Field(..., description="token数")
+
+
+class TokenTool:
+ stop_words_path = "./stopwords.txt"
+ prompt_path = "./prompt.yaml"
+ with open(stop_words_path, 'r', encoding='utf-8') as f:
+ stopwords = set(line.strip() for line in f)
+ with open(prompt_path, 'r', encoding='utf-8') as f:
+ prompt_dict = yaml.load(f, Loader=yaml.SafeLoader)
+
+ @staticmethod
+ def filter_stopwords(content: str) -> str:
+ """
+ 过滤停用词
+ """
+ try:
+ words = TokenTool.split_words(content)
+ filtered_words = [word for word in words if word not in TokenTool.stopwords]
+ return ' '.join(filtered_words)
+ except Exception as e:
+ err = f"[TokenTool] 过滤停用词失败 {e}"
+ print("[TokenTool] %s", err)
+ return content
+
+ @staticmethod
+ def get_leave_tokens_from_content_len(content: str) -> int:
+ """
+ 根据内容长度获取留存的token数
+ """
+ grades = [
+ Grade(content_len=0, tokens=0),
+ Grade(content_len=10, tokens=8),
+ Grade(content_len=50, tokens=16),
+ Grade(content_len=250, tokens=32),
+ Grade(content_len=1250, tokens=64),
+ Grade(content_len=6250, tokens=128),
+ Grade(content_len=31250, tokens=256),
+ Grade(content_len=156250, tokens=512),
+ Grade(content_len=781250, tokens=1024),
+ ]
+ tokens = TokenTool.get_tokens(content)
+ if tokens >= grades[-1].tokens:
+ return 1024
+ index = 0
+ for i in range(len(grades)-1):
+ if grades[i].content_len <= tokens < grades[i+1].content_len:
+ index = i
+ break
+ leave_tokens = grades[index].tokens+(grades[index+1].tokens-grades[index].tokens)*(
+ tokens-grades[index].content_len)/(grades[index+1].content_len-grades[index].content_len)
+ return int(leave_tokens)
+
+ @staticmethod
+ def get_leave_setences_from_content_len(content: str) -> int:
+ """
+ 根据内容长度获取留存的句子数量
+ """
+ grades = [
+ Grade(content_len=0, tokens=0),
+ Grade(content_len=10, tokens=4),
+ Grade(content_len=50, tokens=8),
+ Grade(content_len=250, tokens=16),
+ Grade(content_len=1250, tokens=32),
+ Grade(content_len=6250, tokens=64),
+ Grade(content_len=31250, tokens=128),
+ Grade(content_len=156250, tokens=256),
+ Grade(content_len=781250, tokens=512),
+ ]
+ sentences = TokenTool.content_to_sentences(content)
+ if len(sentences) >= grades[-1].tokens:
+ return 1024
+ index = 0
+ for i in range(len(grades)-1):
+ if grades[i].content_len <= len(sentences) < grades[i+1].content_len:
+ index = i
+ break
+ leave_sentences = grades[index].tokens+(grades[index+1].tokens-grades[index].tokens)*(
+ len(sentences)-grades[index].content_len)/(grades[index+1].content_len-grades[index].content_len)
+ return int(leave_sentences)
+
+ @staticmethod
+ def get_tokens(content: str) -> int:
+ try:
+ enc = tiktoken.encoding_for_model("gpt-4")
+ return len(enc.encode(str(content)))
+ except Exception as e:
+ err = f"[TokenTool] 获取token失败 {e}"
+ print("[TokenTool] %s", err)
+ return 0
+
+ @staticmethod
+ def get_k_tokens_words_from_content(content: str, k: int = 16) -> list:
+ try:
+ if (TokenTool.get_tokens(content) <= k):
+ return content
+ l = 0
+ r = len(content)
+ while l+1 < r:
+ mid = (l+r)//2
+ if (TokenTool.get_tokens(content[:mid]) <= k):
+ l = mid
+ else:
+ r = mid
+ return content[:l]
+ except Exception as e:
+ err = f"[TokenTool] 获取k个token的词失败 {e}"
+ print("[TokenTool] %s", err)
+ return ""
+
+ @staticmethod
+ def split_str_with_slide_window(content: str, slide_window_size: int) -> list:
+ """
+ 将字符串按滑动窗口切割
+ """
+ result = []
+ try:
+ while len(content) > 0:
+ sub_content = TokenTool.get_k_tokens_words_from_content(content, slide_window_size)
+ result.append(sub_content)
+ content = content[len(sub_content):]
+ return result
+ except Exception as e:
+ err = f"[TokenTool] 滑动窗口切割失败 {e}"
+ print("[TokenTool] %s", err)
+ return []
+
+ @staticmethod
+ def compress_tokens(content: str, k: int = None) -> str:
+ try:
+ words = TokenTool.split_words(content)
+ # 过滤掉停用词
+ filtered_words = [
+ word for word in words if word not in TokenTool.stopwords
+ ]
+ filtered_content = ''.join(filtered_words)
+ if k is not None:
+ # 如果k不为None,则获取k个token的词
+ filtered_content = TokenTool.get_k_tokens_words_from_content(filtered_content, k)
+ return filtered_content
+ except Exception as e:
+ err = f"[TokenTool] 压缩token失败 {e}"
+ print("[TokenTool] %s", err)
+ return content
+
+ @staticmethod
+ def split_words(content: str) -> list:
+ try:
+ return list(jieba.cut(str(content)))
+ except Exception as e:
+ err = f"[TokenTool] 分词失败 {e}"
+ print("[TokenTool] %s", err)
+ return []
+
+ @staticmethod
+ def get_top_k_keywords(content: str, k=10) -> list:
+ try:
+ # 使用jieba提取关键词
+ keywords = extract_tags(content, topK=k, withWeight=True)
+ return [keyword for keyword, weight in keywords]
+ except Exception as e:
+ err = f"[TokenTool] 获取关键词失败 {e}"
+ print("[TokenTool] %s", err)
+ return []
+
+ @staticmethod
+ def get_top_k_keywords_and_weights(content: str, k=10) -> list:
+ try:
+ # 使用jieba提取关键词
+ keyword_weight_list = extract_tags(content, topK=k, withWeight=True)
+ keywords = [keyword for keyword, weight in keyword_weight_list]
+ weights = [weight for keyword, weight in keyword_weight_list]
+ return keywords, weights
+ except Exception as e:
+ err = f"[TokenTool] 获取关键词失败 {e}"
+ print("[TokenTool] %s", err)
+ return []
+
+ @staticmethod
+ def get_top_k_keysentence(content: str, k: int = None) -> list:
+ """
+ 获取前k个关键句子
+ """
+ if k is None:
+ k = TokenTool.get_leave_setences_from_content_len(content)
+ leave_tokens = TokenTool.get_leave_tokens_from_content_len(content)
+ words = TokenTool.split_words(content)
+ # 过滤掉停用词
+ filtered_words = [
+ word for word in words if word not in TokenTool.stopwords
+ ]
+ keywords = TokenTool.get_top_k_keywords(''.join(filtered_words), leave_tokens)
+ keywords = set(keywords)
+ sentences = TokenTool.content_to_sentences(content)
+ sentence_and_score_list = []
+ index = 0
+ for sentence in sentences:
+ score = 0
+ words = TokenTool.split_words(sentence)
+ for word in words:
+ if word in keywords:
+ score += 1
+ sentence_and_score_list.append((index, sentence, score))
+ index += 1
+ sentence_and_score_list.sort(key=lambda x: x[1], reverse=True)
+ top_k_sentence_and_score_list = sentence_and_score_list[:k]
+ top_k_sentence_and_score_list.sort(key=lambda x: x[0])
+ return [sentence for index, sentence, score in top_k_sentence_and_score_list]
+
+ @staticmethod
+ async def cal_recall(answer_1: str, answer_2: str, language: str) -> float:
+ """
+ 计算recall
+ 参数:
+ answer_1:答案1
+ answer_2:答案2
+ llm:大模型
+ """
+ llm = LLM(
+ openai_api_base=BaseConfig().get_config().llm.llm_endpoint,
+ openai_api_key=BaseConfig().get_config().llm.llm_api_key,
+ model_name=BaseConfig().get_config().llm.llm_model_name,
+ max_tokens=BaseConfig().get_config().llm.max_tokens,
+ temperature=BaseConfig().get_config().llm.temperature
+ )
+ try:
+ prompt_template = TokenTool.prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
+ answer_1 = TokenTool.get_k_tokens_words_from_content(answer_1, llm.max_tokens//2)
+ answer_2 = TokenTool.get_k_tokens_words_from_content(answer_2, llm.max_tokens//2)
+ prompt = prompt_template.format(text_1=answer_1, text_2=answer_2)
+ sys_call = prompt
+ user_call = '请输出相似度'
+ similarity = await llm.nostream([], sys_call, user_call)
+ return eval(similarity)
+ except Exception as e:
+ err = f"[TokenTool] 计算recall失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
+
+ @staticmethod
+ async def cal_precision(question: str, content: str, language: str) -> float:
+ """
+ 计算precision
+ 参数:
+ question:问题
+ content:内容
+ """
+ llm = LLM(
+ openai_api_base=BaseConfig().get_config().llm.llm_endpoint,
+ openai_api_key=BaseConfig().get_config().llm.llm_api_key,
+ model_name=BaseConfig().get_config().llm.llm_model_name,
+ max_tokens=BaseConfig().get_config().llm.max_tokens,
+ temperature=BaseConfig().get_config().llm.temperature
+ )
+ try:
+ prompt_template = TokenTool.prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
+ content = TokenTool.compress_tokens(content, llm.max_tokens)
+ sys_call = prompt_template.format(content=content)
+ user_call = '请结合文本输出陈诉列表'
+ statements = await llm.nostream([], sys_call, user_call, st_str='[',
+ en_str=']')
+ statements = json.loads(statements)
+ if len(statements) == 0:
+ return 0
+ score = 0
+ prompt_template = TokenTool.prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
+ for statement in statements:
+ statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens)
+ prompt = prompt_template.format(statement=statement, question=question)
+ sys_call = prompt
+ user_call = '请结合文本输出YES或NO'
+ yn = await llm.nostream([], sys_call, user_call)
+ yn = yn.lower()
+ if yn == 'yes':
+ score += 1
+ return score/len(statements)*100
+ except Exception as e:
+ err = f"[TokenTool] 计算precision失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
+
+ @staticmethod
+ async def cal_faithfulness(question: str, answer: str, content: str, language: str) -> float:
+ """
+ 计算faithfulness
+ 参数:
+ question:问题
+ answer:答案
+ """
+ llm = LLM(
+ openai_api_base=BaseConfig().get_config().llm.llm_endpoint,
+ openai_api_key=BaseConfig().get_config().llm.llm_api_key,
+ model_name=BaseConfig().get_config().llm.llm_model_name,
+ max_tokens=BaseConfig().get_config().llm.max_tokens,
+ temperature=BaseConfig().get_config().llm.temperature
+ )
+ try:
+ prompt_template = TokenTool.prompt_dict.get('QA_TO_STATEMENTS_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
+ question = TokenTool.get_k_tokens_words_from_content(question, llm.max_tokens//8)
+ answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//8*7)
+ prompt = prompt_template.format(question=question, answer=answer)
+ sys_call = prompt
+ user_call = '请结合问题和答案输出陈诉'
+ statements = await llm.nostream([], sys_call, user_call, st_str='[',
+ en_str=']')
+ prompt_template = TokenTool.prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
+ statements = json.loads(statements)
+ if len(statements) == 0:
+ return 0
+ score = 0
+ content = TokenTool.compress_tokens(content, llm.max_tokens//8*7)
+ for statement in statements:
+ statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens//8)
+ prompt = prompt_template.format(statement=statement, fragment=content)
+ sys_call = prompt
+ user_call = '请输出YES或NO'
+ user_call = user_call
+ yn = await llm.nostream([], sys_call, user_call)
+ yn = yn.lower()
+ if yn == 'yes':
+ score += 1
+ return score/len(statements)*100
+ except Exception as e:
+ err = f"[TokenTool] 计算faithfulness失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
+
+ @staticmethod
+ def cosine_distance_numpy(vector1, vector2):
+ # 计算向量的点积
+ dot_product = np.dot(vector1, vector2)
+ # 计算向量的 L2 范数
+ norm_vector1 = np.linalg.norm(vector1)
+ norm_vector2 = np.linalg.norm(vector2)
+ # 计算余弦相似度
+ cosine_similarity = dot_product / (norm_vector1 * norm_vector2)
+ # 计算余弦距离
+ cosine_dist = 1 - cosine_similarity
+ return cosine_dist
+
+ @staticmethod
+ async def cal_relevance(question: str, answer: str, language: str) -> float:
+ """
+ 计算relevance
+ 参数:
+ question:问题
+ answer:答案
+ """
+ llm = LLM(
+ openai_api_base=BaseConfig().get_config().llm.llm_endpoint,
+ openai_api_key=BaseConfig().get_config().llm.llm_api_key,
+ model_name=BaseConfig().get_config().llm.llm_model_name,
+ max_tokens=BaseConfig().get_config().llm.max_tokens,
+ temperature=BaseConfig().get_config().llm.temperature
+ )
+ try:
+ prompt_template = TokenTool.prompt_dict.get('GENERATE_QUESTION_FROM_CONTENT_PROMPT', {})
+ prompt_template = prompt_template.get(language, '')
+ answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens)
+ sys_call = prompt_template.format(k=5, content=answer)
+ user_call = '请结合文本输出问题列表'
+ question_vector = await Embedding.vectorize_embedding(question)
+ qs = await llm.nostream([], sys_call, user_call)
+ qs = json.loads(qs)
+ if len(qs) == 0:
+ return 0
+ score = 0
+ for q in qs:
+ q_vector = await Embedding.vectorize_embedding(q)
+ score += TokenTool.cosine_distance_numpy(question_vector, q_vector)
+ return (score/len(qs)+1)/2*100
+ except Exception as e:
+ err = f"[TokenTool] 计算relevance失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
+
+ @staticmethod
+ def cal_lcs(str1: str, str2: str) -> float:
+ """
+ 计算两个字符串的最长公共子序列长度得分
+ """
+ try:
+ words1 = TokenTool.split_words(str1)
+ words2 = TokenTool.split_words(str2)
+ new_words1 = []
+ new_words2 = []
+ for word in words1:
+ if word not in TokenTool.stopwords:
+ new_words1.append(word)
+ for word in words2:
+ if word not in TokenTool.stopwords:
+ new_words2.append(word)
+ if len(new_words1) == 0 and len(new_words2) == 0:
+ return 100
+ if len(new_words1) == 0 or len(new_words2) == 0:
+ return 0
+ m = len(new_words1)
+ n = len(new_words2)
+ dp = np.zeros((m+1, n+1))
+ for i in range(1, m+1):
+ for j in range(1, n+1):
+ if new_words1[i-1] == new_words2[j-1]:
+ dp[i][j] = dp[i-1][j-1] + 1
+ else:
+ dp[i][j] = max(dp[i-1][j], dp[i][j-1])
+ lcs_length = dp[m][n]
+ score = lcs_length / min(len(new_words1), len(new_words2)) * 100
+ return score
+ except Exception as e:
+ err = f"[TokenTool] 计算lcs失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
+
+ @staticmethod
+ def cal_leve(str1: str, str2: str) -> float:
+ """
+ 计算两个字符串的编辑距离
+ """
+ try:
+ words1 = TokenTool.split_words(str1)
+ words2 = TokenTool.split_words(str2)
+ new_words1 = []
+ new_words2 = []
+ for word in words1:
+ if word not in TokenTool.stopwords:
+ new_words1.append(word)
+ for word in words2:
+ if word not in TokenTool.stopwords:
+ new_words2.append(word)
+ if len(new_words1) == 0 and len(new_words2) == 0:
+ return 100
+ if len(new_words1) == 0 or len(new_words2) == 0:
+ return 0
+ m = len(new_words1)
+ n = len(new_words2)
+ dp = np.zeros((m+1, n+1))
+ for i in range(m+1):
+ dp[i][0] = i
+ for j in range(n+1):
+ dp[0][j] = j
+ for i in range(1, m+1):
+ for j in range(1, n+1):
+ if new_words1[i-1] == new_words2[j-1]:
+ dp[i][j] = dp[i-1][j-1]
+ else:
+ dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+1)
+ edit_distance = dp[m][n]
+ score = (1 - edit_distance / max(len(new_words1), len(new_words2))) * 100
+ return score
+ except Exception as e:
+ err = f"[TokenTool] 计算leve失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
+
+ @staticmethod
+ def cal_jac(str1: str, str2: str) -> float:
+ """
+ 计算两个字符串的Jaccard相似度
+ """
+ try:
+ if len(str1) == 0 and len(str2) == 0:
+ return 100
+ words1 = TokenTool.split_words(str1)
+ words2 = TokenTool.split_words(str2)
+ new_words1 = []
+ new_words2 = []
+ for word in words1:
+ if word not in TokenTool.stopwords:
+ new_words1.append(word)
+ for word in words2:
+ if word not in TokenTool.stopwords:
+ new_words2.append(word)
+ if len(new_words1) == 0 or len(new_words2) == 0:
+ return 0
+ set1 = set(new_words1)
+ set2 = set(new_words2)
+ intersection = len(set1.intersection(set2))
+ union = len(set1.union(set2))
+ score = intersection / union * 100
+ return score
+ except Exception as e:
+ err = f"[TokenTool] 计算jac失败 {e}"
+ print("[TokenTool] %s", err)
+ return -1
diff --git a/test/tools/=1.21.6, b/test/tools/=1.21.6,
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/test/tools/config.py b/test/tools/config.py
deleted file mode 100644
index e677904723fd6896f8497af2aa6f543e6564b5a8..0000000000000000000000000000000000000000
--- a/test/tools/config.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-config = {
- "PROMPT_PATH": "./tools/prompt.yaml",
- "MODEL_NAME": "your_model_name", # Replace with your actual model name
- "OPENAI_API_BASE": "your_openai_api_base_url",
- "OPENAI_API_KEY": "your_openai_api_key",
- "REQUEST_TIMEOUT": 120,
- "MAX_TOKENS": 8096,
- "MODEL_ENH": "false",
-}
diff --git a/test/tools/prompt.yaml b/test/tools/prompt.yaml
deleted file mode 100644
index 7a3031b9bc28e7554fc21667dbcf19a5991f3d7d..0000000000000000000000000000000000000000
--- a/test/tools/prompt.yaml
+++ /dev/null
@@ -1,82 +0,0 @@
-GENERATE_QA: "你是一个问答生成专家,你的任务是根据我提供的段落内容和已有的问题,生成{qa_count}个不重复的针对该段落内容的问题与回答,
-并判断这个问答对的属于领域,并只输出问题、回答、领域。
-
-注意:
-
-1. 单个回答长度必须大于30字小于120字
-
-2. 问题不能出现重复
-
-3. 请指定明确的场景,如'xx公司', 'xx系统', 'xx项目', ‘xx软件'等
-
-4. 问题中不要使用模糊的指代词, 如'这'、'那'
-
-5. 划分领域的时候请忽略上下文内容,领域大概可以分为(建筑、园林、摄影、戏剧、戏曲、舞蹈、音乐、书法、绘画、雕塑、美食、营养、健身、运动、旅游、地理、气象、海洋、地质、生态、天文、化学、物理、生物、数学、统计、逻辑、人工智能、大数据、云计算、网络、通信、自动化、机械、电子、材料、能源、化工、纺织、服装、美容、美发、礼仪、公关、广告、营销、管理、金融、证券、保险、期货、税务、审计、会计、法律实务、知识产权)
-
-6. 问题必须与段落内容有逻辑关系
-
-7. 问题与回答在不重复的前提下,应当尽可能多地包含段落内容
-
-8. 输出的格式为:
-[
-
-{{
- \"question\": \" 问题 \",
- \"answer\": \" 回答 \",
- \"type\": \" 领域 \"
-}}
-
-,
-
-{{
- \"question\": \" 问题 \",
- \"answer\": \" 回答 \",
- \"type\": \" 领域 \"
-}}
-
-]
-
-10. 不要输出多余内容
-
-下面是给出的段落内容:
-
-{chunk}
-
-下面是段落的上下文内容:
-
-{text}
-
-下面是段落的来源文件
-{file_name}
-"
-SCORE_QA: "你是一个打分专家,你的任务是根据我提供的问题、原始片段和检索到的片段以及标准答案和答案,判断答案在下面四项指标的分数,每个指标要精确到小数点后面2位,且每次需要进行客观评价
-
-1.context_relevancy 解释:(上下文相关性,越高表示检索到的片段中无用的信息越少 0-100)
-2.context_recall 解释:(召回率,越高表示检索出来的片段与标准答案越相关 0-100)
-3.faithfulness 解释:(忠实性,越高表示答案的生成使用了越多检索出来的片段0-100)
-4.answer_relevancy 解释:(答案与问题的相关性 0-100)
-
-注意:
-请以下面格式输出
-{{
- \"context_relevancy\": 分数,
- \"context_recall\": 分数,
- \"faithfulness\": 分数,
- \"answer_relevancy\": 分数
-}}
-
-下面是问题:
-{question}
-
-下面是原始片段:
-{meta_chunk}
-
-下面是检索到的片段:
-{chunk}
-
-下面是标准答案:
-{answer}
-
-下面是答案:
-{answer_text}
-"
diff --git a/test/tools/similar_cal_tool.py b/test/tools/similar_cal_tool.py
deleted file mode 100644
index 56319a42ae22b151ec3735d81c0607063509278d..0000000000000000000000000000000000000000
--- a/test/tools/similar_cal_tool.py
+++ /dev/null
@@ -1,158 +0,0 @@
-import jieba
-import jieba.analyse
-import synonyms
-
-class Similar_cal_tool:
- with open('./tools/stopwords.txt', 'r', encoding='utf-8') as f:
- stopwords = set(f.read().splitlines())
-
- @staticmethod
- def normalized_scores(scores):
- min_score = None
- max_score = None
- for score in scores:
- if min_score is None:
- min_score = score
- else:
- min_score = min(min_score, score)
- if max_score is None:
- max_score = score
- else:
- max_score = max(max_score, score)
- if min_score == max_score:
- for i in range(len(scores)):
- scores[i] = 1
- else:
- for i in range(len(scores)):
- scores[i] = (scores[i]-min_score)/(max_score-min_score)
- return scores
-
- @staticmethod
- def filter_stop_words(text):
- words = jieba.lcut(text)
- filtered_words = [word for word in words if word not in Similar_cal_tool.stopwords]
- text = ''.join(filtered_words)
- return text
-
- @staticmethod
- def extract_keywords_sorted(text, topK=10):
- keywords = jieba.analyse.textrank(text, topK=topK, withWeight=False)
- return keywords
-
- @staticmethod
- def get_synonyms_score_dict(word):
- try:
- syns, scores = synonyms.nearby(word)
- scores = Similar_cal_tool.normalized_scores(scores)
- syns_scores_dict = {}
- for syn, score in tuple(syns, scores):
- syns_scores_dict[syn] = score
- return syns_scores_dict
- except:
- return {word: 1}
-
- @staticmethod
- def text_to_keywords(text):
- words = jieba.lcut(text)
- if len(set(words)) <64:
- return words
- topK = 5
- lv = 64
- while lv < len(words):
- topK *= 2
- lv *= 2
- keywords_sorted = Similar_cal_tool.extract_keywords_sorted(text, topK)
- keywords_sorted_set = set(keywords_sorted)
- new_words = []
- for word in words:
- if word in keywords_sorted_set:
- new_words.append(word)
- return new_words
- @staticmethod
- def cal_syns_word_score(word, syns_scores_dict):
- if word not in syns_scores_dict:
- return 0
- return syns_scores_dict[word]
- @staticmethod
- def longest_common_subsequence(str1, str2):
- words1 = Similar_cal_tool.text_to_keywords(str1)
- words2 = Similar_cal_tool.text_to_keywords(str2)
- m, n = len(words1), len(words2)
- if m == 0 and n == 0:
- return 1
- if m == 0:
- return 0
- if n == 0:
- return 0
- dp = [[0]*(n+1) for _ in range(m+1)]
- syns_scores_dicts_1 = []
- syns_scores_dicts_2 = []
- for word in words1:
- syns_scores_dicts_1.append(Similar_cal_tool.get_synonyms_score_dict(word))
- for word in words2:
- syns_scores_dicts_2.append(Similar_cal_tool.get_synonyms_score_dict(word))
-
- for i in range(1, m+1):
- for j in range(1, n+1):
- dp[i][j] = max(dp[i-1][j], dp[i][j-1])
- dp[i][j] = dp[i-1][j-1] + (Similar_cal_tool.cal_syns_word_score(words1[i-1], syns_scores_dicts_2[j-1]
- )+Similar_cal_tool.cal_syns_word_score(words2[j-1], syns_scores_dicts_1[i-1]))
-
- return dp[m][n]/(2*min(m,n))
-
- def jaccard_distance(str1, str2):
- words1 = set(Similar_cal_tool.text_to_keywords(str1))
- words2 = set(Similar_cal_tool.text_to_keywords(str2))
- m, n = len(words1), len(words2)
- if m == 0 and n == 0:
- return 1
- if m == 0:
- return 0
- if n == 0:
- return 0
- syns_scores_dict_1 = {}
- syns_scores_dict_2 = {}
- for word in words1:
- tmp_dict=Similar_cal_tool.get_synonyms_score_dict(word)
- for key,val in tmp_dict.items():
- syns_scores_dict_1[key]=max(syns_scores_dict_1.get(key,0),val)
- for word in words2:
- tmp_dict=Similar_cal_tool.get_synonyms_score_dict(word)
- for key,val in tmp_dict.items():
- syns_scores_dict_2[key]=max(syns_scores_dict_2.get(key,0),val)
- sum=0
- for word in words1:
- sum+=Similar_cal_tool.cal_syns_word_score(word,syns_scores_dict_2)
- for word in words2:
- sum+=Similar_cal_tool.cal_syns_word_score(word,syns_scores_dict_2)
- return sum/(len(words1)+len(words2))
- def levenshtein_distance(str1, str2):
- words1 = Similar_cal_tool.text_to_keywords(str1)
- words2 = Similar_cal_tool.text_to_keywords(str2)
- m, n = len(words1), len(words2)
- if m == 0 and n == 0:
- return 1
- if m == 0:
- return 0
- if n == 0:
- return 0
- dp = [[0]*(n+1) for _ in range(m+1)]
- syns_scores_dicts_1 = []
- syns_scores_dicts_2 = []
- for word in words1:
- syns_scores_dicts_1.append(Similar_cal_tool.get_synonyms_score_dict(word))
- for word in words2:
- syns_scores_dicts_2.append(Similar_cal_tool.get_synonyms_score_dict(word))
- dp = [[0 for _ in range(n + 1)] for _ in range(m + 1)]
-
- for i in range(m + 1):
- dp[i][0] = i
- for j in range(n + 1):
- dp[0][j] = j
-
- for i in range(1, m + 1):
- for j in range(1, n + 1):
- dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
- dp[i][j] = min(dp[i][j],dp[i - 1][j - 1]+1-((Similar_cal_tool.cal_syns_word_score(words1[i-1], syns_scores_dicts_2[j-1]
- )+Similar_cal_tool.cal_syns_word_score(words2[j-1], syns_scores_dicts_1[i-1])))/2)
- return 1-dp[m][n]/(m+n)
diff --git "a/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" "b/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm"
deleted file mode 100644
index dd3b2489e2a5a895fa19daacfd79623f6d78e4e1..0000000000000000000000000000000000000000
Binary files "a/test/witchainD\346\265\213\350\257\225\346\214\207\345\257\274.docm" and /dev/null differ