diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/database_manager.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/database_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e77808c4df807b3e0b0c6e9520bc1a4069129439 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/database_manager.py @@ -0,0 +1,257 @@ +""" +数据库操作类 - 使用 SQLAlchemy ORM +""" +import os +import struct +import uuid +from typing import List, Optional, Dict, Any +from datetime import datetime +import logging +from sqlalchemy import create_engine, text, inspect +from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.exc import SQLAlchemyError + +from base.models import Base, KnowledgeBase, Document, Chunk +from base.config import get_embedding_vector_dimension +from base.manager.document_manager import DocumentManager +import sqlite_vec + +logger = logging.getLogger(__name__) + + +class Database: + """SQLite 数据库操作类 - 使用 SQLAlchemy ORM""" + + def __init__(self, db_path: str = "knowledge_base.db"): + """ + 初始化数据库连接 + :param db_path: 数据库文件路径 + """ + db_dir = os.path.dirname(os.path.abspath(db_path)) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir, exist_ok=True) + + self.db_path = os.path.abspath(db_path) + self.engine = create_engine( + f'sqlite:///{self.db_path}', + echo=False, + connect_args={'check_same_thread': False} + ) + self.SessionLocal = sessionmaker(bind=self.engine, autocommit=False, autoflush=False) + self._init_database() + + def _init_database(self): + """初始化数据库表结构""" + try: + # 创建所有表 + Base.metadata.create_all(self.engine) + + # 加载 sqlite-vec 扩展并创建 FTS5 和 vec_index 表 + with self.engine.begin() as conn: + # 创建 FTS5 虚拟表(需要使用原生 SQL) + conn.execute(text(""" + CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + id UNINDEXED, + content, + content_rowid=id + ) + """)) + + # 加载 sqlite-vec 扩展 + try: + raw_conn = conn.connection.dbapi_connection + raw_conn.enable_load_extension(True) + sqlite_vec.load(raw_conn) + raw_conn.enable_load_extension(False) + except Exception as e: + logger.warning(f"加载 sqlite-vec 扩展失败: {e}") + + # 创建 vec_index 虚拟表 + try: + vector_dim = get_embedding_vector_dimension() + conn.execute(text(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_index USING vec0( + embedding float[{vector_dim}] + ) + """)) + except Exception as e: + logger.warning(f"创建 vec_index 表失败: {e}") + except Exception as e: + logger.exception(f"[Database] 初始化数据库失败: {e}") + raise e + + def get_session(self) -> Session: + """获取数据库会话""" + return self.SessionLocal() + + def get_connection(self): + """ + 获取原始数据库连接(用于特殊操作,如 FTS5 和 vec_index) + 注意:此方法保留以兼容现有代码,但推荐使用 get_session() + 返回一个上下文管理器,使用后会自动关闭 + """ + return self.engine.connect() + + def add_knowledge_base(self, kb_id: str, name: str, chunk_size: int, + embedding_model: Optional[str] = None, + embedding_endpoint: Optional[str] = None, + embedding_api_key: Optional[str] = None) -> bool: + """添加知识库""" + session = self.get_session() + try: + kb = KnowledgeBase( + id=kb_id, + name=name, + chunk_size=chunk_size, + embedding_model=embedding_model, + embedding_endpoint=embedding_endpoint, + embedding_api_key=embedding_api_key + ) + session.add(kb) + session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[Database] 添加知识库失败: {e}") + session.rollback() + return False + finally: + session.close() + + def get_knowledge_base(self, kb_name: str) -> Optional[KnowledgeBase]: + """获取知识库""" + session = self.get_session() + try: + return session.query(KnowledgeBase).filter_by(name=kb_name).first() + finally: + session.close() + + def delete_knowledge_base(self, kb_id: str) -> bool: + """删除知识库(级联删除相关文档和chunks)""" + session = self.get_session() + try: + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if kb: + session.delete(kb) + session.commit() + return True + return False + except SQLAlchemyError as e: + logger.exception(f"[Database] 删除知识库失败: {e}") + session.rollback() + return False + finally: + session.close() + + def list_knowledge_bases(self) -> List[KnowledgeBase]: + """列出所有知识库""" + session = self.get_session() + try: + return session.query(KnowledgeBase).order_by(KnowledgeBase.created_at.desc()).all() + finally: + session.close() + + def import_database(self, source_db_path: str) -> tuple[int, int]: + """ + 导入数据库,将其中的内容合并到当前数据库 + + :param source_db_path: 源数据库文件路径 + :return: (imported_kb_count, imported_doc_count) + """ + source_db = Database(source_db_path) + source_session = source_db.get_session() + + try: + # 读取源数据库的知识库 + source_kbs = source_session.query(KnowledgeBase).all() + if not source_kbs: + return 0, 0 + + # 读取源数据库的文档 + source_docs = source_session.query(Document).all() + + # 合并到当前数据库 + target_session = self.get_session() + + try: + imported_kb_count = 0 + imported_doc_count = 0 + + for source_kb in source_kbs: + # 检查知识库是否已存在,如果存在则生成唯一名称 + kb_name = source_kb.name + existing_kb = self.get_knowledge_base(kb_name) + if existing_kb: + # 生成唯一名称 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + counter = 1 + unique_kb_name = f"{kb_name}_{timestamp}" + while self.get_knowledge_base(unique_kb_name): + unique_kb_name = f"{kb_name}_{timestamp}_{counter}" + counter += 1 + kb_name = unique_kb_name + + # 导入知识库 + new_kb_id = str(uuid.uuid4()) + if self.add_knowledge_base(new_kb_id, kb_name, source_kb.chunk_size, + source_kb.embedding_model, source_kb.embedding_endpoint, + source_kb.embedding_api_key): + imported_kb_count += 1 + + # 导入该知识库下的文档 + kb_docs = [doc for doc in source_docs if doc.kb_id == source_kb.id] + manager = DocumentManager(target_session) + + for source_doc in kb_docs: + # 检查文档是否已存在,如果存在则生成唯一名称 + doc_name = source_doc.name + existing_doc = manager.get_document(new_kb_id, doc_name) + if existing_doc: + # 生成唯一名称 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 分离文件名和扩展名 + if '.' in doc_name: + name_part, ext_part = doc_name.rsplit('.', 1) + unique_doc_name = f"{name_part}_{timestamp}.{ext_part}" + else: + unique_doc_name = f"{doc_name}_{timestamp}" + + # 如果新名称仍然存在,继续添加后缀 + counter = 1 + final_doc_name = unique_doc_name + while manager.get_document(new_kb_id, final_doc_name): + if '.' in doc_name: + name_part, ext_part = doc_name.rsplit('.', 1) + final_doc_name = f"{name_part}_{timestamp}_{counter}.{ext_part}" + else: + final_doc_name = f"{doc_name}_{timestamp}_{counter}" + counter += 1 + doc_name = final_doc_name + + # 导入文档 + new_doc_id = str(uuid.uuid4()) + if manager.add_document(new_doc_id, new_kb_id, doc_name, + source_doc.file_path, source_doc.file_type, + source_doc.content, source_doc.chunk_size): + imported_doc_count += 1 + + # 导入chunks(包含向量) + source_chunks = source_session.query(Chunk).filter_by(doc_id=source_doc.id).all() + for source_chunk in source_chunks: + new_chunk_id = str(uuid.uuid4()) + # 提取向量(如果存在) + embedding = None + if source_chunk.embedding: + embedding_bytes = source_chunk.embedding + if len(embedding_bytes) > 0 and len(embedding_bytes) % 4 == 0: + embedding = list(struct.unpack(f'{len(embedding_bytes)//4}f', embedding_bytes)) + + manager.add_chunk(new_chunk_id, new_doc_id, source_chunk.content, + source_chunk.tokens, source_chunk.chunk_index, embedding) + return imported_kb_count, imported_doc_count + finally: + target_session.close() + finally: + source_session.close() + source_db = None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/document_manager.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/document_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5186ee241584f7dc8b8d5cc6a6fc2d4c40f02238 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/manager/document_manager.py @@ -0,0 +1,394 @@ +""" +文档操作模块 - 使用 SQLAlchemy ORM +""" +import os +import struct +import uuid +import asyncio +from typing import List, Optional, Tuple +from datetime import datetime +import logging +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.exc import SQLAlchemyError + +from base.models import Document, Chunk +from base.embedding import Embedding +from base.parser.parser import Parser +from base.token_tool import TokenTool +import jieba + +logger = logging.getLogger(__name__) + + +class DocumentManager: + """文档操作管理器""" + + def __init__(self, session: Session): + """ + 初始化文档管理器 + :param session: 数据库会话 + """ + self.session = session + + def add_document(self, doc_id: str, kb_id: str, name: str, file_path: str, + file_type: str, content: Optional[str] = None, chunk_size: Optional[int] = None) -> bool: + """添加文档""" + try: + document = Document( + id=doc_id, + kb_id=kb_id, + name=name, + file_path=file_path, + file_type=file_type, + content=content, + chunk_size=chunk_size, + updated_at=datetime.now() + ) + self.session.add(document) + self.session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 添加文档失败: {e}") + self.session.rollback() + return False + + def delete_document(self, kb_id: str, doc_name: str) -> bool: + """删除文档(级联删除相关chunks)""" + try: + doc = self.session.query(Document).filter_by(kb_id=kb_id, name=doc_name).first() + if doc: + self.session.delete(doc) + self.session.commit() + return True + return False + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 删除文档失败: {e}") + self.session.rollback() + return False + + def get_document(self, kb_id: str, doc_name: str) -> Optional[Document]: + """获取文档""" + return self.session.query(Document).filter_by(kb_id=kb_id, name=doc_name).first() + + def list_documents_by_kb(self, kb_id: str) -> List[Document]: + """列出知识库下的所有文档""" + return self.session.query(Document).filter_by(kb_id=kb_id).order_by(Document.created_at.desc()).all() + + def add_chunk(self, chunk_id: str, doc_id: str, content: str, tokens: int, chunk_index: int, + embedding: Optional[List[float]] = None) -> bool: + """添加 chunk(可包含向量)""" + try: + embedding_bytes = None + if embedding: + embedding_bytes = struct.pack(f'{len(embedding)}f', *embedding) + + chunk = Chunk( + id=chunk_id, + doc_id=doc_id, + content=content, + tokens=tokens, + chunk_index=chunk_index, + embedding=embedding_bytes + ) + self.session.add(chunk) + self.session.flush() + + # 添加 FTS5 索引(需要使用原生 SQL) + fts_content = self._prepare_fts_content(content) + self.session.execute(text(""" + INSERT INTO chunks_fts (id, content) + VALUES (:chunk_id, :content) + """), {"chunk_id": chunk_id, "content": fts_content}) + + # 检查并更新 vec_index(需要使用原生 SQL) + if embedding_bytes: + conn = self.session.connection() + result = conn.execute(text(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='vec_index' + """)) + if result.fetchone(): + result = conn.execute(text(""" + SELECT rowid FROM chunks WHERE id = :chunk_id + """), {"chunk_id": chunk_id}) + row = result.fetchone() + if row: + vec_rowid = row[0] + # 先删除可能存在的旧记录,避免 UNIQUE constraint 冲突 + conn.execute(text(""" + DELETE FROM vec_index WHERE rowid = :rowid + """), {"rowid": vec_rowid}) + # 然后插入新记录 + conn.execute(text(""" + INSERT INTO vec_index(rowid, embedding) + VALUES (:rowid, :embedding) + """), {"rowid": vec_rowid, "embedding": embedding_bytes}) + + self.session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 添加chunk失败: {e}") + self.session.rollback() + return False + + def _prepare_fts_content(self, content: str) -> str: + """ + 准备 FTS5 内容(对中文进行 jieba 分词) + :param content: 原始内容 + :return: 分词后的内容(用空格连接) + """ + try: + words = jieba.cut(content) + words = [word.strip() for word in words if word.strip()] + return ' '.join(words) + except Exception: + return content + + def update_chunk_embedding(self, chunk_id: str, embedding: List[float]) -> bool: + """更新 chunk 的向量""" + try: + embedding_bytes = struct.pack(f'{len(embedding)}f', *embedding) + + chunk = self.session.query(Chunk).filter_by(id=chunk_id).first() + if not chunk: + return False + + chunk.embedding = embedding_bytes + self.session.flush() + + # 检查并更新 vec_index(需要使用原生 SQL) + conn = self.session.connection() + result = conn.execute(text(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='vec_index' + """)) + if result.fetchone(): + result = conn.execute(text(""" + SELECT rowid FROM chunks WHERE id = :chunk_id + """), {"chunk_id": chunk_id}) + row = result.fetchone() + if row: + vec_rowid = row[0] + # 先删除可能存在的旧记录,避免 UNIQUE constraint 冲突 + conn.execute(text(""" + DELETE FROM vec_index WHERE rowid = :rowid + """), {"rowid": vec_rowid}) + # 然后插入新记录 + conn.execute(text(""" + INSERT INTO vec_index(rowid, embedding) + VALUES (:rowid, :embedding) + """), {"rowid": vec_rowid, "embedding": embedding_bytes}) + + self.session.commit() + return True + except SQLAlchemyError as e: + logger.exception(f"[DocumentManager] 更新chunk向量失败: {e}") + self.session.rollback() + return False + + def delete_document_chunks(self, doc_id: str) -> None: + """删除文档的所有chunks""" + chunks = self.session.query(Chunk).filter_by(doc_id=doc_id).all() + conn = self.session.connection() + for chunk in chunks: + # 删除FTS5索引 + conn.execute(text(""" + DELETE FROM chunks_fts WHERE id = :chunk_id + """), {"chunk_id": chunk.id}) + # 删除向量索引(如果chunk有向量) + if chunk.embedding: + result = conn.execute(text(""" + SELECT rowid FROM chunks WHERE id = :chunk_id + """), {"chunk_id": chunk.id}) + row = result.fetchone() + if row: + conn.execute(text(""" + DELETE FROM vec_index WHERE rowid = :rowid + """), {"rowid": row[0]}) + # 删除chunk + self.session.delete(chunk) + self.session.commit() + + def update_document_content(self, doc_id: str, content: str, chunk_size: int) -> None: + """更新文档的content和chunk_size""" + doc = self.session.query(Document).filter_by(id=doc_id).first() + if doc: + doc.chunk_size = chunk_size + doc.content = content + doc.updated_at = datetime.now() + self.session.commit() + + +def _generate_unique_name(base_name: str, check_exists_func) -> str: + """ + 生成唯一名称,如果已存在则添加时间戳 + + :param base_name: 基础名称 + :param check_exists_func: 检查是否存在的函数,接受名称参数,返回是否存在 + :return: 唯一名称 + """ + if not check_exists_func(base_name): + return base_name + + # 如果已存在,添加时间戳 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # 分离文件名和扩展名 + if '.' in base_name: + name_part, ext_part = base_name.rsplit('.', 1) + new_name = f"{name_part}_{timestamp}.{ext_part}" + else: + new_name = f"{base_name}_{timestamp}" + + # 如果新名称仍然存在,继续添加后缀 + counter = 1 + final_name = new_name + while check_exists_func(final_name): + if '.' in base_name: + name_part, ext_part = base_name.rsplit('.', 1) + final_name = f"{name_part}_{timestamp}_{counter}.{ext_part}" + else: + final_name = f"{base_name}_{timestamp}_{counter}" + counter += 1 + + return final_name + + +async def import_document(session: Session, kb_id: str, file_path: str, + chunk_size: int) -> Tuple[bool, str, Optional[dict]]: + """ + 导入文档(异步) + + :param session: 数据库会话 + :param kb_id: 知识库ID + :param file_path: 文件路径 + :param chunk_size: chunk大小 + :return: (success, message, data) + """ + try: + doc_name = os.path.basename(file_path) + content = Parser.parse(file_path) + if not content: + return False, "文档解析失败", None + + chunks = TokenTool.split_content_to_chunks(content, chunk_size) + if not chunks: + return False, "文档内容为空", None + + manager = DocumentManager(session) + + # 检查文档是否已存在,如果存在则生成唯一名称 + def check_doc_exists(name: str) -> bool: + return manager.get_document(kb_id, name) is not None + + unique_doc_name = _generate_unique_name(doc_name, check_doc_exists) + + doc_id = str(uuid.uuid4()) + file_type = file_path.lower().split('.')[-1] + + if not manager.add_document(doc_id, kb_id, unique_doc_name, file_path, file_type, content, chunk_size): + return False, "添加文档失败", None + + chunk_ids = [] + chunk_data = [] + + # 先收集所有chunk数据 + for idx, chunk_content in enumerate(chunks): + chunk_id = str(uuid.uuid4()) + tokens = TokenTool.get_tokens(chunk_content) + chunk_data.append((chunk_id, chunk_content, tokens, idx)) + + # 批量生成向量(异步) + embeddings_list = [None] * len(chunk_data) + if Embedding.is_configured() and chunk_data: + try: + chunk_contents = [content for _, content, _, _ in chunk_data] + embeddings_list = await Embedding.vectorize_embeddings_batch(chunk_contents, max_concurrent=5) + except Exception as e: + logger.warning(f"批量生成向量失败: {e}") + + # 添加chunks(包含向量) + for (chunk_id, chunk_content, tokens, idx), embedding in zip(chunk_data, embeddings_list): + if manager.add_chunk(chunk_id, doc_id, chunk_content, tokens, idx, embedding): + chunk_ids.append(chunk_id) + + return True, f"成功导入文档,共 {len(chunk_ids)} 个 chunks", { + "doc_id": doc_id, + "doc_name": unique_doc_name, + "original_name": doc_name if unique_doc_name != doc_name else None, + "chunk_count": len(chunk_ids), + "file_path": file_path + } + except Exception as e: + logger.exception(f"[import_document] 导入文档失败: {e}") + return False, "导入文档失败", None + + +async def update_document(session: Session, kb_id: str, doc_name: str, chunk_size: int) -> Tuple[bool, str, Optional[dict]]: + """ + 更新文档的chunk_size并重新解析(异步) + + :param session: 数据库会话 + :param kb_id: 知识库ID + :param doc_name: 文档名称 + :param chunk_size: 新的chunk大小 + :return: (success, message, data) + """ + try: + manager = DocumentManager(session) + doc = manager.get_document(kb_id, doc_name) + if not doc: + return False, f"文档 '{doc_name}' 不存在", None + + # 删除旧文档的所有chunks + manager.delete_document_chunks(doc.id) + + # 重新解析文档 + if not doc.file_path or not os.path.exists(doc.file_path): + return False, "文档文件不存在", None + + content = Parser.parse(doc.file_path) + if not content: + return False, "文档解析失败", None + + chunks = TokenTool.split_content_to_chunks(content, chunk_size) + if not chunks: + return False, "文档内容为空", None + + # 收集所有chunk数据 + chunk_ids = [] + chunk_data = [] + + for idx, chunk_content in enumerate(chunks): + chunk_id = str(uuid.uuid4()) + tokens = TokenTool.get_tokens(chunk_content) + chunk_data.append((chunk_id, chunk_content, tokens, idx)) + + # 批量生成向量(异步) + embeddings_list = [None] * len(chunk_data) + if Embedding.is_configured() and chunk_data: + try: + chunk_contents = [content for _, content, _, _ in chunk_data] + embeddings_list = await Embedding.vectorize_embeddings_batch(chunk_contents, max_concurrent=5) + except Exception as e: + logger.warning(f"批量生成向量失败: {e}") + + # 添加chunks(包含向量) + for (chunk_id, chunk_content, tokens, idx), embedding in zip(chunk_data, embeddings_list): + if manager.add_chunk(chunk_id, doc.id, chunk_content, tokens, idx, embedding): + chunk_ids.append(chunk_id) + + # 更新文档的chunk_size和content + manager.update_document_content(doc.id, content, chunk_size) + + return True, f"成功修改文档,共 {len(chunk_ids)} 个 chunks", { + "doc_id": doc.id, + "doc_name": doc_name, + "chunk_count": len(chunk_ids), + "chunk_size": chunk_size + } + except Exception as e: + logger.exception(f"[update_document] 修改文档失败: {e}") + return False, "修改文档失败", None + diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/keyword.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/keyword.py new file mode 100644 index 0000000000000000000000000000000000000000..d1994d0b3be3ae46aff3d84961d1870891ba3e5c --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/keyword.py @@ -0,0 +1,92 @@ +""" +关键词检索模块 - 使用 SQLAlchemy +""" +import logging +from typing import List, Dict, Any, Optional +from sqlalchemy import text +import jieba + +logger = logging.getLogger(__name__) + + +def _prepare_fts_query(query: str) -> str: + """ + 准备 FTS5 查询 + :param query: 原始查询文本 + :return: FTS5 查询字符串 + """ + def escape_fts_word(word: str) -> str: + # 包含以下任意字符时,整体作为短语用双引号包裹,避免触发 FTS5 语法解析 + # 特别是 '%' 在 FTS5 MATCH 语法中会导致 "syntax error near '%'" + special_chars = [ + '"', "'", '(', ')', '*', ':', '?', '+', '-', '|', '&', + '{', '}', '[', ']', '^', '$', '\\', '/', '!', '~', ';', + ',', '.', ' ', '%' + ] + if any(char in word for char in special_chars): + escaped_word = word.replace('"', '""') + return f'"{escaped_word}"' + return word + + try: + words = jieba.cut(query) + words = [word.strip() for word in words if word.strip()] + if not words: + return escape_fts_word(query) + + escaped_words = [escape_fts_word(word) for word in words] + fts_query = ' OR '.join(escaped_words) + return fts_query + except Exception: + return escape_fts_word(query) + + +def search_by_keyword(conn, query: str, top_k: int = 5, doc_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """ + 关键词检索(FTS5,使用 jieba 对中文进行分词) + :param conn: 数据库连接对象(SQLAlchemy Connection) + :param query: 查询文本 + :param top_k: 返回数量 + :param doc_ids: 可选的文档ID列表,用于过滤 + :return: chunk 列表 + """ + try: + fts_query = _prepare_fts_query(query) + + params = {"fts_query": fts_query, "top_k": top_k} + where_clause = "WHERE chunks_fts MATCH :fts_query" + + if doc_ids: + placeholders = ','.join([f':doc_id_{i}' for i in range(len(doc_ids))]) + for i, doc_id in enumerate(doc_ids): + params[f'doc_id_{i}'] = doc_id + where_clause += f" AND c.doc_id IN ({placeholders})" + + sql = f""" + SELECT c.id, c.doc_id, c.content, c.tokens, c.chunk_index, + d.name as doc_name, + chunks_fts.rank + FROM chunks_fts + JOIN chunks c ON c.id = chunks_fts.id + JOIN documents d ON d.id = c.doc_id + {where_clause} + ORDER BY chunks_fts.rank + LIMIT :top_k + """ + result = conn.execute(text(sql), params) + + results = [] + for row in result: + results.append({ + 'id': row.id, + 'doc_id': row.doc_id, + 'content': row.content, + 'tokens': row.tokens, + 'chunk_index': row.chunk_index, + 'doc_name': row.doc_name, + 'score': row.rank if row.rank is not None else 0.0 + }) + return results + except Exception as e: + logger.exception(f"[KeywordSearch] 关键词检索失败: {e}") + return [] diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/vector.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..179423caabe42a6fca900affe490cdd46fcd5036 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/vector.py @@ -0,0 +1,67 @@ +""" +向量检索模块 - 使用 SQLAlchemy +""" +import logging +import struct +from typing import List, Dict, Any, Optional +from sqlalchemy import text + +logger = logging.getLogger(__name__) + + +def search_by_vector(conn, query_vector: List[float], top_k: int = 5, doc_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """ + 向量检索 + :param conn: 数据库连接对象(SQLAlchemy Connection) + :param query_vector: 查询向量 + :param top_k: 返回数量 + :param doc_ids: 可选的文档ID列表,用于过滤 + :return: chunk 列表 + """ + try: + # 检查 vec_index 表是否存在 + result = conn.execute(text(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='vec_index' + """)) + if not result.fetchone(): + return [] + + query_vector_bytes = struct.pack(f'{len(query_vector)}f', *query_vector) + + params = {"query_vector": query_vector_bytes, "top_k": top_k} + where_clause = "WHERE v.embedding MATCH :query_vector AND k = :top_k" + + if doc_ids: + placeholders = ','.join([f':doc_id_{i}' for i in range(len(doc_ids))]) + for i, doc_id in enumerate(doc_ids): + params[f'doc_id_{i}'] = doc_id + where_clause += f" AND c.doc_id IN ({placeholders})" + + sql = f""" + SELECT c.id, c.doc_id, c.content, c.tokens, c.chunk_index, + d.name as doc_name, + distance + FROM vec_index v + JOIN chunks c ON c.rowid = v.rowid + JOIN documents d ON d.id = c.doc_id + {where_clause} + ORDER BY distance + """ + result = conn.execute(text(sql), params) + + results = [] + for row in result: + results.append({ + 'id': row.id, + 'doc_id': row.doc_id, + 'content': row.content, + 'tokens': row.tokens, + 'chunk_index': row.chunk_index, + 'doc_name': row.doc_name, + 'score': row.distance + }) + return results + except Exception as e: + logger.exception(f"[VectorSearch] 向量检索失败: {e}") + return [] diff --git a/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/weighted_keyword_and_vector_search.py b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/weighted_keyword_and_vector_search.py new file mode 100644 index 0000000000000000000000000000000000000000..f824151409345d2fa67aa4c593938e0f5dba1495 --- /dev/null +++ b/mcp_center/servers/oe-cli-mcp-server/mcp_tools/rag_tools/base/search/weighted_keyword_and_vector_search.py @@ -0,0 +1,122 @@ +import logging +import asyncio +from typing import List, Dict, Any, Optional +from base.search.keyword import search_by_keyword as keyword_search +from base.search.vector import search_by_vector as vector_search +from base.embedding import Embedding +from base.rerank import Rerank + +logger = logging.getLogger(__name__) + + +async def weighted_keyword_and_vector_search( + conn, + query: str, + top_k: int = 5, + weight_keyword: float = 0.3, + weight_vector: float = 0.7, + doc_ids: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """ + 加权关键词和向量混合检索(异步) + + :param conn: 数据库连接对象(SQLAlchemy Connection) + :param query: 查询文本 + :param top_k: 返回数量 + :param weight_keyword: 关键词搜索权重 + :param weight_vector: 向量搜索权重 + :return: 合并后的 chunk 列表 + """ + try: + # 同时进行关键词和向量搜索,每个获取 2*topk 个结果 + keyword_chunks = [] + vector_chunks = [] + + # 关键词搜索 + try: + keyword_chunks = keyword_search(conn, query, 2 * top_k, doc_ids) + except Exception as e: + logger.warning(f"[WeightedSearch] 关键词检索失败: {e}") + + # 向量搜索(需要 embedding 配置) + if Embedding.is_configured(): + try: + query_vector = await Embedding.vectorize_embedding(query) + if query_vector: + vector_chunks = vector_search(conn, query_vector, 2 * top_k, doc_ids) + except Exception as e: + logger.warning(f"[WeightedSearch] 向量检索失败: {e}") + + # 如果没有结果 + if not keyword_chunks and not vector_chunks: + return [] + + # 归一化并合并结果 + merged_chunks = {} + + # 处理关键词搜索结果 + if keyword_chunks: + # 归一化 rank 分数(rank 越小越好,转换为越大越好) + keyword_scores = [chunk.get('score', 0.0) for chunk in keyword_chunks if chunk.get('score') is not None] + if keyword_scores: + min_rank = min(keyword_scores) + max_rank = max(keyword_scores) + rank_range = max_rank - min_rank + + for chunk in keyword_chunks: + chunk_id = chunk['id'] + rank = chunk.get('score', 0.0) + # 转换为越大越好的分数(归一化到 0-1) + if rank_range > 0: + normalized_score = 1.0 - ((rank - min_rank) / rank_range) + else: + normalized_score = 1.0 + weighted_score = normalized_score * weight_keyword + + if chunk_id not in merged_chunks: + merged_chunks[chunk_id] = chunk.copy() + merged_chunks[chunk_id]['score'] = weighted_score + else: + merged_chunks[chunk_id]['score'] += weighted_score + + # 处理向量搜索结果 + if vector_chunks: + # 归一化 distance 分数(distance 越小越好,转换为越大越好) + vector_scores = [chunk.get('score', 0.0) for chunk in vector_chunks if chunk.get('score') is not None] + if vector_scores: + min_distance = min(vector_scores) + max_distance = max(vector_scores) + distance_range = max_distance - min_distance + + for chunk in vector_chunks: + chunk_id = chunk['id'] + distance = chunk.get('score', 0.0) + # 转换为越大越好的分数(归一化到 0-1) + if distance_range > 0: + normalized_score = 1.0 - ((distance - min_distance) / distance_range) + else: + normalized_score = 1.0 + weighted_score = normalized_score * weight_vector + + if chunk_id not in merged_chunks: + merged_chunks[chunk_id] = chunk.copy() + merged_chunks[chunk_id]['score'] = weighted_score + else: + merged_chunks[chunk_id]['score'] += weighted_score + + # 转换为列表并按分数排序 + merged_list = list(merged_chunks.values()) + merged_list.sort(key=lambda x: x.get('score', 0.0), reverse=True) + + # Rerank + reranked_chunks = Rerank.rerank_chunks(merged_list, query) + + # 取前 top_k 个 + final_chunks = reranked_chunks[:top_k] + + return final_chunks + + except Exception as e: + logger.exception(f"[WeightedSearch] 混合检索失败: {e}") + return [] +