diff --git a/data_chain/apps/base/convertor.py b/data_chain/apps/base/convertor.py index 211f6504e2afcde9c42c8569b3be7f4fcd96fb4d..4fcf6844a05f0ebf3eaef6d2a71e54ace324132e 100644 --- a/data_chain/apps/base/convertor.py +++ b/data_chain/apps/base/convertor.py @@ -377,7 +377,8 @@ class Convertor: 'name': req.kb_name, 'description': req.description, 'tokenizer': req.tokenizer.value, - 'rerank_model': req.rerank_model, + 'rerank_methond': req.rerank_methond.value, + 'rerank_name': req.rerank_name, 'spearating_characters': req.spearating_characters, 'upload_count_limit': req.upload_count_limit, 'upload_size_limit': req.upload_size_limit, @@ -401,7 +402,8 @@ class Convertor: authorName=knowledge_base_entity.author_name, tokenizer=knowledge_base_entity.tokenizer, embeddingModel=knowledge_base_entity.embedding_model, - rerankModel=knowledge_base_entity.rerank_model, + rerankMethod=knowledge_base_entity.rerank_methond, + rerankName=knowledge_base_entity.rerank_name, spearatingCharacters=knowledge_base_entity.spearating_characters, description=knowledge_base_entity.description, docCnt=knowledge_base_entity.doc_cnt, @@ -448,7 +450,8 @@ class Convertor: tokenizer=req.tokenizer.value, description=req.description, embedding_model=req.embedding_model, - rerank_model=req.rerank_model, + rerank_method=req.rerank_methond.value, + rerank_name=req.rerank_name, spearating_characters=req.spearating_characters, upload_count_limit=req.upload_count_limit, upload_size_limit=req.upload_size_limit, diff --git a/data_chain/apps/router/other.py b/data_chain/apps/router/other.py index 2bcf71ad240ef69df96161f1d48c31a6917394b9..1b91e4aeb59de7e8919bfa7dfe7d515ae3d078fe 100644 --- a/data_chain/apps/router/other.py +++ b/data_chain/apps/router/other.py @@ -6,12 +6,13 @@ import hashlib from typing import Annotated from uuid import UUID from data_chain.config.config import config -from data_chain.entities.enum import Embedding, Tokenizer, ParseMethod, SearchMethod +from data_chain.entities.enum import Embedding, RerankType, Tokenizer, ParseMethod, SearchMethod from data_chain.entities.response_data import ( LLM, ListLLMMsg, ListLLMResponse, ListEmbeddingResponse, + RerankMethod, ListRerankResponse, ListTokenizerResponse, ListParseMethodResponse, @@ -39,7 +40,8 @@ async def list_llms_by_user_sub( 'MAX_TOKENS': config['MAX_TOKENS'], 'TEMPERATURE': config['TEMPERATURE'] } - config_json = json.dumps(config_params, sort_keys=True, ensure_ascii=False).encode('utf-8') + config_json = json.dumps( + config_params, sort_keys=True, ensure_ascii=False).encode('utf-8') hash_object = hashlib.sha256(config_json) hash_hex = hash_object.hexdigest() llm = LLM( @@ -56,11 +58,17 @@ async def list_embeddings(): embeddings = [config['EMBEDDING_MODEL_NAME']] return ListEmbeddingResponse(result=embeddings) + @router.get('/rerank', response_model=ListRerankResponse, dependencies=[Depends(verify_user)]) async def list_reranks(): - reranks = [config['RERANK_MODEL_NAME']] + aloghrithm_rerank = RerankMethod( + rerankMethod=RerankType.ALGORITHM, rerankerName="jaccard dis reranker") + model_rerank = RerankMethod( + rerankMethod=RerankType(config['RERANK_MODEL_NAME']), rerankerName=config['RERANK_MODEL_NAME']) + reranks = [aloghrithm_rerank, model_rerank] return ListRerankResponse(result=reranks) + @router.get('/tokenizer', response_model=ListTokenizerResponse, dependencies=[Depends(verify_user)]) async def list_tokenizers(): tokenizers = [tokenizer.value for tokenizer in Tokenizer] diff --git a/data_chain/apps/router/role.py b/data_chain/apps/router/role.py index 1b3f31ced64cf2853ed1976424d5a10f44879085..245d899cbd14a506bea9dc74407a62836369ed90 100644 --- a/data_chain/apps/router/role.py +++ b/data_chain/apps/router/role.py @@ -53,7 +53,7 @@ async def list_roles( if not (await TeamService.validate_user_action_in_team(user_sub, req.team_id, action)): raise Exception('用户没有权限查看该团队角色') list_role_msg = await RoleService.list_roles(req) - await TeamService.add_team_msg(user_sub, req.team_id, IdType.TEAM, '查看了角色列表') + await TeamService.add_team_msg(user_sub, req.team_id, IdType.TEAM, '查看了角色列表', 'role list viewed') return ListRoleResponse(message='角色列表获取成功', result=list_role_msg) @@ -65,7 +65,7 @@ async def create_role(user_sub: Annotated[str, Depends(get_user_sub)], if not (await TeamService.validate_user_action_in_team(user_sub, team_id, action)): raise Exception('用户没有权限创建该团队角色') role_id = await RoleService.create_role(team_id, req) - await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '创建了{teamName}的角色{roleName}') + await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '创建了角色{roleName}', 'created role {roleName}') return CreateRoleResponse(message='角色创建成功', result=role_id) @@ -78,7 +78,7 @@ async def update_role_by_role_id( if not (await RoleService.validate_user_action_to_role(user_sub, role_id, action)): raise Exception('用户没有权限修改该团队角色') role_id = await RoleService.update_role(role_id, req) - await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '更新了{teamName}的角色{roleName}') + await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '更新了角色{roleName}', 'updated role {roleName}') return UpdateRoleResponse(message='角色更新成功', result=role_id) @@ -90,5 +90,5 @@ async def delete_role_by_role_ids( if not (await RoleService.validate_user_action_to_role(user_sub, role_id, action)): raise Exception('用户没有权限删除该团队角色') role_id = await RoleService.delete_role(role_id) - await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '删除了{teamName}的角色{roleName}') + await TeamService.add_team_msg(user_sub, role_id, IdType.ROLE, '删除了角色{roleName}', 'deleted role {roleName}') return DeleteRoleResponse(message='角色删除成功', result=role_id) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index a533a03cd8c9d0784ffc35ace9a26b7039b6bc82..f6c51d46281056ae642cce670d5a62298c3cacac 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -86,17 +86,18 @@ class ChunkService: err = f"知识库不存在,知识库ID: {kb_id}" logging.warning("[ChunkService] %s", err) continue - if kb_id!=DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + if kb_id != DEFAULT_KNOWLEDGE_BASE_ID and not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): err = f"用户没有权限访问该知识库,知识库ID: {kb_id}" logging.warning("[ChunkService] %s", err) continue - top_k=req.top_k + top_k = req.top_k if req.is_rerank: - top_k=req.top_k*3 - sub_chunk_entities= await BaseSearcher.search(req.search_method.value, kb_id, req.query, top_k, req.doc_ids, req.banned_ids) + top_k = req.top_k*3 + sub_chunk_entities = await BaseSearcher.search(req.search_method.value, kb_id, req.query, top_k, req.doc_ids, req.banned_ids) if req.is_rerank: - sub_chunk_indexs = await BaseSearcher.rerank(sub_chunk_entities,kb_entity.rerank_model, req.query) - sub_chunk_entities = [sub_chunk_entities[i] for i in sub_chunk_indexs] + sub_chunk_indexs = await BaseSearcher.rerank(sub_chunk_entities, kb_entity.rerank_method, req.query) + sub_chunk_entities = [sub_chunk_entities[i] + for i in sub_chunk_indexs] sub_chunk_entities = sub_chunk_entities[:req.top_k] chunk_entities += sub_chunk_entities except Exception as e: @@ -106,7 +107,7 @@ class ChunkService: if len(chunk_entities) == 0: return SearchChunkMsg(docChunks=[]) if req.is_rerank: - chunk_indexs = await BaseSearcher.rerank(chunk_entities,None, req.query) + chunk_indexs = await BaseSearcher.rerank(chunk_entities, None, req.query) chunk_entities = [chunk_entities[i] for i in chunk_indexs] chunk_entities = chunk_entities[:req.top_k] chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] diff --git a/data_chain/apps/service/team_service.py b/data_chain/apps/service/team_service.py index e540f051fd9056fc350084ba1cae2783f4c9ba84..83c964a9bd235e2795bdb4fb06d5158add1bfe55 100644 --- a/data_chain/apps/service/team_service.py +++ b/data_chain/apps/service/team_service.py @@ -147,10 +147,8 @@ class TeamService: return None team_entity = await TeamManager.get_team_by_id(role_entity.team_id) team_id = team_entity.id - zh_message = zh_message.format( - teamName=team_entity.name, roleName=role_entity.name) - en_message = en_message.format( - teamName=team_entity.name, roleName=role_entity.name) + zh_message = zh_message.format(roleName=role_entity.name) + en_message = en_message.format(roleName=role_entity.name) elif id_type == IdType.USER: team_entity = await TeamManager.get_team_by_id(id) if team_entity is None: diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index 3b2e3d0d1b990b6a46ce4761f5dedd5818b47251..c10a85c5efe778e8bcda060ac7c5dc4424b170c1 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -16,12 +16,11 @@ class EmbeddingType(str, Enum): class RerankType(str, Enum): """rerank 服务的类型""" + ALGORITHM = "algorithm" BAILIAN = "bailian" GUIJILIUDONG = "guijiliudong" VLLM = "vllm" ASSECEND = "assecend" - - class TeamType(str, Enum): """团队类型""" MYCREATED = "mycreated" diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py index 8c4230971d87d743b8ef355eac6ec8a614642567..533212bb03f2050928e1556c4b7ba812f68d7484 100644 --- a/data_chain/entities/request_data.py +++ b/data_chain/entities/request_data.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, validator, constr from data_chain.entities.enum import ( TeamType, Tokenizer, + RerankType, ParseMethod, UserStatus, UserMessageType, @@ -106,8 +107,10 @@ class CreateKnowledgeBaseRequest(BaseModel): tokenizer: Tokenizer = Field(default=Tokenizer.ZH) embedding_model: str = Field( default='', description="知识库使用的embedding模型", alias="embeddingModel") - rerank_model: Optional[str] = Field( - default=None, description="知识库使用的rerank模型", alias="rerankModel") + rerank_methond: RerankType = Field( + default=RerankType.ALGORITHM, description="知识库使用的rerank模型", alias="rerankMethod") + rerank_name: str = Field( + default="jaccard dis reranker", description="知识库使用的rerank模型名称", alias="rerankName") spearating_characters: Optional[str] = Field( default=None, description="知识库分块的分隔符", alias="spearatingCharacters") default_chunk_size: int = Field( @@ -127,8 +130,10 @@ class UpdateKnowledgeBaseRequest(BaseModel): max_length=256, alias="kbName") description: str = Field(default='', max_length=256) tokenizer: Tokenizer = Field(default=Tokenizer.ZH) - rerank_model: Optional[str] = Field( - default=None, description="知识库使用的rerank模型", alias="rerankModel") + rerank_methond: RerankType = Field( + default=RerankType.ALGORITHM, description="知识库使用的rerank模型", alias="rerankMethod") + rerank_name: str = Field( + default="jaccard dis reranker", description="知识库使用的rerank模型名称", alias="rerankName") spearating_characters: Optional[str] = Field( default=None, description="知识库分块的分隔符", alias="spearatingCharacters") default_chunk_size: int = Field( diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index 00d5f168f1b7c440ee183ab978874be6a7cc8ef0..92635fc59573ccf57816e0c4f22b6e7287a47443 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -8,6 +8,7 @@ import uuid from data_chain.entities.enum import ( TeamType, ActionType, + RerankType, Tokenizer, ParseMethod, UserStatus, @@ -155,8 +156,10 @@ class Knowledgebase(BaseModel): author_name: str = Field(description="知识库创建者的用户名", alias="authorName") tokenizer: Tokenizer = Field(description="分词器", alias="tokenizer") embedding_model: str = Field(description="嵌入模型", alias="embeddingModel") - rerank_model: Optional[str] = Field( - default=None, description="rerank模型", alias="rerankModel") + rerank_methond: RerankType = Field( + default=RerankType.ALGORITHM, description="知识库使用的rerank模型", alias="rerankMethod") + rerank_name: str = Field( + default="jaccard dis reranker", description="知识库使用的rerank模型名称", alias="rerankName") spearating_characters: Optional[str] = Field(default=None, description="分隔符", alias="spearatingCharacters") description: str = Field(description="知识库描述", max=150) @@ -767,9 +770,15 @@ class ListEmbeddingResponse(ResponseData): result: list[str] = Field(default=[], description="向量化模型的列表数据结构") +class RerankMethod(BaseModel): + rerank_method: RerankType = Field( + description="重排序模型类型", alias="rerankMethod") + reranker_name: str = Field(description="重排序模型描述", alias="rerankerName") + + class ListRerankResponse(ResponseData): """GET /other/rerank 数据结构""" - result: list[str] = Field(default=[], description="重排序模型的列表数据结构") + result: list[RerankMethod] = Field(default=[], description="重排序模型的列表数据结构") class ListTokenizerResponse(ResponseData): diff --git a/data_chain/parser/handler/deep_pdf_parser.py b/data_chain/parser/handler/deep_pdf_parser.py index 9eee88938ee812569132856498e7a78a16a0243e..82dfa2c6a1054aa9eab98c2fe2471bb344121828 100644 --- a/data_chain/parser/handler/deep_pdf_parser.py +++ b/data_chain/parser/handler/deep_pdf_parser.py @@ -644,6 +644,4 @@ class DeepPdfParser(BaseParser): if node.type == ChunkType.IMAGE: # 处理图片节点 continue - return parse_result -import asyncio -result=asyncio.run(DeepPdfParser.parser("./test.pdf")) \ No newline at end of file + return parse_result \ No newline at end of file diff --git a/data_chain/rag/base_searcher.py b/data_chain/rag/base_searcher.py index 6e44b9fd18d1c345b3d92e91a3758ec367d14922..f6a48771bbb9ae47057142bca724aec74a264055 100644 --- a/data_chain/rag/base_searcher.py +++ b/data_chain/rag/base_searcher.py @@ -8,9 +8,11 @@ from data_chain.apps.base.convertor import Convertor from data_chain.stores.database.database import ChunkEntity from data_chain.parser.tools.token_tool import TokenTool from data_chain.manager.chunk_manager import ChunkManager +from data_chain.entities.enum import RerankType from data_chain.entities.response_data import Chunk, DocChunk from data_chain.rerank.rerank import Rerank + class BaseSearcher: @staticmethod def find_worker_class(worker_name: str): @@ -40,24 +42,25 @@ class BaseSearcher: err = f"[BaseSearch] 检索器不存在,search_method: {search_method}" logging.exception(err) raise Exception(err) - + @staticmethod - async def rerank(chunk_entities: list[ChunkEntity],rerank_method:Union[None,str], query: str) -> list[ChunkEntity]: + async def rerank(chunk_entities: list[ChunkEntity], rerank_method: str, query: str) -> list[ChunkEntity]: """ 重新排序 :param list: 检索结果 :param query: 查询 :return: 重新排序后的结果 """ - if rerank_method is None: + if rerank_method == RerankType.ALGORITHM.value: score_chunk_entities = [] for chunk_entity in chunk_entities: score = TokenTool.cal_jac(chunk_entity.text, query) score_chunk_entities.append((score, chunk_entity)) score_chunk_entities.sort(key=lambda x: x[0], reverse=True) - sorted_chunk_entities = [chunk_entity for _, chunk_entity in score_chunk_entities] + sorted_chunk_entities = [chunk_entity for _, + chunk_entity in score_chunk_entities] else: - text=[] + text = [] for chunk_entity in chunk_entities: text.append(chunk_entity.text) try: diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 432ddd4c0656ea4999a1047fd40e9eb28ef970dc..8764261c1da738f0d3d497bb4bda00cabd136f04 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -235,7 +235,8 @@ class KnowledgeBaseEntity(Base): tokenizer = Column(String, default=Tokenizer.ZH.value) # 分词器 description = Column(String, default='') # 资产描述 embedding_model = Column(String) # 资产向量化模型 - rerank_model = Column(String) # 资产rerank模型 + rerank_method = Column(String) + rerank_name = Column(String) spearating_characters = Column(String) # 资产分块的分隔符 doc_cnt = Column(Integer, default=0) # 资产文档个数 doc_size = Column(Integer, default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) diff --git a/test.pdf b/test.pdf deleted file mode 100644 index 4b18ba4cebc4c00f682109b641cac29b92b5df4a..0000000000000000000000000000000000000000 Binary files a/test.pdf and /dev/null differ