1 Star 0 Fork 2

xulei8/DaqiRAG

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
app.py 11.43 KB
一键复制 编辑 原始数据 按行查看 历史
xulei8 提交于 21天前 . 移动建表操作
from flask import Flask, render_template, request, jsonify, url_for
from werkzeug.utils import secure_filename
from docx import Document
import chromadb
import os
import numpy as np
import requests
from sentence_transformers import SentenceTransformer
import logging
from dotenv import load_dotenv
from sentence_transformers import CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from funcs import merge_small_paragraphs, create_knowledge_base_table,create_files_table
import sqlite3
from sqlite3 import Error
# 创建files表
# 初始化时创建表
create_files_table()
# 加载.env文件中的环境变量
load_dotenv()
# 从环境变量获取API密钥和代理设置
siliconflow_api_key = os.getenv('SILICONFLOW_API_KEY')
if not siliconflow_api_key:
raise ValueError("API密钥未设置")
if len (siliconflow_api_key) < 20:
raise ValueError("API密钥无效")
app = Flask(__name__, static_folder='static', static_url_path='/static')
app.jinja_env.globals.update(zip=zip)
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max-limit
# 配置日志处理器
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)
# 确保上传目录存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 获取代理设置
http_proxy = os.getenv('HTTP_PROXY')
if http_proxy and len(http_proxy) > 18:
os.environ['HTTP_PROXY'] = http_proxy
app.logger.info(f'使用HTTP代理: {http_proxy}')
else:
os.environ['HTTP_PROXY'] = ''
app.logger.info('未使用HTTP代理')
https_proxy = os.getenv('HTTPS_PROXY')
if https_proxy and len(https_proxy) > 18:
os.environ['HTTPS_PROXY'] = https_proxy
app.logger.info(f'使用HTTPS代理: {https_proxy}')
else:
os.environ['HTTPS_PROXY'] = ''
app.logger.info('未使用HTTPS代理')
# 加载BAAI/bge-m3嵌入模型
app.logger.info('开始加载BAAI/bge-m3嵌入模型...')
model = SentenceTransformer('BAAI/bge-m3')
app.logger.info('BAAI/bge-m3嵌入模型加载成功')
# 开始加载BAAI/bge-reranker-v2-m3重排模型...
app.logger.info('BAAI/bge-reranker-v2-m开始加载')
reranker = CrossEncoder('BAAI/bge-reranker-v2-m3')
app.logger.info('BAAI/bge-reranker-v2-m3重排模型加载完成')
# 定义嵌入函数
from chromadb import EmbeddingFunction
class MyEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: list):
return [model.encode(text).tolist() for text in input]
embedding_function = MyEmbeddingFunction()
# 初始化ChromaDB客户端
chroma_client = chromadb.PersistentClient(path="./chroma_db")
# 初始化空的collections数组
collections = {}
from chroma_utils import getCollection
import sqlite3
def process_docx(file_path):
doc = Document(file_path)
text_content = []
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text_content.append(paragraph.text.strip())
return text_content
@app.route('/')
def index():
return render_template('index.html')
@app.route('/get_enabled_knowledge_bases', methods=['GET'])
def get_enabled_knowledge_bases():
try:
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('SELECT id, title FROM knowledge_bases WHERE enabled = 1')
knowledge_bases = cursor.fetchall()
conn.close()
return jsonify([{'id': kb[0], 'title': kb[1]} for kb in knowledge_bases])
except Exception as e:
app.logger.error(f'获取启用知识库出错: {str(e)}')
return jsonify({'error': str(e)}), 500
@app.route('/upload_page')
def upload_page():
try:
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('SELECT * FROM files')
columns = [col[0] for col in cursor.description]
uploaded_files = [dict(zip(columns, row)) for row in cursor.fetchall()]
conn.close()
return render_template('upload.html', uploaded_files=uploaded_files)
except Exception as e:
app.logger.error(f'查询已上传文件出错: {str(e)}')
return render_template('upload.html', uploaded_files=[])
# 在Flask应用初始化后添加日志配置
app.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
app.logger.addHandler(handler)
@app.route('/upload', methods=['POST'])
@app.route('/upload/<int:id>', methods=['POST'])
def upload_file(id=None):
if id is None:
id = request.form.get('kb_id')
if id:
id = int(id)
if id is None or id < 1:
raise ValueError('id必须大于1')
app.logger.info('收到文件上传请求')
if 'file' not in request.files:
return jsonify({'error': '没有文件被上传'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '没有选择文件'}), 400
if not file.filename.endswith('.docx'):
return jsonify({'error': '只支持.docx文件'}), 400
filename = file.filename
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
try:
collection = getCollection(id)
# 处理文档内容
text_content = process_docx(file_path)
original_count = len(text_content)
app.logger.info(f'合并前段落数量: {original_count}')
merged_text_content = merge_small_paragraphs(text_content, model)
merged_count = len(merged_text_content)
app.logger.info(f'合并后段落数量: {merged_count}')
text_content = merged_text_content
app.logger.info(f'成功处理文档内容,段落数: {len(text_content)}')
collection.add(
documents=text_content,
ids=[f"{filename}_{i}" for i in range(len(text_content))]
)
app.logger.info(f'文件已存入知识库: {filename}')
try:
# 插入文件记录到数据库
app.logger.info(f'file insert : {filename}')
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('''
INSERT INTO files (file_name, file_type, kb_id, doc_count)
VALUES (?, ?, ?, ?)
''', (filename, 'docx', id, merged_count))
conn.commit()
conn.close()
app.logger.info(f'file insert 2: {filename}')
except Error as e:
app.logger.error(f'插入文件记录到数据库出错: {str(e)}')
finally:
# 清理上传的文件
os.remove(file_path)
return jsonify({'success': True, 'message': '文件上传成功'}), 200
except Exception as e:
app.logger.error(f'文件处理失败: {str(e)}')
return jsonify({'error': str(e)}), 500
@app.route('/search', methods=['POST'])
def search():
query = request.json.get('query')
context_from_frontend = request.json.get('context') # 获取前端传来的上下文
app.logger.info(f'收到搜索请求: "{query}",前端上下文: "{context_from_frontend}"')
id = request.json.get('id')
# 构建大模型提示
# 获取初步检索结果
if id is None:
id = 0
collection = getCollection(id)
initial_results = collection.query(
query_texts=[query],
n_results=10
)
# 使用BAAI/bge-reranker-v2-m3进行重排
pairs = [[query, doc] for doc in initial_results['documents'][0]]
scores = reranker.predict(pairs)
# 根据得分排序并取前5个结果
sorted_indices = scores.argsort()[::-1]
results = {
'documents': [[initial_results['documents'][0][i] for i in sorted_indices[:8]]],
'ids': [[initial_results['ids'][0][i] for i in sorted_indices[:8]]]
}
context = '\n'.join(results['documents'][0])
app.logger.info(f'返回知识库结果数: {len(results["documents"][0])}')
try:
msg = {
'model': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
'messages': [{
'role': 'user',
'content': f'基于以下知识库内容:{context}\n前端上下文: {context_from_frontend}\n\n请回答:{query}'
}]
}
app.logger.info( "msg %s " ,msg )
response = requests.post(
'https://api.siliconflow.cn/v1/chat/completions',
headers={
'Authorization': 'Bearer ' + siliconflow_api_key,
'Content-Type': 'application/json'
},
json=msg
)
response.raise_for_status()
llm_response = response.json()['choices'][0]['message']['content']
app.logger.info( "msg %s " ,llm_response )
except Exception as e:
return jsonify({'error': f'大模型处理失败:{str(e)}'}), 500
return jsonify({
'success': True,
'knowledge_results': results,
'llm_response': llm_response
})
# 从外部模块导入路由函数
from route import register_routes
# 注册额外的路由
register_routes(app)
# 在应用启动时创建知识库表
create_knowledge_base_table()
@app.route('/docmanage')
def docmanage():
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('SELECT * FROM knowledge_bases ORDER BY sort')
knowledge_bases = cursor.fetchall()
conn.close()
return render_template('docmanage.html', knowledge_bases=knowledge_bases)
@app.route('/add_knowledge_base', methods=['POST'])
def add_knowledge_base():
title = request.form.get('title')
remark = request.form.get('remark')
sort = request.form.get('sort')
enabled = 'enabled' in request.form
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('INSERT INTO knowledge_bases (title, remark, sort, enabled) VALUES (?,?,?,?)', (title, remark, sort, enabled))
conn.commit()
conn.close()
return redirect(url_for('docmanage'))
@app.route('/edit_knowledge_base/<int:id>', methods=['GET', 'POST'])
def edit_knowledge_base(id):
if request.method == 'POST':
title = request.form.get('title')
remark = request.form.get('remark')
sort = request.form.get('sort')
enabled = 'enabled' in request.form
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('UPDATE knowledge_bases SET title=?, remark=?, sort=?, enabled=? WHERE id=?', (title, remark, sort, enabled, id))
conn.commit()
conn.close()
return redirect(url_for('docmanage'))
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('SELECT * FROM knowledge_bases WHERE id=?', (id,))
knowledge_base = cursor.fetchone()
conn.close()
return render_template('edit_knowledge_base.html', knowledge_base=knowledge_base)
@app.route('/delete_knowledge_base/<int:id>')
def delete_knowledge_base(id):
conn = sqlite3.connect('rag.db')
cursor = conn.cursor()
cursor.execute('DELETE FROM knowledge_bases WHERE id=?', (id,))
conn.commit()
conn.close()
return redirect(url_for('docmanage'))
from flask import redirect
if __name__ == '__main__':
# 连接到SQLite数据库,如果数据库不存在,将创建一个新的数据库
conn = sqlite3.connect('rag.db')
app.logger.info('成功连接到SQLite数据库rag.db')
# 关闭数据库连接
conn.close()
app.logger.info('已关闭SQLite数据库连接')
app.run(debug=True)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xulei8/rag.git
git@gitee.com:xulei8/rag.git
xulei8
rag
DaqiRAG
new2

搜索帮助