diff --git a/data_chain/apps/base/convertor/model_convertor.py b/data_chain/apps/base/convertor/model_convertor.py index 59ad4dd4afc6dfa57d875bc7f2e3381410717ac1..23a60e5c6ea189d8d0d7d233d64a2f4ed78f39ed 100644 --- a/data_chain/apps/base/convertor/model_convertor.py +++ b/data_chain/apps/base/convertor/model_convertor.py @@ -1,17 +1,36 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from typing import Optional +import json from data_chain.models.service import ModelDTO from data_chain.stores.postgres.postgres import ModelEntity +from data_chain.apps.base.security.security import Security +from data_chain.config.config import ModelConfig class ModelConvertor(): @staticmethod - def convert_entity_to_dto(model_entity: Optional[ModelEntity]=None) -> ModelDTO: + def convert_entity_to_dto(model_entity: Optional[ModelEntity] = None) -> ModelDTO: if model_entity is None: return ModelDTO() return ModelDTO( id=str(model_entity.id), model_name=model_entity.model_name, + model_type=model_entity.model_type, openai_api_base=model_entity.openai_api_base, + openai_api_key=Security.decrypt( + model_entity.encrypted_openai_api_key, + json.loads(model_entity.encrypted_config) + ), max_tokens=model_entity.max_tokens, + is_online=model_entity.is_online + ) + + @staticmethod + def convert_config_to_entity(model_config: ModelConfig) -> ModelDTO: + if model_config is None: + return ModelEntity() + return ModelDTO( + id=str(model_config['MODEL_ID']), + model_name=model_config['MODEL_NAME'], + model_type=model_config['MODEL_TYPE'], ) diff --git a/data_chain/apps/base/model/llm.py b/data_chain/apps/base/model/llm.py index a830da2a94588448b1424dbcab0a20f6b93734cb..ad0263bb4eba5e75acc7a64b23288b722e67cb19 100644 --- a/data_chain/apps/base/model/llm.py +++ b/data_chain/apps/base/model/llm.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import asyncio import time +import re import json import tiktoken from langchain_openai import ChatOpenAI @@ -27,7 +28,8 @@ class LLM: async def nostream(self, chat, system_call, user_call): chat = self.assemble_chat(chat, system_call, user_call) response = await self.client.ainvoke(chat) - return response.content + content = re.sub(r'.*?\n\n', '', response.content, flags=re.DOTALL) + return content async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): message = self.assemble_chat(history, system_call, user_call) diff --git a/data_chain/apps/base/session/session.py b/data_chain/apps/base/session/session.py index a9d61e4b60db2c5825ab369cca14db42a7e1772d..10e096fda3acab9ba4f393769e7caaf8d91a5288 100644 --- a/data_chain/apps/base/session/session.py +++ b/data_chain/apps/base/session/session.py @@ -88,7 +88,7 @@ class SessionManager: csrf_value = f"{session_id}{rand}" csrf_b64 = base64.b64encode(bytes.fromhex(csrf_value)) - hmac_processor = hmac.new(key=bytes.fromhex(config["CSRF_KEY"]), msg=csrf_b64, digestmod=hashlib.sha256) + hmac_processor = hmac.new(key=base64.b64decode(config["CSRF_KEY"]), msg=csrf_b64, digestmod=hashlib.sha256) signature = base64.b64encode(hmac_processor.digest()) csrf_b64 = csrf_b64.decode("utf-8") @@ -120,7 +120,7 @@ class SessionManager: except Exception as e: logging.error(f"Get csrf token from session error: {e}") - hmac_obj = hmac.new(key=bytes.fromhex(config["CSRF_KEY"]), + hmac_obj = hmac.new(key=base64.b64decode(config["CSRF_KEY"]), msg=token_msg[0].encode("utf-8"), digestmod=hashlib.sha256) signature = hmac_obj.digest() current_signature = base64.b64decode(token_msg[1]) diff --git a/data_chain/apps/router/chunk.py b/data_chain/apps/router/chunk.py index d6ef4b4d68f47f9235b24ab474601047d8a0249e..53baa68d79e4df780c0cfff9d6e096fc2c1ec058 100644 --- a/data_chain/apps/router/chunk.py +++ b/data_chain/apps/router/chunk.py @@ -25,7 +25,7 @@ async def list(req: ListChunkRequest, user_id=Depends(get_user_id)): total=total, data_list=chunk_list) return BaseResponse(data=chunk_page) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.CREATE_CHUNK_ERROR, retmsg=str(e.args[0])) @@ -39,5 +39,5 @@ async def switch(req: SwitchChunkRequest, user_id=Depends(get_user_id)): for id in req.ids: await switch_chunk(id, req.enabled) return BaseResponse(data='success') - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.SWITCH_CHUNK_ERROR, retmsg=str(e.args[0])) diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index 5c84cea18c48e8057c1ed6c54826ac87b9f0b4fb..9a947f632e55168e7318a39d67f74417d8f3cdbc 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -40,7 +40,7 @@ async def list(req: ListDocumentRequest, user_id=Depends(get_user_id)): total=document_list_tuple[1], data_list=document_list_tuple[0]) return BaseResponse(data=document_page) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.LIST_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) @@ -53,7 +53,7 @@ async def update(req: UpdateDocumentRequest, user_id=Depends(get_user_id)): tmp_dict = dict(req) document = await update_document(tmp_dict) return BaseResponse(data=document) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.RENAME_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) @@ -70,7 +70,7 @@ async def run(reqs: RunDocumentRequest, user_id=Depends(get_user_id)): document = await run_document(dict(id=req_id, run=run)) document_dto_list.append(document) return BaseResponse(data=document_dto_list) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.RUN_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) @@ -82,7 +82,7 @@ async def switch(req: SwitchDocumentRequest, user_id=Depends(get_user_id)): await _validate_doucument_belong_to_user(user_id, req.id) document = await switch_document(req.id, req.enabled) return BaseResponse(data=document) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.SWITCH_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) @@ -95,7 +95,7 @@ async def rm(req: DeleteDocumentRequest, user_id=Depends(get_user_id)): await _validate_doucument_belong_to_user(user_id, id) deleted_cnt = await delete_document(req.ids) return BaseResponse(data=deleted_cnt) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.DELETE_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) @@ -121,7 +121,7 @@ async def upload(kb_id: str, files: List[UploadFile] = File(...), user_id=Depend await _validate_knowledge_base_belong_to_user(user_id, kb_id) res = await submit_upload_document_task(user_id, kb_id, files) return BaseResponse(data=res) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.UPLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) @@ -149,7 +149,7 @@ async def download(id: uuid.UUID, user_id=Depends(get_user_id)): else: return BaseResponse( retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg="Failed to retrieve the file.", data=None) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.DOWNLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) diff --git a/data_chain/apps/router/knowledge_base.py b/data_chain/apps/router/knowledge_base.py index 418a56074014bc7b47bd3b769582611f35ce2039..9871b0e974ade35586b6f429c8387d96ecea31b6 100644 --- a/data_chain/apps/router/knowledge_base.py +++ b/data_chain/apps/router/knowledge_base.py @@ -7,8 +7,8 @@ from httpx import AsyncClient from fastapi import APIRouter, File, UploadFile, status, HTTPException from fastapi import Depends from fastapi.responses import StreamingResponse, HTMLResponse, Response -from data_chain.logger.logger import logger as logging +from data_chain.logger.logger import logger as logging from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user from data_chain.exceptions.err_code import ErrorCode from data_chain.exceptions.exception import KnowledgeBaseException @@ -18,6 +18,7 @@ from data_chain.models.api import Page, BaseResponse, ExportKnowledgeBaseRequest from data_chain.apps.service.knwoledge_base_service import _validate_knowledge_base_belong_to_user, \ create_knowledge_base, list_knowledge_base, rm_knowledge_base, generate_knowledge_base_download_link, submit_import_knowledge_base_task, \ update_knowledge_base, list_knowledge_base_task, stop_knowledge_base_task, submit_export_knowledge_base_task, rm_knowledge_base_task, rm_all_knowledge_base_task +from data_chain.apps.service.model_service import get_model_by_kb_id from data_chain.models.constant import KnowledgeLanguageEnum, TaskConstant from data_chain.models.service import KnowledgeBaseDTO from data_chain.apps.service.task_service import _validate_task_belong_to_user @@ -36,7 +37,7 @@ async def create(req: CreateKnowledgeBaseRequest, user_id=Depends(get_user_id)): tmp_dict['user_id'] = user_id knowledge_base = await create_knowledge_base(tmp_dict) return BaseResponse(data=knowledge_base) - except KnowledgeBaseException as e: + except Exception as e: logging.error(f"Create knowledge base failed due to: {e}") return BaseResponse(retcode=ErrorCode.CREATE_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) @@ -50,7 +51,7 @@ async def update(req: UpdateKnowledgeBaseRequest, user_id=Depends(get_user_id)): update_dict['user_id'] = user_id knowledge_base = await update_knowledge_base(update_dict) return BaseResponse(data=knowledge_base) - except KnowledgeBaseException as e: + except Exception as e: logging.error(f"Update knowledge base failed due to: {e}") return BaseResponse(retcode=ErrorCode.UPDATE_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) @@ -69,7 +70,7 @@ async def list(req: ListKnowledgeBaseRequest, user_id=Depends(get_user_id)): total=knowledge_base_list_tuple[1], data_list=knowledge_base_list_tuple[0]) return BaseResponse(data=knowledge_base_page) - except KnowledgeBaseException as e: + except Exception as e: logging.error(f"List knowledge base failed due to: {e}") return BaseResponse(retcode=ErrorCode.LIST_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) @@ -82,7 +83,7 @@ async def rm(req: DeleteKnowledgeBaseRequest, user_id=Depends(get_user_id)): await _validate_knowledge_base_belong_to_user(user_id, req.id) res = await rm_knowledge_base(req.id) return BaseResponse(data=res) - except KnowledgeBaseException as e: + except Exception as e: logging.error(f"Rmove knowledge base failed due to: {e}") return BaseResponse(retcode=ErrorCode.DELETE_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) @@ -207,11 +208,19 @@ async def rm_kb_task(req: RmoveTaskRequest, user_id=Depends(get_user_id)): @router.post('/get_stream_answer', response_class=HTMLResponse) async def get_stream_answer(req: QueryRequest, response: Response): + model_dto = await get_model_by_kb_id(req.kb_sn) + if model_dto is None: + if len(config['MODELS']) > 0: + tokens_upper = config['MODELS'][0]['MAX_TOKENS'] + else: + tokens_upper = 0 + else: + tokens_upper = model_dto.max_tokens try: - question = await question_rewrite(req.history, req.question) - max_tokens = config['MAX_TOKENS']//3 + question = await question_rewrite(req.history, req.question, model_dto) + max_tokens = tokens_upper//3 bac_info = '' - document_chunk_list = await get_similar_chunks(content=question,kb_id=req.kb_sn,temporary_document_ids=req.document_ids, max_tokens=config['MAX_TOKENS']//2, topk=req.top_k) + document_chunk_list = await get_similar_chunks(content=question, kb_id=req.kb_sn, temporary_document_ids=req.document_ids, max_tokens=tokens_upper//2, topk=req.top_k) for i in range(len(document_chunk_list)): document_name = document_chunk_list[i]['document_name'] chunk_list = document_chunk_list[i]['chunk_list'] @@ -237,7 +246,7 @@ async def get_stream_answer(req: QueryRequest, response: Response): logging.error(f"get bac info failed due to: {e}") try: response.headers["Content-Type"] = "text/event-stream" - res = await get_llm_answer(req.history, bac_info, req.question) + res = await get_llm_answer(req.history, bac_info, req.question, is_stream=True, model_dto=model_dto) return StreamingResponse( res, status_code=status.HTTP_200_OK, @@ -250,11 +259,19 @@ async def get_stream_answer(req: QueryRequest, response: Response): @router.post('/get_answer', response_model=BaseResponse[dict]) async def get_answer(req: QueryRequest): + model_dto = await get_model_by_kb_id(req.kb_sn) + if model_dto is None: + if len(config['MODELS']) > 0: + tokens_upper = config['MODELS'][0]['MAX_TOKENS'] + else: + tokens_upper = 0 + else: + tokens_upper = model_dto.max_tokens try: - question = await question_rewrite(req.history, req.question) - max_tokens = config['MAX_TOKENS']//3 + question = await question_rewrite(req.history, req.question, model_dto) + max_tokens = tokens_upper//3 bac_info = '' - document_chunk_list = await get_similar_chunks(content=question,kb_id=req.kb_sn,temporary_document_ids=req.document_ids, max_tokens=config['MAX_TOKENS']//2, topk=req.top_k) + document_chunk_list = await get_similar_chunks(content=question, kb_id=req.kb_sn, temporary_document_ids=req.document_ids, max_tokens = tokens_upper//2, topk=req.top_k) for i in range(len(document_chunk_list)): document_name = document_chunk_list[i]['document_name'] chunk_list = document_chunk_list[i]['chunk_list'] @@ -279,7 +296,7 @@ async def get_answer(req: QueryRequest): bac_info = '' logging.error(f"get bac info failed due to: {e}") try: - answer = await get_llm_answer(req.history, bac_info, req.question, is_stream=False) + answer = await get_llm_answer(req.history, bac_info, req.question, is_stream=False, model_dto=model_dto) tmp_dict = { 'answer': answer, } diff --git a/data_chain/apps/router/model.py b/data_chain/apps/router/model.py index c90eabb7df99dadc87b4bda5670e3d5f85b448dd..acec0ab05d1c04344039c887cbabdf58ccf1ddf6 100644 --- a/data_chain/apps/router/model.py +++ b/data_chain/apps/router/model.py @@ -1,14 +1,14 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from fastapi import Depends from fastapi import APIRouter - +from typing import List from data_chain.models.service import ModelDTO from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user from data_chain.exceptions.err_code import ErrorCode from data_chain.exceptions.exception import DocumentException from data_chain.models.api import BaseResponse from data_chain.models.api import UpdateModelRequest -from data_chain.apps.service.model_service import get_model, update_model +from data_chain.apps.service.model_service import get_model_by_user_id, list_offline_model, update_model router = APIRouter(prefix='/model', tags=['Model']) @@ -20,18 +20,32 @@ router = APIRouter(prefix='/model', tags=['Model']) async def update(req: UpdateModelRequest, user_id=Depends(get_user_id)): try: update_dict = dict(req) + update_dict['user_id']=user_id model_dto = await update_model(user_id, update_dict) + model_dto.openai_api_key=None return BaseResponse(data=model_dto) - except DocumentException as e: + except Exception as e: return BaseResponse(retcode=ErrorCode.UPDATE_MODEL_ERROR, retmsg=str(e.args[0]), data=None) @router.get('/get', response_model=BaseResponse[ModelDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) + dependencies=[Depends(verify_user), + Depends(verify_csrf_token)]) async def get(user_id=Depends(get_user_id)): try: - model_dto = await get_model(user_id) + model_dto = await get_model_by_user_id(user_id) + model_dto.openai_api_key = None return BaseResponse(data=model_dto) - except DocumentException as e: + except Exception as e: + return BaseResponse(retcode=ErrorCode.UPDATE_MODEL_ERROR, retmsg=str(e.args[0]), data=None) + + +@router.get('/list', response_model=BaseResponse[List[ModelDTO]], + dependencies=[Depends(verify_user), + Depends(verify_csrf_token)]) +async def list(): + try: + model_dto_list = await list_offline_model() + return BaseResponse(data=model_dto_list) + except Exception as e: return BaseResponse(retcode=ErrorCode.UPDATE_MODEL_ERROR, retmsg=str(e.args[0]), data=None) diff --git a/data_chain/apps/router/user.py b/data_chain/apps/router/user.py index 39352fcdee32dce474c231374cb31e72e87b3fcc..74217d7166bbfda3512c806088e9a4fb906f0ab7 100644 --- a/data_chain/apps/router/user.py +++ b/data_chain/apps/router/user.py @@ -21,6 +21,7 @@ router = APIRouter( @router.post("/add", response_model=BaseResponse) async def add_user(request: AddUserRequest): name = request.name + email = request.email account = request.account passwd = request.passwd user_entity = await UserManager.get_user_info_by_account(account) @@ -30,12 +31,19 @@ async def add_user(request: AddUserRequest): retmsg="Sign failed due to duplicate account", data={} ) - - user_entity = await UserManager.add_user(name, account, passwd) + if email is not None: + user_entity = await UserManager.get_user_info_by_email(email) + if user_entity is not None: + return BaseResponse( + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + retmsg="Sign failed due to duplicate email", + data={} + ) + user_entity = await UserManager.add_user(name,email,account, passwd) if user_entity is None: return BaseResponse( code=status.HTTP_422_UNPROCESSABLE_ENTITY, - retmsg="Sign failed due to duplicate account", + retmsg="Sign failed due to add user failed", data={} ) return BaseResponse( @@ -47,14 +55,14 @@ async def add_user(request: AddUserRequest): @router.post("/del", response_model=BaseResponse, dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) async def del_user(request: Request, response: Response, user_id=Depends(get_user_id)): - session_id = request.cookies['ECSESSION'] + session_id = request.cookies['WD_ECSESSION'] if not SessionManager.verify_user(session_id): logging.info("User already logged out.") return BaseResponse(code=200, retmsg="ok", data={}) SessionManager.delete_session(user_id) - response.delete_cookie("ECSESSION") - response.delete_cookie("_csrf_tk") + response.delete_cookie("WD_ECSESSION") + response.delete_cookie("wd_csrf_tk") await UserManager.del_user_by_user_id(user_id) response_data = BaseResponse( code=status.HTTP_200_OK, @@ -89,16 +97,16 @@ async def login(request: Request, response: Response, account: str): new_csrf_token = SessionManager.create_csrf_token(current_session) if config['COOKIE_MODE'] == 'DEBUG': response.set_cookie( - "_csrf_tk", + "wd_csrf_tk", new_csrf_token ) response.set_cookie( - "ECSESSION", + "WD_ECSESSION", current_session ) else: response.set_cookie( - "_csrf_tk", + "wd_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, secure=config['SSL_ENABLE'], @@ -106,7 +114,7 @@ async def login(request: Request, response: Response, account: str): samesite="strict" ) response.set_cookie( - "ECSESSION", + "WD_ECSESSION", current_session, max_age=config["SESSION_TTL"] * 60, secure=config['SSL_ENABLE'], @@ -127,14 +135,14 @@ async def login(request: Request, response: Response, account: str): @router.get("/logout", response_model=BaseResponse, dependencies=[Depends(verify_csrf_token)]) async def logout(request: Request, response: Response, user_id=Depends(get_user_id)): - session_id = request.cookies['ECSESSION'] + session_id = request.cookies['WD_ECSESSION'] if not SessionManager.verify_user(session_id): logging.info("User already logged out.") return BaseResponse(code=200, retmsg="ok", data={}) SessionManager.delete_session(user_id) - response.delete_cookie("ECSESSION") - response.delete_cookie("_csrf_tk") + response.delete_cookie("WD_ECSESSION") + response.delete_cookie("wd_csrf_tk") return { "code": status.HTTP_200_OK, "rtmsg": "Logout success", diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 3e2896f278bba0e185184d98ae8663a24bb323fb..e5aace69912bd4c19d7b663190d8039bb980d298 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -5,10 +5,12 @@ import time import jieba import traceback import asyncio + +import jieba.analyse from data_chain.logger.logger import logger as logging from data_chain.apps.service.llm_service import get_question_chunk_relation from data_chain.models.constant import ChunkRelevance -from data_chain.manager.document_manager import DocumentManager,TemporaryDocumentManager +from data_chain.manager.document_manager import DocumentManager, TemporaryDocumentManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.chunk_manager import ChunkManager, TemporaryChunkManager from data_chain.manager.vector_items_manager import VectorItemsManager, TemporaryVectorItemsManager @@ -19,6 +21,7 @@ from data_chain.apps.service.embedding_service import Vectorize from data_chain.config.config import config from data_chain.apps.base.convertor.chunk_convertor import ChunkConvertor + async def _validate_chunk_belong_to_user(user_id: uuid.UUID, chunk_id: uuid.UUID) -> bool: chunk_entity = await ChunkManager.select_by_chunk_id(chunk_id) if chunk_entity is None: @@ -31,13 +34,14 @@ async def list_chunk(params, page_number, page_size): doc_entity = await DocumentManager.select_by_id(params['document_id']) if doc_entity is None or doc_entity.status == DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING: return [], 0 - - chunk_entity_list,total= await ChunkManager.select_by_page(params, page_number, page_size) - chunk_dto_list=[] + + chunk_entity_list, total = await ChunkManager.select_by_page(params, page_number, page_size) + chunk_dto_list = [] for chunk_entity in chunk_entity_list: chunk_dto = ChunkConvertor.convert_entity_to_dto(chunk_entity) chunk_dto_list.append(chunk_dto) - return (chunk_dto_list,total) + return (chunk_dto_list, total) + async def switch_chunk(id, enabled): await ChunkManager.update(id, {'enabled': enabled}) @@ -133,6 +137,21 @@ async def filter_or_expand_chunk_by_llm(kb_id, content, document_para_dict, maxt st = en +async def get_keywords_from_content(content: str, top_k: int = 3): + words = list(jieba.cut(content)) + keywords = set(jieba.analyse.extract_tags(content, topK=top_k)) + result = [] + exist_words=set() + for word in words: + if word in keywords and word not in exist_words: + exist_words.add(word) + result.append(word) + return result + + +async def rerank_chunks(content: str, chunks: list[str], top_k: int = 3): + pass + async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, max_tokens=4096, topk=3): # # 这里返回的chunk_tuple_list是个n*5二维列表 @@ -141,8 +160,8 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m st = time.time() if temporary_document_ids: chunk_tuple_list = await TemporaryChunkManager.find_top_k_similar_chunks( - temporary_document_ids, - content, + temporary_document_ids, + content, max(topk // 2, 1)) elif kb_id: chunk_tuple_list = await ChunkManager.find_top_k_similar_chunks( @@ -160,16 +179,16 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m if target_vector is not None: st = time.time() if temporary_document_ids: - chunk_id_list=[] + chunk_id_list = [] for i in range(retry_times): try: chunk_id_list = await asyncio.wait_for(TemporaryVectorItemsManager.find_top_k_similar_temporary_vectors( target_vector, temporary_document_ids, topk-len(chunk_tuple_list) - ), + ), timeout=1 - ) + ) break except Exception as e: logging.error(f"检索临时向量时出错: {e}") @@ -183,10 +202,10 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m vector_items_id = kb_entity.vector_items_id dim = embedding_model_out_dimensions[embedding_model] vector_items_table = await PostgresDB.get_dynamic_vector_items_table(vector_items_id, dim) - chunk_id_list=[] + chunk_id_list = [] for i in range(retry_times): try: - chunk_id_list = await asyncio.wait_for(VectorItemsManager.find_top_k_similar_vectors(vector_items_table, target_vector, kb_id, topk-len(chunk_tuple_list)),timeout=1) + chunk_id_list = await asyncio.wait_for(VectorItemsManager.find_top_k_similar_vectors(vector_items_table, target_vector, kb_id, topk-len(chunk_tuple_list)), timeout=1) break except Exception as e: logging.error(f"检索向量时出错: {e}") @@ -200,7 +219,7 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m logging.info(f"向量化结果关联片段耗时: {time.time()-st}") st = time.time() document_para_dict = {} - exist_chunk_id_set=set() + exist_chunk_id_set = set() for chunk_tuple in chunk_tuple_list: document_id = chunk_tuple[1] if document_id not in document_para_dict.keys(): @@ -214,7 +233,7 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m new_document_para_dict = {} ex_tokens = max_tokens//len(exist_chunk_id_set) st = time.time() - leave_ex_tokens=0 + leave_ex_tokens = 0 for document_id in document_para_dict.keys(): global_offset_set = set() new_document_para_dict[document_id] = [] @@ -222,9 +241,9 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m document_id = chunk_tuple[1] global_offset = chunk_tuple[2] tokens = chunk_tuple[3] - leave_ex_tokens+=ex_tokens + leave_ex_tokens += ex_tokens if temporary_document_ids: - ex_chunk_tuple_list = await expand_chunk(document_id, global_offset, expand_method='all', max_tokens=leave_ex_tokens-tokens,is_temporary_document=True) + ex_chunk_tuple_list = await expand_chunk(document_id, global_offset, expand_method='all', max_tokens=leave_ex_tokens-tokens, is_temporary_document=True) elif kb_id: ex_chunk_tuple_list = await expand_chunk(document_id, global_offset, expand_method='all', max_tokens=leave_ex_tokens-tokens) ex_chunk_tuple_list.append(chunk_tuple) @@ -233,9 +252,9 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m if global_offset not in global_offset_set: new_document_para_dict[document_id].append(ex_chunk_tuple) global_offset_set.add(global_offset) - leave_ex_tokens-=ex_chunk_tuple[3] - if leave_ex_tokens<=0: - leave_ex_tokens=0 + leave_ex_tokens -= ex_chunk_tuple[3] + if leave_ex_tokens <= 0: + leave_ex_tokens = 0 new_document_para_dict[document_id] = sorted(new_document_para_dict[document_id], key=lambda x: x[2]) logging.info(f"上下文关联耗时: {time.time()-st}") # if config['MODEL_ENH']: @@ -245,15 +264,15 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m document_para_dict = new_document_para_dict docuemnt_chunk_list = [] for document_id in document_para_dict: - document_entity=None + document_entity = None if docuemnt_chunk_list: - document_entity=await TemporaryDocumentManager.select_by_id(document_id) + document_entity = await TemporaryDocumentManager.select_by_id(document_id) elif kb_id: document_entity = await DocumentManager.select_by_id(document_id) if document_entity is not None: document_name = document_entity.name else: - document_name='' + document_name = '' chunk_list = [] st = 0 en = 0 @@ -261,7 +280,7 @@ async def get_similar_chunks(content, kb_id=None, temporary_document_ids=None, m text = '' while en < len( document_para_dict[document_id]) and ( - en == st or document_para_dict[document_id][en][2] + en == st or document_para_dict[document_id][en][2] - document_para_dict[document_id][en - 1][2] == 1): text += document_para_dict[document_id][en][4] diff --git a/data_chain/apps/service/embedding_service.py b/data_chain/apps/service/embedding_service.py index 5edd909e3de45c6f977ea168b044205c96ec455e..28d2d0e7a7890bbba05a63976b4e3f49ee0ee96c 100644 --- a/data_chain/apps/service/embedding_service.py +++ b/data_chain/apps/service/embedding_service.py @@ -10,14 +10,19 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) class Vectorize(): @staticmethod async def vectorize_embedding(text): + headers = { + "Authorization": f"Bearer {config['EMBEDDING_API_KEY']}" + } data = { - "texts": [text] + "input": text, + "model": config["EMBEDDING_MODEL_NAME"], + "encoding_format": "float" } try: - res = requests.post(url=config["REMOTE_EMBEDDING_ENDPOINT"], json=data, verify=False) + res = requests.post(url=config["EMBEDDING_ENDPOINT"],headers=headers, json=data, verify=False) if res.status_code != 200: return None - return res.json()[0] + return res.json()['data'][0]['embedding'] except Exception as e: logging.error(f"Embedding error failed due to: {e}") return None diff --git a/data_chain/apps/service/llm_service.py b/data_chain/apps/service/llm_service.py index 0eeb1fe6b9b4cc75b40cd91e01f83a0f0c1bd6ce..272d3427092fd498014a1567c749406881846710 100644 --- a/data_chain/apps/service/llm_service.py +++ b/data_chain/apps/service/llm_service.py @@ -1,13 +1,14 @@ from typing import List import time import yaml +import json import jieba +from data_chain.models.service import ModelDTO from data_chain.logger.logger import logger as logging from data_chain.config.config import config from data_chain.apps.base.model.llm import LLM from data_chain.parser.tools.split import split_tools - - +from data_chain.apps.base.security.security import Security def load_stopwords(file_path): with open(file_path, 'r', encoding='utf-8') as f: stopwords = set(line.strip() for line in f) @@ -21,7 +22,7 @@ def filter_stopwords(text): return filtered_words -async def question_rewrite(history: List[dict], question: str) -> str: +async def question_rewrite(history: List[dict], question: str,model_dto:ModelDTO=None) -> str: if not history: return question try: @@ -65,10 +66,17 @@ async def question_rewrite(history: List[dict], question: str) -> str: history_prompt = ''.join(splited_prompt) prompt = prompt.format(history=history_prompt, question=question) user_call = "请输出改写后的问题" - default_llm = LLM(model_name=config['MODEL_NAME'], - openai_api_base=config['OPENAI_API_BASE'], - openai_api_key=config['OPENAI_API_KEY'], - max_tokens=config['MAX_TOKENS'], + default_llm = LLM(model_name=config['MODELS'][0]['MODEL_NAME'], + openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'], + openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'], + max_tokens=config['MODELS'][0]['MAX_TOKENS'], + request_timeout=60, + temperature=0.35) + if model_dto is not None: + default_llm = LLM(model_name=model_dto.model_name, + openai_api_base=model_dto.openai_api_base, + openai_api_key=model_dto.openai_api_key, + max_tokens=model_dto.max_tokens, request_timeout=60, temperature=0.35) rewrite_question = await default_llm.nostream([], prompt, user_call) @@ -85,7 +93,7 @@ async def question_split(question: str) -> List[str]: return [question] -async def get_llm_answer(history, bac_info, question, is_stream=True): +async def get_llm_answer(history, bac_info, question, is_stream=True,model_dto:ModelDTO=None): try: with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) @@ -95,28 +103,41 @@ async def get_llm_answer(history, bac_info, question, is_stream=True): logging.error(f'Get prompt failed : {e}') raise e llm = LLM( - openai_api_key=config['OPENAI_API_KEY'], - openai_api_base=config['OPENAI_API_BASE'], - model_name=config['MODEL_NAME'], - max_tokens=config['MAX_TOKENS']) + openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'], + openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'], + model_name=config['MODELS'][0]['MODEL_NAME'], + max_tokens=config['MODELS'][0]['MAX_TOKENS']) + if model_dto is not None: + llm = LLM(model_name=model_dto.model_name, + openai_api_base=model_dto.openai_api_base, + openai_api_key=model_dto.openai_api_key, + max_tokens=model_dto.max_tokens + ) if is_stream: return llm.stream(history, prompt, question) res = await llm.nostream(history, prompt, question) return res -async def get_question_chunk_relation(question, chunk): +async def get_question_chunk_relation(question, chunk,model_dto:ModelDTO=None): with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_template_dict = yaml.load(f, Loader=yaml.SafeLoader) prompt = prompt_template_dict['DETERMINE_ANSWER_AND_QUESTION'] prompt = prompt.format(chunk=chunk, question=question) user_call = "判断,并输出关联性编号" - default_llm = LLM(model_name=config['MODEL_NAME'], - openai_api_base=config['OPENAI_API_BASE'], - openai_api_key=config['OPENAI_API_KEY'], - max_tokens=config['MAX_TOKENS'], - request_timeout=60, - temperature=0.35) + default_llm = LLM(model_name=config['MODELS'][0]['MODEL_NAME'], + openai_api_base=config['MODELS'][0]['OPENAI_API_BASE'], + openai_api_key=config['MODELS'][0]['OPENAI_API_KEY'], + max_tokens=config['MODELS'][0]['MAX_TOKENS'], + request_timeout=60, + temperature=0.35) + if model_dto is not None: + default_llm = LLM(model_name=model_dto.model_name, + openai_api_base=model_dto.openai_api_base, + openai_api_key=model_dto.openai_api_key, + max_tokens=model_dto.max_tokens, + request_timeout=60, + temperature=0.35) ans = await default_llm.nostream([], prompt, user_call) return ans diff --git a/data_chain/apps/service/model_service.py b/data_chain/apps/service/model_service.py index e490df523f62414c48d3f2ebe4c05ff8c888cf3c..2f0c848523b47b93111a6af0f913549dce014153 100644 --- a/data_chain/apps/service/model_service.py +++ b/data_chain/apps/service/model_service.py @@ -1,8 +1,11 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import uuid -from data_chain.logger.logger import logger as logging import json +import asyncio +from data_chain.logger.logger import logger as logging +from data_chain.config.config import config +from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.model_manager import ModelManager from data_chain.exceptions.exception import ModelException from data_chain.apps.base.convertor.model_convertor import ModelConvertor @@ -11,7 +14,6 @@ from data_chain.apps.base.security.security import Security from data_chain.stores.postgres.postgres import ModelEntity - async def _validate_model_belong_to_user(user_id: uuid.UUID, model_id: uuid.UUID) -> bool: model_entity = await ModelManager.select_by_id(model_id) if model_entity is None: @@ -21,41 +23,75 @@ async def _validate_model_belong_to_user(user_id: uuid.UUID, model_id: uuid.UUID async def test_model_connection(model_name, openai_api_base, openai_api_key, max_tokens): - return True try: - llm = LLM(model_name, openai_api_base, openai_api_key, max_tokens) - await llm.nostream([],"hello world", "hello world") + llm = LLM(openai_api_key, openai_api_base, model_name, max_tokens) + await asyncio.wait_for(llm.nostream([], "hello world", "hello world"), timeout=60) return True except Exception as e: logging.error(f"test model connection error:{e}") return False -async def get_model(user_id): + +async def get_model_by_user_id(user_id): model_entity = await ModelManager.select_by_user_id(user_id) return ModelConvertor.convert_entity_to_dto(model_entity) +async def get_model_by_kb_id(kb_id): + model_entity = None + kb_entity = await KnowledgeBaseManager.select_by_id(kb_id) + if kb_entity is not None: + model_entity = await ModelManager.select_by_user_id(kb_entity.user_id) + return ModelConvertor.convert_entity_to_dto(model_entity) + + +async def list_offline_model(): + model_configs = config['MODELS'] + model_dto_list = [] + for model_config in model_configs: + try: + model_dto_list.append(ModelConvertor.convert_config_to_entity(model_config)) + except Exception as e: + logging.error(f"load model config error due to:{e}") + continue + return model_dto_list + + async def update_model(user_id, update_dict): - model_name = update_dict['model_name'] - openai_api_base = update_dict['openai_api_base'] - openai_api_key = update_dict['openai_api_key'] - max_tokens = update_dict['max_tokens'] - encrypted_openai_api_key, encrypted_config = Security.encrypt(openai_api_key) - if not await test_model_connection(model_name, openai_api_base, openai_api_key, max_tokens): + if not update_dict['is_online']: + for model_config in config['MODELS']: + if model_config['MODEL_ID'] == update_dict['id']: + update_dict['model_name'] = model_config['MODEL_NAME'] + update_dict['model_type'] = model_config['MODEL_TYPE'] + update_dict['openai_api_base'] = model_config['OPENAI_API_BASE'] + update_dict['openai_api_key'] = model_config['OPENAI_API_KEY'] + update_dict['max_tokens'] = model_config['MAX_TOKENS'] + if 'id' in update_dict.keys(): + del update_dict['id'] + else: + update_dict['model_type'] = '' + encrypted_openai_api_key, encrypted_config = Security.encrypt(update_dict['openai_api_key']) + if not await test_model_connection( + update_dict['model_name'], + update_dict['openai_api_base'], + update_dict['openai_api_key'], + update_dict['max_tokens']): raise ModelException("Model connection test failed") model_entity = await ModelManager.select_by_user_id(user_id) if model_entity is None: model_entity = ModelEntity( - model_name=model_name, - user_id=user_id, - openai_api_base=openai_api_base, + model_name=update_dict['model_name'], + model_type=update_dict['model_type'], + is_online=update_dict['is_online'], + user_id=update_dict['user_id'], + openai_api_base=update_dict['openai_api_base'], encrypted_openai_api_key=encrypted_openai_api_key, encrypted_config=json.dumps(encrypted_config), - max_tokens=max_tokens + max_tokens=update_dict['max_tokens'] ) await ModelManager.insert(model_entity) else: - update_dict['encrypted_openai_api_key']=encrypted_openai_api_key - update_dict['encrypted_config']=json.dumps(encrypted_config) + update_dict['encrypted_openai_api_key'] = encrypted_openai_api_key + update_dict['encrypted_config'] = json.dumps(encrypted_config) await ModelManager.update_by_user_id(user_id, update_dict) - return ModelConvertor.convert_entity_to_dto(model_entity) \ No newline at end of file + return ModelConvertor.convert_entity_to_dto(model_entity) diff --git a/data_chain/apps/service/user_service.py b/data_chain/apps/service/user_service.py index 9cdb5d4f1a651f8242953a6a4cd5cdeaa88067d0..431ad6e4eb03eba2f8379797a147c06b04274b7a 100644 --- a/data_chain/apps/service/user_service.py +++ b/data_chain/apps/service/user_service.py @@ -20,7 +20,7 @@ class UserHTTPException(HTTPException): def verify_user(request: HTTPConnection): try: - session_id = request.cookies["ECSESSION"] + session_id = request.cookies["WD_ECSESSION"] except: raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Authentication Error.", data="") @@ -31,7 +31,7 @@ def verify_user(request: HTTPConnection): def get_session(request: HTTPConnection): try: - session_id = request.cookies["ECSESSION"] + session_id = request.cookies["WD_ECSESSION"] except: raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Authentication Error.", data="") @@ -43,7 +43,7 @@ def get_session(request: HTTPConnection): def get_user_id(request: HTTPConnection) -> uuid: try: - session_id = request.cookies["ECSESSION"] + session_id = request.cookies["WD_ECSESSION"] except: raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Authentication Error.", data="") @@ -78,7 +78,7 @@ def verify_csrf_token(request: Request, response: Response): return try: csrf_token = request.headers.get('x-csrf-token').strip("\"") - session = request.cookies.get('ECSESSION') + session = request.cookies.get('WD_ECSESSION') except: raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Authentication Error.", data="") @@ -91,6 +91,6 @@ def verify_csrf_token(request: Request, response: Response): raise UserHTTPException(status_code=status.HTTP_401_UNAUTHORIZED, retcode=401, rtmsg="Renew CSRF token failed.", data="") - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, + response.set_cookie("wd_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, secure=True, domain=config["DOMAIN"], samesite="strict") return response diff --git a/data_chain/config/config.py b/data_chain/config/config.py index 28d574ce7c007eb40f1e9ed2a0078b48af6740ad..b7810e2093bce5898bf338a4664cff67beeea56f 100644 --- a/data_chain/config/config.py +++ b/data_chain/config/config.py @@ -1,11 +1,28 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import os - +import uuid from dotenv import dotenv_values from pydantic import BaseModel, Field +from typing import List + + +class DictBaseModel(BaseModel): + def __getitem__(self, key): + if key in self.__dict__: + return getattr(self, key) + return None + +class ModelConfig(DictBaseModel): + MODEL_ID: str = Field(default_factory=lambda: uuid.uuid4(), description="模型ID") + MODEL_NAME: str = Field(..., description="使用的语言模型名称或版本") + MODEL_TYPE: str = Field(..., description="语言模型类型") + OPENAI_API_BASE: str = Field(..., description="语言模型服务的基础URL") + OPENAI_API_KEY: str = Field(..., description="语言模型访问密钥") + MAX_TOKENS: int = Field(..., description="单次请求中允许的最大Token数") -class ConfigModel(BaseModel): + +class ConfigModel(DictBaseModel): # FastAPI UVICORN_IP: str = Field(None, description="FastAPI 服务的IP地址") UVICORN_PORT: int = Field(None, description="FastAPI 服务的端口号") @@ -13,7 +30,7 @@ class ConfigModel(BaseModel): SSL_KEYFILE: str = Field(None, description="SSL密钥文件的路径") SSL_ENABLE: bool = Field(None, description="是否启用SSL连接") # LOG METHOD - LOG_METHOD:str = Field('stdout', description="日志记录方式") + LOG_METHOD: str = Field('stdout', description="日志记录方式") # Postgres DATABASE_URL: str = Field(None, description="Postgres数据库链接url") # MinIO @@ -28,11 +45,14 @@ class ConfigModel(BaseModel): REDIS_PENDING_TASK_QUEUE_NAME: str = Field(default='rag_pending_task_queue', description="redis等待开始任务队列名称") REDIS_SUCCESS_TASK_QUEUE_NAME: str = Field(default='rag_success_task_queue', description="redis已经完成任务队列名称") REDIS_RESTART_TASK_QUEUE_NAME: str = Field(default='rag_restart_task_queue', description="redis等待重启任务队列名称") - REDIS_SILENT_ERROR_TASK_QUEUE_NAME: str = Field(default='rag_silent_error_task_queue', description="redis等待重启任务队列名称") + REDIS_SILENT_ERROR_TASK_QUEUE_NAME: str = Field( + default='rag_silent_error_task_queue', description="redis等待重启任务队列名称") # Task TASK_RETRY_TIME: int = Field(None, description="任务重试次数") # Embedding - REMOTE_EMBEDDING_ENDPOINT: str = Field(None, description="远程embedding服务url地址") + EMBEDDING_API_KEY: str = Field(None, description="embedding服务api key") + EMBEDDING_ENDPOINT: str = Field(None, description="embedding服务url地址") + EMBEDDING_MODEL_NAME: str = Field(None, description="embedding模型名称") # Token SESSION_TTL: int = Field(None, description="用户session过期时间") CSRF_KEY: str = Field(None, description="csrf的密钥") @@ -45,19 +65,17 @@ class ConfigModel(BaseModel): # Stop Words PATH STOP_WORDS_PATH: str = Field(None, description="停用词表存放位置") # LLM config - MODEL_NAME: str = Field(None, description="使用的语言模型名称或版本") - OPENAI_API_BASE: str = Field(None, description="语言模型服务的基础URL") - OPENAI_API_KEY: str = Field(None, description="语言模型访问密钥") - REQUEST_TIMEOUT: int = Field(None, description="大模型请求超时时间") - MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数") + MODELS: List[ModelConfig] = Field(..., description="多个大模型的配置列表") MODEL_ENH: bool = Field(None, description="是否使用大模型能力增强") # DEFAULT USER DEFAULT_USER_ACCOUNT: str = Field(default='admin', description="默认用户账号") - DEFAULT_USER_PASSWD: str = Field(default='123456', description="默认用户密码") + DEFAULT_USER_PASSWD: str = Field(default='8d969eef6ecad3c29a3a629280e686cf0c3f5d5a86aff3ca12020c923adc6c92', description="默认用户密码") DEFAULT_USER_NAME: str = Field(default='admin', description="默认用户名称") DEFAULT_USER_LANGUAGE: str = Field(default='zh', description="默认用户语言") # DOCUMENT PARSER - DOCUMENT_PARSE_USE_CPU_LIMIT:int=Field(default=4,description="文档解析器使用CPU核数") + DOCUMENT_PARSE_USE_CPU_LIMIT: int = Field(default=4, description="文档解析器使用CPU核数") + + class Config: config: ConfigModel @@ -66,7 +84,20 @@ class Config: config_file = os.getenv("CONFIG") else: config_file = "data_chain/common/.env" - self.config = ConfigModel(**(dotenv_values(config_file))) + env_vars = dotenv_values(config_file) + + models_configs = [] + model_keys = set([k.split('_')[1] for k in env_vars.keys() if k.startswith('MODEL_')]) # 提取模型标识符 + + for model_key in model_keys: + single_model_config = { + k.replace(f'MODEL_{model_key}_', ''): v for k, v in env_vars.items() + if k.startswith(f'MODEL_{model_key}')} + models_configs.append(single_model_config) + self.config = ConfigModel( + MODELS=[ModelConfig(**model_cfg) for model_cfg in models_configs], + **{k: v for k, v in env_vars.items() if not k.startswith('MODEL_')} + ) if os.getenv("PROD"): os.remove(config_file) diff --git a/data_chain/manager/user_manager.py b/data_chain/manager/user_manager.py index 3d7d995e9a1356c42aa1b631f3767c707cc21ba3..90c9eb7a0262d14de65c16919df285d5b7d4453b 100644 --- a/data_chain/manager/user_manager.py +++ b/data_chain/manager/user_manager.py @@ -10,9 +10,10 @@ from data_chain.stores.postgres.postgres import PostgresDB, User class UserManager: @staticmethod - async def add_user(name, account, passwd): + async def add_user(name,email, account, passwd): user_slice = User( name=name, + email=email, account=account, passwd=passwd ) @@ -80,7 +81,17 @@ class UserManager: except Exception as e: logging.error(f"Failed to get user info by account: {e}") return None - + @staticmethod + async def get_user_info_by_email(email): + try: + async with await PostgresDB.get_session() as session: + stmt = select(User).where(User.email == email) + result = await session.execute(stmt) + user = result.scalars().first() + return user + except Exception as e: + logging.error(f"Failed to get user info by account: {e}") + return None @staticmethod async def get_user_info_by_user_id(user_id): result = None @@ -92,3 +103,4 @@ class UserManager: except Exception as e: logging.error(f"Get user failed due to error: {e}") return result + diff --git a/data_chain/models/api.py b/data_chain/models/api.py index 8ad95500bbdb1a8e1e5a7d320bc97346d1de912e..dc8f314bce6ba18f97e7bd172167ea7865280159 100644 --- a/data_chain/models/api.py +++ b/data_chain/models/api.py @@ -6,7 +6,7 @@ from typing import Dict, Generic, List, Optional, TypeVar from data_chain.models.service import DocumentTypeDTO -from pydantic import BaseModel, Field, validator,constr +from pydantic import BaseModel, Field, validator, constr T = TypeVar('T') @@ -31,22 +31,24 @@ class Page(DictionaryBaseModel, Generic[T]): total: int data_list: Optional[List[T]] + class CreateKnowledgeBaseRequest(DictionaryBaseModel): - name: str=Field(...,min_length=1, max_length=150) - language: str=Field(...,pattern=r"^(zh|en)$") - description: Optional[str]=Field(None, max_length=150) - embedding_model: str=Field(...,pattern=r"^(bge_large_zh|bge_large_en)$") + name: str = Field(..., min_length=1, max_length=150) + language: str = Field(..., pattern=r"^(zh|en)$") + description: Optional[str] = Field(None, max_length=150) + embedding_model: str = Field(..., pattern=r"^(bge_large_zh|bge_large_en)$") default_parser_method: str default_chunk_size: int = Field(1024, ge=128, le=1024) document_type_list: Optional[List[str]] + class UpdateKnowledgeBaseRequest(DictionaryBaseModel): id: uuid.UUID - name: Optional[str]=Field(None,min_length=1, max_length=150) - language: Optional[str]=Field(None,pattern=r"^(zh|en)$") + name: Optional[str] = Field(None, min_length=1, max_length=150) + language: Optional[str] = Field(None, pattern=r"^(zh|en)$") description: Optional[str] - embedding_model: Optional[str]=Field(None,pattern=r"^(bge_large_zh|bge_large_en)$") - default_parser_method: Optional[str]=None + embedding_model: Optional[str] = Field(None, pattern=r"^(bge_large_zh|bge_large_en)$") + default_parser_method: Optional[str] = None default_chunk_size: Optional[int] = Field(None, ge=128, le=1024) document_type_list: Optional[List[DocumentTypeDTO]] = None @@ -92,6 +94,7 @@ class ListTaskRequest(DictionaryBaseModel): page_size: int = 10 created_time_order: Optional[str] = 'desc' # 取值desc降序, asc升序 + class ListDocumentRequest(DictionaryBaseModel): kb_id: Optional[uuid.UUID] = None id: Optional[uuid.UUID] = None @@ -105,6 +108,7 @@ class ListDocumentRequest(DictionaryBaseModel): parser_method: Optional[List[str]] = None page_number: int = 1 page_size: int = 10 + @validator('status', each_item=True) def check_types(cls, v): # 定义允许的类型正则表达式 @@ -113,18 +117,18 @@ class ListDocumentRequest(DictionaryBaseModel): raise ValueError(f'Invalid type value "{v}". Must match pattern {allowed_type_pattern}.') return v + class UpdateDocumentRequest(DictionaryBaseModel): id: uuid.UUID - name: Optional[str] = Field(None,min_length=1, max_length=128) - parser_method: Optional[str] = Field(None,pattern=r"^(general|ocr|enhanced)$") + name: Optional[str] = Field(None, min_length=1, max_length=128) + parser_method: Optional[str] = Field(None, pattern=r"^(general|ocr|enhanced)$") type_id: Optional[uuid.UUID] = None chunk_size: Optional[int] = Field(None, gt=127, lt=1025) - class RunDocumentRequest(DictionaryBaseModel): ids: List[uuid.UUID] - run: str=Field(...,pattern=r"^(run|cancel)$")# run运行或者cancel取消 + run: str = Field(..., pattern=r"^(run|cancel)$") # run运行或者cancel取消 class SwitchDocumentRequest(DictionaryBaseModel): @@ -139,26 +143,35 @@ class DeleteDocumentRequest(DictionaryBaseModel): class DownloadDocumentRequest(DictionaryBaseModel): ids: List[uuid.UUID] + class GetTemporaryDocumentStatusRequest(DictionaryBaseModel): ids: List[uuid.UUID] + class TemporaryDocumentInParserRequest(DictionaryBaseModel): id: uuid.UUID - name:str=Field(...,min_length=1, max_length=128) - type:str=Field(...,min_length=1, max_length=128) - bucket_name:str=Field(...,min_length=1, max_length=128) - parser_method:str=Field("ocr",pattern=r"^(general|ocr)$") - chunk_size:int=Field(1024,ge=128,le=1024) + name: str = Field(..., min_length=1, max_length=128) + type: str = Field(..., min_length=1, max_length=128) + bucket_name: str = Field(..., min_length=1, max_length=128) + parser_method: str = Field("ocr", pattern=r"^(general|ocr)$") + chunk_size: int = Field(1024, ge=128, le=1024) + + class ParserTemporaryDocumenRequest(DictionaryBaseModel): - document_list:List[TemporaryDocumentInParserRequest] + document_list: List[TemporaryDocumentInParserRequest] + + class DeleteTemporaryDocumentRequest(DictionaryBaseModel): ids: List[uuid.UUID] + + class RelatedTemporaryDocumenRequest(DictionaryBaseModel): - content:str - top_k:int=Field(5, ge=0, le=10) + content: str + top_k: int = Field(5, ge=0, le=10) kb_sn: Optional[uuid.UUID] = None document_ids: Optional[List[uuid.UUID]] = None + class ListChunkRequest(DictionaryBaseModel): document_id: uuid.UUID text: Optional[str] = None @@ -166,22 +179,28 @@ class ListChunkRequest(DictionaryBaseModel): types: Optional[List[str]] = None page_size: int = 10 # 定义一个验证器来确保types中的每个元素都符合正则表达式 + @validator('types', each_item=True) def check_types(cls, v): # 定义允许的类型正则表达式 - allowed_type_pattern = r"^(para|table|image)$" # 替换为你需要的正则表达式 + allowed_type_pattern = r"^(para|table|image)$" # 替换为你需要的正则表达式 if not re.match(allowed_type_pattern, v): raise ValueError(f'Invalid type value "{v}". Must match pattern {allowed_type_pattern}.') return v + + class SwitchChunkRequest(DictionaryBaseModel): ids: List[uuid.UUID] # 支持批量操作 enabled: bool # True启用, False未启用 -class AddUserRequest(DictionaryBaseModel): - name: str - account: str - passwd: str +class AddUserRequest(BaseModel): + name: str = Field(..., min_length=1, max_length=10, description="用户名,长度在1到10个字符") + email: Optional[str] = Field(None, min_length=5, max_length=30, + pattern='^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', description="邮箱,长度在5到30个字符") + account: str = Field(..., min_length=5, max_length=20, + pattern="^[a-z0-9]+$", description="账号,由小写字母和数字组成,长度在5到20个字符") + passwd: str = Field(..., min_length=63, max_length=65, description="密码的哈希") class UpdateUserRequest(DictionaryBaseModel): @@ -192,17 +211,20 @@ class UpdateUserRequest(DictionaryBaseModel): status: Optional[str] = None language: Optional[str] = None + class UpdateModelRequest(DictionaryBaseModel): - model_name: str=Field(...,min_length=1, max_length=128) - openai_api_base: str=Field(...,min_length=1, max_length=128) - openai_api_key: str=Field(...,min_length=1, max_length=128) - max_tokens: int=Field(1024, ge=1024, le=8192) + id: Optional[uuid.UUID] = Field(None) + model_name: Optional[str] = Field(None, min_length=1, max_length=128) + openai_api_base: Optional[str] = Field(None, min_length=1, max_length=128) + openai_api_key: Optional[str] = Field(None, min_length=1, max_length=128) + max_tokens: Optional[int] = Field(None, ge=1024, le=8192) + is_online: bool = Field(default=True) class QueryRequest(BaseModel): question: str kb_sn: Optional[uuid.UUID] = None - document_ids : Optional[List[uuid.UUID]] = None + document_ids: Optional[List[uuid.UUID]] = None top_k: int = Field(5, ge=0, le=10) fetch_source: bool = False history: Optional[List] = [] diff --git a/data_chain/models/service.py b/data_chain/models/service.py index 8cf437b8a806db2ce4ceb4f2a32084a4cdc18939..b4a5da4282b0e95c0e11172e98906e010d21f861 100644 --- a/data_chain/models/service.py +++ b/data_chain/models/service.py @@ -81,5 +81,8 @@ class ChunkDTO(DictionaryBaseModelDTO): class ModelDTO(DictionaryBaseModelDTO): id: Optional[str] = None model_name: Optional[str] = None + model_type: Optional[str] = None openai_api_base: Optional[str] = None + openai_api_key: Optional[str] = None max_tokens: Optional[int] = None + is_online: Optional[bool] = None diff --git a/data_chain/stores/postgres/postgres.py b/data_chain/stores/postgres/postgres.py index 16afba1295fa71d3006f06e38843be53a4ad25dd..83ee876cedec0d2d6f697fb3293c96b5c9964c12 100644 --- a/data_chain/stores/postgres/postgres.py +++ b/data_chain/stores/postgres/postgres.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import declarative_base, relationship from data_chain.config.config import config from data_chain.models.api import CreateKnowledgeBaseRequest -from data_chain.models.constant import KnowledgeStatusEnum,ParseMethodEnum +from data_chain.models.constant import KnowledgeStatusEnum, ParseMethodEnum Base = declarative_base() @@ -22,6 +22,7 @@ class User(Base): id = Column(UUID, default=uuid4, primary_key=True) # 用户id account = Column(String, unique=True) # 用户账号 + email = Column(String, unique=True) # 用户邮箱 passwd = Column(String) name = Column(String) language = Column(String, default='zh') @@ -41,7 +42,9 @@ class ModelEntity(Base): __tablename__ = 'model' id = Column(UUID, default=uuid4, primary_key=True) user_id = Column(UUID, ForeignKey('users.id', ondelete="CASCADE")) + is_online = Column(Boolean, default=False) model_name = Column(String) + model_type = Column(String) openai_api_base = Column(String) encrypted_openai_api_key = Column(String) encrypted_config = Column(String) @@ -58,16 +61,16 @@ class KnowledgeBaseEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) user_id = Column(UUID, ForeignKey('users.id', ondelete="CASCADE")) # 用户id - name = Column(String,default='') # 知识库名资产名 + name = Column(String, default='') # 知识库名资产名 language = Column(String, default='zh') # 资产文档语言 - description = Column(String,default='') # 资产描述 + description = Column(String, default='') # 资产描述 embedding_model = Column(String) # 资产向量化模型 - document_number = Column(Integer,default=0) # 资产文档个数 - document_size = Column(Integer,default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) - default_parser_method = Column(String,default=ParseMethodEnum.GENERAL) # 默认解析方法 - default_chunk_size = Column(Integer,default=1024) # 默认分块大小 + document_number = Column(Integer, default=0) # 资产文档个数 + document_size = Column(Integer, default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) + default_parser_method = Column(String, default=ParseMethodEnum.GENERAL) # 默认解析方法 + default_chunk_size = Column(Integer, default=1024) # 默认分块大小 vector_items_id = Column(UUID, default=uuid4) # 向量表id - status = Column(String,default=KnowledgeStatusEnum.IDLE) + status = Column(String, default=KnowledgeStatusEnum.IDLE) created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) updated_time = Column( TIMESTAMP(timezone=True), @@ -219,10 +222,10 @@ class TemporaryDocumentEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) name = Column(String) extension = Column(String) - bucket_name=Column(String) - parser_method=Column(String) + bucket_name = Column(String) + parser_method = Column(String) chunk_size = Column(Integer) # 文档分块大小 - status=Column(String) + status = Column(String) created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp())