diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index 5c84cea18c48e8057c1ed6c54826ac87b9f0b4fb..a2a3de0021ad5edf362789b3358f36c761b450ed 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -1,8 +1,14 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import urllib +import os from typing import Dict, List import uuid from fastapi import HTTPException,status +from httpx import AsyncClient +from fastapi import Depends +from fastapi import APIRouter, File, UploadFile +from fastapi.responses import StreamingResponse +from data_chain.logger.logger import logger as logging from data_chain.models.service import DocumentDTO, TemporaryDocumentDTO from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user from data_chain.exceptions.err_code import ErrorCode @@ -18,10 +24,6 @@ from data_chain.apps.service.document_service import _validate_doucument_belong_ get_file_name_and_extension, init_temporary_document_parse_task, delete_temporary_document, get_temporary_document_parse_status, \ get_related_document -from httpx import AsyncClient -from fastapi import Depends -from fastapi import APIRouter, File, UploadFile -from fastapi.responses import StreamingResponse router = APIRouter(prefix='/doc', tags=['Document']) @@ -170,7 +172,7 @@ async def parser_temporary_doc(req: ParserTemporaryDocumenRequest): if tmp_dict['type']=='application/pdf': tmp_dict['type']='.pdf' elif tmp_dict['type']=='text/html': - tmp_dict['type']='.html' + tmp_dict['type']='.txt'#html解析方式暂时替代为txt elif tmp_dict['type']=='text/plain': tmp_dict['type']='.txt' elif tmp_dict['type']=='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': @@ -181,6 +183,12 @@ async def parser_temporary_doc(req: ParserTemporaryDocumenRequest): tmp_dict['type']='.docx' elif tmp_dict['type']=='application/msword': tmp_dict['type']='.doc' + else: + try: + tmp_dict['type']=os.path.splitext(tmp_dict['name'])[1] + except Exception as e: + logging.error(f"Error in get document type due to: {e}") + continue temporary_document_list.append(tmp_dict) result = await init_temporary_document_parse_task(temporary_document_list) return BaseResponse(data=result) diff --git a/data_chain/apps/router/knowledge_base.py b/data_chain/apps/router/knowledge_base.py index 418a56074014bc7b47bd3b769582611f35ce2039..3e91e138880d3665dec1bcf1a0d4cff724106c02 100644 --- a/data_chain/apps/router/knowledge_base.py +++ b/data_chain/apps/router/knowledge_base.py @@ -220,7 +220,7 @@ async def get_stream_answer(req: QueryRequest, response: Response): bac_info += '段落'+str(j)+':\n\n' bac_info += chunk_list[j]+'\n\n' bac_info = split_chunk(bac_info) - if len(bac_info) > max_tokens: + if len(bac_info) > max_tokens//8*7: bac_info = '' for i in range(len(document_chunk_list)): document_name = document_chunk_list[i]['document_name'] @@ -263,7 +263,7 @@ async def get_answer(req: QueryRequest): bac_info += '段落'+str(j)+':\n\n' bac_info += chunk_list[j]+'\n\n' bac_info = split_chunk(bac_info) - if len(bac_info) > max_tokens: + if len(bac_info) > max_tokens//8*7: bac_info = '' for i in range(len(document_chunk_list)): document_name = document_chunk_list[i]['document_name'] diff --git a/data_chain/common/stopwords.txt b/data_chain/common/stopwords.txt index 5784b4462a67442a7301abb939b8ca17fa791598..c96cc1dd6f1a448c8db4e31c4f0a518fc26940dd 100644 --- a/data_chain/common/stopwords.txt +++ b/data_chain/common/stopwords.txt @@ -79,16 +79,12 @@ meanwhile ' ( ) -* 可是 怪 here’s + , yourselves -- -. -/ [⑥] 甚或 集中 @@ -103,7 +99,6 @@ eleven 于是乎 much ? -@ 第二单元 A 够瞧的 @@ -1083,7 +1078,6 @@ couldn’t 喽 since .. -./ 倘若 we’ve 更为 @@ -1101,7 +1095,6 @@ hither 个人 基于 无 -// 嗡 certainly 造成 @@ -1687,7 +1680,6 @@ awfully latterly amongst 敞开儿 -etc 然后 net 这么 @@ -2406,8 +2398,7 @@ what's 上 下 光是 -恰恰 -不 +恰 somewhere 与 [②⑦] @@ -2765,7 +2756,6 @@ circa only should 结合 -:// 依 多数 再者说 @@ -2849,24 +2839,17 @@ $ ' ( ) -* + , -- -- -. .. ... ...... ................... -./ .一 .数 .日 -/ -// : -:// :: ; < @@ -2874,7 +2857,6 @@ $ > >> ? -@ A Lex [ @@ -2998,7 +2980,6 @@ sup 下去 下来 下面 -不 不一 不下 不久 diff --git a/data_chain/parser/handler/html_parser.py b/data_chain/parser/handler/html_parser.py index ee3ee74f26e39f70f076a6405a8e278b1c48191b..5b0a0471e0111b3aca0ef3f65842b5ba552777ea 100644 --- a/data_chain/parser/handler/html_parser.py +++ b/data_chain/parser/handler/html_parser.py @@ -44,7 +44,7 @@ class HtmlService(BaseService): return node_dict - def parser(self, file_path): + async def parser(self, file_path): html_content = self.open_file(file_path) # 解析 HTML 内容 soup = BeautifulSoup(html_content, 'lxml') diff --git a/utils/README.md b/utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6a0cc3592760860e974304334566dc7d2f2fc783 --- /dev/null +++ b/utils/README.md @@ -0,0 +1,218 @@ +# 脚本使用指南 + +## 1. 生成问答对 + +### 1.1 功能简介 + +该功能用于将原始数据转换为问答对,并将问答对保存为 json (默认) /xlsx/yaml 格式的文件。 + +### 1.2 功能参数说明 + +| 参数名 | 默认值 | 说明 | +| --- | --- |-----------------------------------------------| +| path | | 必填,指定待处理的文件路径,支持 docx、pdf、txt 等格式 | +| output_path | | 必填,指定输出路径 | +| output_format | json | 可选,指定输出格式,支持 json、xlsx、yaml 三种格式 | +| enhance | False | 可选,是否使用增强模式,增强模式下,将通过验证机制增强生成的问答对的准确率 | +| qa_count | 5 | 可选,指定生成问答对的数量,对于每个文档随机选择若干chunk生成qa_count个问答对 | + + +### 1.3 使用示范 +```bash +python utils/main.py \ +qa_generate \ +--path docs/examples.xlsx \ +--output_path output \ +--output_format json \ +--enhance \ +--qa_count 10 +``` + +### 1.4 结果输出 + +结果输出在 output_path 目录下 + +## 2. 文档治理 + +### 2.1 功能简介 + +该功能用于优化文档的表现形式,功能包括 +1. 去除重复段落、文本 +2. 敏感信息脱敏(自定义敏感词、敏感内容格式) +3. 文档内容标准化,包括统一编码格式、统一全半角等 +4. 文档内容格式化,包括通用场景(段落总结,支持自定义格式)、开发场景(代码注释)、运维场景(案例整理)三种类别 + +### 2.2 功能参数说明 + +| 参数名 | 默认值 | 说明 | +| --- |---------|--------------------------------------------------------| +| method | | 必填,指定脚本的功能,此处为 "document_governance" | +| path | | 必填,指定待处理的文件路径,支持 docx、pdf、txt 等格式 | +| output_path | | 必填,指定输出路径 | +| output_format | json | 可选,指定输出格式,支持 json、xlsx、yaml 三种格式 | +| standardize | False | 可选,是否进行文档内容标准化,包括统一编码格式、统一全半角等 | +| unique | False | 可选,是否去除重复段落、文本 | +| format | False | 可选,是否进行文档内容格式化,包括通用场景(段落总结)、开发场景(代码注释)、运维场景(案例整理)三种类别 | +| format_mode | general | 可选,指定文档内容格式化模式,包括 "general"、"develop"、"OPS" 分别对应上述三种场景 | + + +### 2.3 使用示范 +```bash +python3 utils/main.py \ + document_governance \ + --path docs/test_for_document.txt \ + --output_path output/document \ + --standardize \ + --format \ + --unique \ + --output_format md \ + --format_mode develop +``` + +### 2.4 自定义内容 + +#### 2.4.1 自定义敏感词 +敏感词文件为 sensitive_words.txt,每行一个敏感词,示例如下: +```text +暴力 +色情 +赌博 +毒品 +诈骗 +``` +敏感格式表文件为 sensitive_pattern.txt,每行一个敏感句式,通过正则表达式匹配,示例如下: +```text +\b(赌|博)\b +\b(毒|品)\b +\b(诈|骗)\b +\b(暴|力)\b +\b(色|情)\b +``` +术语替换文件为 term_replacements.txt,每行一个替换词,示例如下: +```text +医生:医师 +护士:护理人员 +医院:医疗机构 +手术:外科操作 +药物:药剂 +``` + +### 3.5 结果输出 + +结果输出在 output_path 目录下 + +## 3. 向量模型微调 + +### 3.1 功能简介 + +该功能用于微调指定的向量模型,包括 bge-large-zh、bge-large-en、bge-small-zh、bge-small-en 等。 + + +### 3.2 数据集/测试集 + +#### 3.2.1 数据格式 +数据集格式为jsonl,示例如下: +``` +{"query": str, "pos": List[str], "neg":List[str], "pos_scores": List[int], "neg_scores": List[int], "prompt": str, "type": str} +``` +其中 query 为问题,pos 为正例,neg 为负例,pos_scores 为正例的打分,neg_scores 为负例的打分,prompt 为提示词,type 为数据分类。 + +测试集格式为json,示例如下: +``` +{ + "corpus": { + "content": list[str], + }, + "test_data": { + "query": list[str], + "mapper": { # key: query, value: answer + str: str, + } + } +} + +``` +其中 corpus 为语料库,test_data 为测试集,query 为问题,mapper 为正确答案。 + +#### 3.2.2 生成方式: +1) 使用脚本生成问答对到output_path,或者使用数据集和测试集自行构造成xlsx格式,格式为question列为query,answer列为答案。 +2) 使用脚本生成训练集和测试集,生成方法如下:请执行 data_processing.py 脚本将你的数据处理成如下 jsonl 格式: + +``` +{"query": str, "pos": List[str], "neg":List[str]} +``` +query是查询指令,pos是正例列表,neg是负例列表 + +``` +python utils/my_tools/bge_finetune/data_processing.py \ +--input_dir data_path \ +--output_dir output_path \ +--train_num 10000 +``` +- input_dir 问答对数据存放目录 +- output_dir 训练集和测试机输出目录 +- train_num 训练集数量 + +如果数据中没有负例,则可以使用以下命令从整个语料库中随机抽取样本做负例增强: +``` +python ./utils/my_tools/bge_finetune/hn_mine.py \ + --model_name_or_path BAAI/bge-large-zh-v1.5 \ + --input_file path/to/data.jsonl \ + --output_file path/to/data_minedHN.jsonl \ + --range_for_sampling 2-200 \ + --negative_number 15 \ + --use_gpu_for_searching +``` +- input_file:jsonl 格式的原始训练数据 +- output_file:负例增强后输出的 jsonl 数据的路径 +- range_for_sampling:采样区间,例如,2-100表示从 top2-top200 文档中采样负样本 +- negative_number:采样负样本的数量 +- use_gpu_for_searching:是否使用 faiss-gpu 来检索 + +### 3.3 功能参数说明 + + + +| 参数名 | 默认值 | 说明 | +|--------------------| --- |-------------------------------------| +| method | | 必填,指定脚本的功能,此处为 "embedding_training" | +| train_data | | 必填,指定训练数据路径,支持 jsonl 格式 | +| test_data | | 必填,指定测试数据路径,支持 json 格式 | +| output_path | | 必填,指定模型输出路径 | +| batch_size | 8 | 可选,指定训练批次大小 | +| learning_rate | 5e-5 | 可选,指定学习率 | +| deepspeed | | 可选,指定 deepspeed 配置文件路径,用于优化微调速度 | +| epochs | 3 | 可选,指定训练轮数 | +| save_steps | 1000 | 可选,指定保存模型的步数 | +| logging_steps | 100 | 可选,指定日志输出的步数 | +| gpu_num | 1 | 可选,指定使用的 GPU 数量 | +| model_name_or_path | | 可选,指定微调的模型路径,默认为 bge-large-zh-v1.5 | +| temperature | 0.02 | 可选,指定温度参数,默认为 0.02 | +| warmup | 0.1 | 可选,指定预热比例,默认为 0.1 | + + + + +### 3.4 使用示范 + +```bash +python3 utils/main.py \ + embedding_training \ + --data_path output/bge/train_data_mineHN.jsonl \ + --output_path output/test_encoder_only_base_bge-large-en-v1.5 \ + --batch_size 2 \ + --learning_rate 5e-5 \ + --deepspeed utils/my_tools/bge_finetune/ds_stage0.json \ + --epochs 1 \ + --save_steps 1000 \ + --logging_steps 100 \ + --gpu_num 4 \ + --model_name_or_path ./bge_model/bge-large-en-v1.5 \ + --temperature 0.02 \ + --warmup 0.1 +``` + +### 3.5 结果输出 +微调后模型输出在 **output_path** 对应目录下,报告输出在 **./report/embedding/{训练完成时间}** 目录下,报告包含训练过程曲线图、模型预测结果等。 + +需要进行模型评估和合并,请参考./utils/my_tools/bge_finetune/README.md \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/common/.env.example b/utils/common/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..92655194328456babd60757419e80d45a022331f --- /dev/null +++ b/utils/common/.env.example @@ -0,0 +1,15 @@ +# PROMPT_PATH +PROMPT_PATH= + +# DOCS_PATH +SENSITIVE_WORDS_PATH= +SENSITIVE_PATTERNS_PATH= +TERM_REPLACEMENTS_PATH= + +#LLM config +MODEL_NAME= +OPENAI_API_BASE= +OPENAI_API_KEY= +REQUEST_TIMEOUT= +MAX_TOKENS= +MODEL_ENH= \ No newline at end of file diff --git a/utils/config/config.py b/utils/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..35246d54d7dc3385c1ae84503f8d1d3632995b3c --- /dev/null +++ b/utils/config/config.py @@ -0,0 +1,41 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import os + +from dotenv import dotenv_values +from pydantic import BaseModel, Field + + +class ConfigModel(BaseModel): + # Prompt file + PROMPT_PATH: str = Field(None, description="prompt路径") + + # PATH + SENSITIVE_WORDS_PATH: str = Field(None, description="敏感词表存放位置") + TERM_REPLACEMENTS_PATH: str = Field(None, description="术语替换表存放位置") + SENSITIVE_PATTERNS_PATH: str = Field(None, description="敏感词匹配表存放位置") + # LLM my_tools + MODEL_NAME: str = Field(None, description="使用的语言模型名称或版本") + OPENAI_API_BASE: str = Field(None, description="语言模型服务的基础URL") + OPENAI_API_KEY: str = Field(None, description="语言模型访问密钥") + REQUEST_TIMEOUT: int = Field(None, description="大模型请求超时时间") + MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数") + MODEL_ENH: bool = Field(None, description="是否使用大模型能力增强") +class Config: + config: ConfigModel + + def __init__(self): + if os.getenv("CONFIG"): + config_file = os.getenv("CONFIG") + else: + config_file = "utils/common/.env" + self.config = ConfigModel(**(dotenv_values(config_file))) + if os.getenv("PROD"): + os.remove(config_file) + + def __getitem__(self, key): + if key in self.config.__dict__: + return self.config.__dict__[key] + return None + + +config = Config() \ No newline at end of file diff --git a/utils/docs/prompt.yaml b/utils/docs/prompt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..227550d94adac451fa0e76eae5996fdbd54cd3d1 --- /dev/null +++ b/utils/docs/prompt.yaml @@ -0,0 +1,194 @@ +GENERATE_QA: "你是一个问答生成专家,你的任务是根据我提供的段落内容和已有的问题,生成{qa_count}个不重复的针对该段落内容的问题与回答, +并判断这个问答对的属于领域,并只输出问题、回答、领域。 + +注意: + +1. 单个回答长度必须大于30字小于120字 + +2. 问题不能出现重复 + +3. 领域分为'openEuler版本信息','openEuler通用知识','openEuler技术原理',‘openEuler使用指导','openEuler软硬件支持', +'openEuler社区信息','openEuler行业应用案例' + +4. 请指定明确的场景,如'xx公司', 'xx系统', 'xx项目', ‘xx软件'等 + +5. 问题中不要使用模糊的指代词, 如'这'、'那' + +6. 划分领域的时候请忽略上下文内容 + +7. 问题必须与段落内容有逻辑关系 + +8. 问题与回答在不重复的前提下,应当尽可能多地包含段落内容 + +9. 输出的格式为: +[ + +{{ + \"question\": \" 问题 \", + \"answer\": \" 回答 \", + \"type\": \" 领域 \" +}} + +, + +{{ + \"question\": \" 问题 \", + \"answer\": \" 回答 \", + \"type\": \" 领域 \" +}} + +] + +10. 不要输出多余内容 + +下面是给出的段落内容: + +{chunk} + +下面是段落的上下文内容: + +{text} + +下面是段落的来源文件 +{file_name} +" +FORMAT_DOCUMENT_GENERAL: "你是一个文档治理专家,你的任务是根据我提供的上下文,给出一个符合要求的文档结构。 + +文档结构为: + +\"主要信息\": xxx,\n + +\"次要信息\": xxx,\n + +\"补充信息\": xxx,\n + +注意: + +1. 输入的结构为 +{{\n + 'text':xxx,\n + 'type':xxx,\n +}}\n +其中text表示内容,type表示段落的类别,image为图片,para为文字段落,table为列表的某一行 + +2. 输出的格式为 + +\"主要信息\": xxx,\n + +\"次要信息\": xxx,\n + +\"补充信息\": xxx,\n + +3. 主要信息、次要信息、补充信息的内容必须分别大于150字,130字,70字 + +4. 所有文本必须与段落内容相关 + +5. 给出的所有文本必须包含段落的主要内容和尽可能多的详细信息 + +6. 次要信息请分3点以上给出 + +7. 补充信息必须补充所有没有在主要信息和次要信息中提到的内容 + +8. 请不要输出多余内容 + +下面是给出的段落内容: +{chunk} + +下面是段落的上下文内容: +{text} + +下面是之前文本的三段论结果 +{front_text} + +下面是段落的来源文件 +{file_name} + +" +FORMAT_DOCUMENT_DEVELOP: "你是一个文档代码格式化专家,你的任务是根据我提供的段落内容和上下文,以函数的粒度为单位分成列表,给出段落中代码的注释 + +注意: + +1. 输入的结构为 +{{\n + 'text':xxx,\n + 'type':xxx,\n +}}\n +其中text表示内容,type表示段落的类别,image为图片,para为文字段落,table为列表的某一行 + +2. 请使用自然语言进行注释,但每个函数的注释需要超过30字,不能超过100字,每个函数的注释全部输出在同一行。 + +3. 如果不包含代码,那么就不用输出注释 + +4. 输出的时候需要包含函数的具体实现 + +5. 注释必须使用自然语言描述函数的读入、功能、效果 + +6. 如果原来已经有注释,尽可能补全原有的注释,但必须使用自然语言,并且在同一行 + +7. 非函数的代码请原样输出 + +下面是给出的段落内容: +{chunk} + +下面是段落的上下文内容: +{text} + +下面是之前的函数定义和注释: +{front_text} + +下面是段落的来源文件 +{file_name} +" +FORMAT_DOCUMENT_OPS: "你是一个文档测试用例格式化专家,你的任务是根据我提供的段落内容和上下文,补全上下文的运维案例,并按照标准格式输出若干个运维案例。 + +注意: + +1. 输入的结构为 +{{\n + 'text':xxx,\n + 'type':xxx,\n +}}\n +其中text表示内容,type表示段落的类别,image为图片,para为文字段落,table为列表的某一行 + +2. 案例名称是通过对运维案例的概括得到的用来查询案例的名称 + +3. 环境描述表示环境信息,包括操作系统版本、CPU架构、内存大小、磁盘大小、网络类型等 + +4. 问题详情表示问题描述,包括问题类型、问题级别、问题原因、问题影响等 + +5. 解决方案表示问题的解决方案,包括解决方法、验证方法等 + +6. 其他内容表示一些该案例中的非文字内容,包括图片、文档、超链接等 + +7. 请只输出通过上下文补全后能够完整,并且没有在的案例 + +8. 输出格式 +[ + {{ + \"案例名称A\": ,\n + \"环境描述\": ,\n + \"问题详情\": ,\n + \"解决方案\": ,\n + \"其他内容\": \n + }},\n + {{\n + \"案例名称B\": ,\n + \"环境描述\": ,\n + \"问题详情\": ,\n + \"解决方案\": ,\n + \"其他内容\": \n + }} +] + +下面是给出的段落内容: +{chunk} + +下面是段落的上下文内容: +{text} + +下面是之前的函数定义和注释: +{front_text} + +下面是段落的来源文件 +{file_name} +" \ No newline at end of file diff --git a/utils/docs/sensitive_patterns.txt b/utils/docs/sensitive_patterns.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/docs/sensitive_words.txt b/utils/docs/sensitive_words.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/docs/term_replacements.txt b/utils/docs/term_replacements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/main.py b/utils/main.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4d56f80c9531db4184f211c0eccee147e06c8c --- /dev/null +++ b/utils/main.py @@ -0,0 +1,260 @@ +import argparse +import glob +import logging +import os +import time +from datetime import datetime + +import pandas as pd +import yaml + +from utils.my_tools.llm import LLM +from utils.config.config import config +from utils.parser.service.parser_service import EasyParser +from utils.service.embedding_training import EmbeddingTraining +from utils.service.qa_generate import QAgenerator +from utils.service.document_governance import DocumentGovernance + + +class CLIService: + def __init__(self): + self.prompt_dict = self.get_prompt_dict() + self.args = self.parse_arguments() + self.llm = LLM(model_name=config['MODEL_NAME'], + openai_api_base=config['OPENAI_API_BASE'], + openai_api_key=config['OPENAI_API_KEY'], + max_tokens=config['MAX_TOKENS'], + request_timeout=60, + temperature=0.3) + + @staticmethod + def get_prompt_dict(): + """ + 获取prompt表 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + return prompt_dict + except Exception as e: + logging.error(f'Get prompt failed : {e}') + raise e + + async def parser_file(self, file_path): + model = EasyParser() + answer = await model.parser(file_path, self.llm) + return answer['chunk_list'] + + import os + import glob + + async def get_all_file(self): + path_list = [] + + # 获取文件路径 + if os.path.isfile(self.args.path): + # 如果是单个文件,直接添加到路径列表 + path_list.append(self.args.path) + elif os.path.isdir(self.args.path): + # 如果是目录,遍历目录下的所有文件 + for root, _, files in os.walk(self.args.path): + for file in files: + path_list.append(os.path.join(root, file)) + else: + # 如果是通配符路径,使用 glob 匹配文件 + for path in glob.glob(self.args.path): + if os.path.isfile(path): + path_list.append(path) + + return path_list + + async def solve(self): + """ + 处理 + """ + times = [] + start_time = time.time() + if self.args.mode == 'embedding_training': + await self.embedding_training() + times.append({'file': 'embedding_training', "time": time.time() - start_time}) + else: + path_list = await self.get_all_file() + print(f"文件列表: {path_list}") + print(f"文件总数:", len(path_list)) + sum_qa = 0 + ret_count = 0 + sum_chunks = 0 + for file in path_list: + temp_start_time = time.time() + chunks = [] + if self.args.mode in ['qa_generate', 'document_governance']: + try: + chunks = await self.parser_file(file) + new_chunks = [] + for chunk in chunks: + new_chunk = { + 'text': chunk['text'], + 'type': chunk['type'].split('.')[1], + } + new_chunks.append(new_chunk) + chunks = new_chunks + except Exception as e: + print(f'parser {file} error:', e) + continue + try: + if self.args.mode == 'qa_generate': + qa_count = self.args.qa_count + ret_count + results, ans = await self.qa_generate(chunks, file, qa_count) + if ans < qa_count: + ret_count = qa_count - ans + else: + ret_count = 0 + sum_qa = sum_qa + ans + elif self.args.mode == 'document_governance': + chunks = await self.document_governance(chunks, file) + sum_chunks = sum_chunks + len(chunks) + except Exception as e: + print(f'solve {file} error:', e) + continue + + print(f"文件 {file} 完成,用时:{time.time() - temp_start_time}") + times.append({'file': file, "time": time.time() - temp_start_time, + "count": ans if self.args.mode == 'qa_generate' else len(chunks)}) + # logging.info(f"文件 {file} 完成,用时:{time.time() - temp_start_time}") + if self.args.mode == 'qa_generate': + print(f"问答对总数:{sum_qa}") + # logging.info(f"问答对总数:{sum_qa}") + if self.args.mode == 'document_governance': + print(f"文档治理完成,共处理 {sum_chunks} 个文本块") + # logging.info(f"问答对总数:{sum_chunks}") + + print(f"处理完成,耗时:{time.time() - start_time}") + # logging.info(f"处理完成,耗时:{time.time() - start_time}") + # 输出times到path_time.xlsx + df = pd.DataFrame(times) + os.makedirs("logs", exist_ok=True) + df.to_excel(f"logs/{self.args.mode}_{datetime.now().strftime('%Y%m%d%H%M%S')}.xlsx", index=False) + print(f'Excel结果已输出到logs/{self.args.mode}_{datetime.now().strftime("%Y%m%d%H%M%S")}.xlsx') + + async def qa_generate(self, chunks, file, qa_count): + model = QAgenerator() + prompt = self.prompt_dict.get('GENERATE_QA') + file = file.split('/')[-1].split('.')[:-1] + file = ''.join(file) + sum_qa = 0 + results, ans = await model.qa_generate(chunks, file, qa_count, prompt, self.llm, self.args.enhance) + sum_qa = sum_qa + ans + await model.output_results(results, file, self.args.output_path, self.args.output_format) + return results, sum_qa + + async def document_governance(self, chunks, file): + """ + 文档治理 + """ + model = DocumentGovernance() + if self.args.format: + if self.args.format_mode == 'develop': + prompt = self.prompt_dict.get('FORMAT_DOCUMENT_DEVELOP') + elif self.args.format_mode == 'OPS': + prompt = self.prompt_dict.get('FORMAT_DOCUMENT_OPS') + else: + prompt = self.prompt_dict.get('FORMAT_DOCUMENT_GENERAL') + chunks = await model.format(chunks, prompt, self.llm, file) + if self.args.standardize: + chunks = await model.standardize(chunks) + if self.args.unique: + chunks = await model.unique(chunks) + # print(chunks) + # for chunk in chunks: + # print(chunk['text']) + file = file.split('/')[-1].split('.')[0] + model.output_chunks_to_file(self.args.output_path, chunks, file, self.args.output_format) + # print(json.dumps(chunks, indent=4, ensure_ascii=False)) + return chunks + + async def embedding_training(self): + model = EmbeddingTraining() + model.run(self.args) + return True + + def parse_arguments(self): + """ + 解析命令行参数,根据功能模式区分并解析其他参数。 + """ + # 创建参数解析器 + parser = argparse.ArgumentParser(description="文档处理工具:问答对生成、文档优化和embedding模型微调") + + # 创建子解析器 + subparsers = parser.add_subparsers(dest="mode", required=True, + help="功能模式:qa_generate(问答对生成)、document_governance(文档优化)、embedding_training(embedding模型微调)") + + # 子解析器:qa_generate + qa_parser = subparsers.add_parser("qa_generate", help="问答对生成") + qa_parser.add_argument("--path", type=str, required=True, + help="待处理文件或目录路径(默认值:docs)") + qa_parser.add_argument("--output_path", type=str, required=True, + help="结果存放目录(默认值:output)") + qa_parser.add_argument("--output_format", type=str, choices=["json", "yaml", "xlsx"], + default="json", + help="答案输出支持导出的格式(默认值:json)") + qa_parser.add_argument("--enhance", action="store_true", + help="是否启用生成强化策略(默认值:False)") + qa_parser.add_argument("--qa_count", type=int, default=1, + help="每片段问答对数量目标(默认值:1)") + + # 子解析器:document_governance + doc_parser = subparsers.add_parser("document_governance", help="文档优化") + doc_parser.add_argument("--path", type=str, required=True, + help="待处理文件或目录路径(默认值:docs)") + doc_parser.add_argument("--output_path", type=str, required=True, + help="结果存放目录(默认值:output)") + doc_parser.add_argument("--unique", action="store_true", + help="是否启用去重策略(默认值:False)") + doc_parser.add_argument("--standardize", action="store_true", + help="是否启用标准化策略(默认值:False)") + doc_parser.add_argument("--format", action="store_true", + help="是否启用格式化策略(默认值:False)") + doc_parser.add_argument("--output_format", type=str, choices=["md", "docx"], + default="json", + help="答案输出支持导出的格式(默认值:md)") + doc_parser.add_argument("--format_mode", type=str, choices=["develop", "OPS", "general"], default="general", + help="格式化模式(默认值:general)") + + # 子解析器:embedding_training + embedding_parser = subparsers.add_parser("embedding_training", help="embedding模型微调") + embedding_parser.add_argument("--train_data", type=str, required=True, + help="训练数据集路径") + embedding_parser.add_argument("--test_data", type=str, required=True, + help="测试数据集路径") + embedding_parser.add_argument("--output_path", type=str, required=True, + help="模型微调结果存放目录") + embedding_parser.add_argument("--batch_size", type=int, default=32, + help="每次迭代中使用的样本数量(默认值:32)") + embedding_parser.add_argument("--learning_rate", type=float, default=5e-5, + help="控制模型学习的速度(默认值:5e-5)") + embedding_parser.add_argument("--deepspeed", type=str, default="utils/bge_finetune/ds_stage0.json", + help="DeepSpeed 配置文件的路径(默认值:utils/bge_finetune/ds_stage0.json)") + embedding_parser.add_argument("--epochs", type=int, default=10, + help="整个训练过程中遍历数据集的次数(默认值:10)") + embedding_parser.add_argument("--save_steps", type=int, default=1000, + help="指定每多少个步骤保存一次模型检查点(默认值:1000)") + embedding_parser.add_argument("--logging_steps", type=int, default=100, + help="指定每多少个步骤输出一次日志信息(默认值:100)") + embedding_parser.add_argument("--gpu_num", type=int, default=2, + help="每个节点上使用的GPU数量(默认值:2)") + embedding_parser.add_argument("--model_name_or_path", type=str, default="None", + help="预训练模型的路径或名称") + embedding_parser.add_argument("--temperature", type=float, default=0.02, + help="对比学习中的温度参数(默认值:0.02)") + embedding_parser.add_argument("--warmup", type=float, default=0.1, + help="学习率预热的比例(默认值:0.1)") + embedding_parser.add_argument("--tokens", type=int, default=512, + help="每个样本的问答对token上限(默认值:512)") + + return parser.parse_args() + + +import asyncio + +if __name__ == "__main__": + asyncio.run(CLIService().solve()) diff --git a/utils/my_tools/__init__.py b/utils/my_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/my_tools/bge_finetune/README.md b/utils/my_tools/bge_finetune/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2af6820a29a4bef32d399d15f58c487ba7a616b6 --- /dev/null +++ b/utils/my_tools/bge_finetune/README.md @@ -0,0 +1,112 @@ +# bge-large-zh微调指南 + +### BGE简介 + +BGE(BAAI General Embedding)是北京智源人工智能研究院开源的一系列embedding大模型,其核心功能是将文本转换为高维向量表示。这些向量捕捉了文本中的语义信息,为后续的相似性搜索提供了便利。 + +bge-large-zh是BGE系列中参数规模最大的中文向量大模型,参数3.26亿。输入序列512,输出维度1024,是本文所使用的BGE模型版本。 + +### 准备工作 + +#### 安装 FlagEmbedding +- 使用pip +``` +pip install -U FlagEmbedding +``` + +- 源码安装 +``` +git clone https://github.com/FlagOpen/FlagEmbedding.git +cd FlagEmbedding +pip install . +``` +### 数据处理 +请执行 data_processing.py 脚本将你的数据处理成如下 jsonl 格式: +``` +{"query": str, "pos": List[str], "neg":List[str]} +``` +query是查询指令,pos是正例列表,neg是负例列表 + +``` +python data_processing.py --input_dir data_path --output_dir output_path --train_num 10000 +``` +- input_dir 问答对数据存放目录 +- output_dir 训练集和测试机输出目录 +- train_num 训练集数量 + +如果数据中没有负例,则可以使用以下命令从整个语料库中随机抽取样本做负例增强: +``` +python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \ + --model_name_or_path BAAI/bge-large-zh-v1.5 \ + --input_file path/to/data.jsonl \ + --output_file path/to/data_minedHN.jsonl \ + --range_for_sampling 2-200 \ + --negative_number 15 \ + --use_gpu_for_searching +``` +- input_file:jsonl 格式的原始训练数据 +- output_file:负例增强后输出的 jsonl 数据的路径 +- range_for_sampling:采样区间,例如,2-100表示从 top2-top200 文档中采样负样本 +- negative_number:采样负样本的数量 +- use_gpu_for_searching:是否使用 faiss-gpu 来检索 + +### 运行微调 +``` +torchrun \ + --nproc_per_node 2 \ + -m FlagEmbedding.baai_general_embedding.finetune.run \ + --output_dir path/to/fine_tuned_model \ + --model_name_or_path BAAI/bge-large-zh-v1.5 \ + --train_data path/to/data_minedHN.jsonl \ + --learning_rate 1e-5 \ + --fp16 \ + --num_train_epochs 30 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 10 \ + --dataloader_drop_last True \ + --normlized True \ + --temperature 0.02 \ + --query_max_len 64 \ + --passage_max_len 512 \ + --train_group_size 2 \ + --negatives_cross_device \ + --logging_steps 10 \ + --save_steps 0.2 \ + --query_instruction_for_retrieval "" +``` +### 模型合并[可选的] +``` +from LM_Cocktail import mix_models, mix_models_with_data + +model = mix_models( + model_names_or_paths=["BAAI/bge-large-zh-v1.5", "your_fine-tuned_model"], + model_type='encoder', + weights=[0.5, 0.5], # you can change the weights to get a better trade-off. + output_path='./mixed_model_1') +``` +### 评估 +- 首先,安装 faiss (近似最近邻搜索库): +``` +pip install faiss-gpu +``` +- 执行 eval.py 脚本评估模型,计算召回率和 MRR 指标: +``` +python eval.py \ + --encoder your_model_path \ + --fp16 \ + --max_query_length 64 \ + --max_passage_length 512 \ + --batch_size 4 \ + --val_data your_val_dataset_path_ \ + --add_instruction \ + --k 100 +``` +- 评估结果示例: +``` +# 微调前 +{'MRR@1': 0.48, 'MRR@10': 0.4985634920634917, 'MRR@100': 0.5121167012782633, 'Recall@1': 0.296, 'Recall@10': 0.368, 'Recall@100': 0.65} +# 微调后 +{'MRR@1': 0.702, 'MRR@10': 0.7113190476190476, 'MRR@100': 0.721750086730327, 'Recall@1': 0.418, 'Recall@10': 0.49, 'Recall@100': 0.918} +# 微调前后模型合并[0.5, 0.5] +{'MRR@1': 0.722, 'MRR@10': 0.7387634920634922, 'MRR@100': 0.7472510142865891, 'Recall@1': 0.432, 'Recall@10': 0.519, 'Recall@100': 0.901} +``` \ No newline at end of file diff --git a/utils/my_tools/bge_finetune/__init__.py b/utils/my_tools/bge_finetune/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/my_tools/bge_finetune/data_processing.py b/utils/my_tools/bge_finetune/data_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf520c70223bc3b0e6e57e90460d8873bfd9c5c --- /dev/null +++ b/utils/my_tools/bge_finetune/data_processing.py @@ -0,0 +1,145 @@ +import glob + +import pandas as pd +import jsonlines +import json +import os +import random +import argparse + + +def get_all_file(path): + path_list = [] + + # 获取文件路径 + if os.path.isfile(path): + # 如果是单个文件,直接添加到路径列表 + path_list.append(path) + elif os.path.isdir(path): + # 如果是目录,遍历目录下的所有文件 + for root, _, files in os.walk(path): + for file in files: + path_list.append(os.path.join(root, file)) + else: + # 如果是通配符路径,使用 glob 匹配文件 + for path in glob.glob(path): + if os.path.isfile(path): + path_list.append(path) + + return path_list + +parser = argparse.ArgumentParser() +parser.add_argument("--input_dir", required=True, type=str, help="path to data dir") +parser.add_argument("--output_dir", required=True, type=str, help="path to output dir") +parser.add_argument("--train_num", required=True, type=int, help="number of train data") + +args = parser.parse_args() + +#TODO: 适配输入格式 +data_dir = args.input_dir +output_dir = args.output_dir +file_list = get_all_file(data_dir) +df = pd.DataFrame(columns=['question', 'answer', 'text']) +for file in file_list: + path = os.path.join(data_dir, file) + print('current: ', path) + _df = pd.read_excel(path).loc[:, ['question', 'answer', 'text']] + _df.dropna(inplace=True) + df = pd.concat([df, _df]) +df.reset_index(drop=True, inplace=True) + +ids = df.index.to_list() +#TODO: 统一输入转换成list + + +random.shuffle(ids) +ids_count = len(ids) +if args.train_num > ids_count: + print('train_num > ids_count') + exit() +train_ids, val_ids = ids[:args.train_num], ids[args.train_num:] + +print("train_data:", len(train_ids), "test_data:", len(val_ids)) + +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# 生成训练集 +print('正在生成训练集...') +train_data_path = os.path.join(output_dir, 'train_data.jsonl') +with jsonlines.open(train_data_path, 'a') as f: + for i in train_ids: + dic = {} + dic['query'] = df['question'][i] + # score = df['score'][i] + # if not score == "是": + # continue + dic['pos'], dic['neg'] = [], [] + pos = str(df['answer'][i]) + length = len(pos) + cut_idx = length // 2 + if length > 700: + pos1 = pos[:cut_idx] + pos2 = pos[cut_idx:] + dic['pos'].append(pos1) + dic['pos'].append(pos2) + else: + dic['pos'].append(pos) + + pos = str(df['text'][i]) + length = len(pos) + cut_idx = length // 2 + if length > 700: + pos1 = pos[:cut_idx] + pos2 = pos[cut_idx:] + dic['pos'].append(pos1) + dic['pos'].append(pos2) + else: + dic['pos'].append(pos) + f.write(dic) +print(f'训练集生成完成: {train_data_path}') + +# 生成测试集 +print('正在生成测试集...') +test_data_path = os.path.join(output_dir, 'test_data.json') +test_dataset = { + "corpus": { + "content": [] + }, + "test_data": { + "query": [], + "mapper": { + } + } +} +for idx in range(len(df)): + cor = df['answer'][idx] + length = len(cor) + cut_idx = length // 2 + if length > 700: + cor1 = cor[:cut_idx] + cor2 = cor[cut_idx:] + test_dataset['corpus']['content'].append(cor1) + test_dataset['corpus']['content'].append(cor2) + else: + test_dataset['corpus']['content'].append(cor) +for idx in val_ids: + query = df['question'][idx] + test_dataset['test_data']['query'].append(query) + poses = [] + pos = df['answer'][idx] + length = len(pos) + cut_idx = length // 2 + if length > 700: + pos1 = pos[:cut_idx] + pos2 = pos[cut_idx:] + poses.append(pos1) + poses.append(pos2) + else: + poses.append(pos) + test_dataset['test_data']['mapper'][query] = poses + +with open(test_data_path, 'w', encoding='utf-8') as f: + json.dump(test_dataset, f, indent=2, ensure_ascii=False) + +print(f'测试集生成完成: {test_data_path}') diff --git a/utils/my_tools/bge_finetune/ds_stage0.json b/utils/my_tools/bge_finetune/ds_stage0.json new file mode 100644 index 0000000000000000000000000000000000000000..f8db062b407d32b468a099582ca4438fdae41ef0 --- /dev/null +++ b/utils/my_tools/bge_finetune/ds_stage0.json @@ -0,0 +1,45 @@ +{ + "zero_optimization": { + "stage": 0 + }, + + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "bf16": { + "enabled": "auto" + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/utils/my_tools/bge_finetune/eval.py b/utils/my_tools/bge_finetune/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c018205268240bb298b527adf64dc75078eee4 --- /dev/null +++ b/utils/my_tools/bge_finetune/eval.py @@ -0,0 +1,291 @@ +import faiss +import torch +import logging +import datasets +import json +import numpy as np +from tqdm import tqdm +from typing import Optional +from dataclasses import dataclass, field +from transformers import HfArgumentParser +from FlagEmbedding import FlagModel + +logger = logging.getLogger(__name__) + + +@dataclass +class Args: + encoder: str = field( + default="BAAI/bge-base-en-v1.5", + metadata={'help': 'The encoder name or path.'} + ) + fp16: bool = field( + default=False, + metadata={'help': 'Use fp16 in inference?'} + ) + add_instruction: bool = field( + default=False, + metadata={'help': 'Add query-side instruction?'} + ) + test_data: str = field( + default=False, + metadata={'help': 'test_data path'} + ) + max_query_length: int = field( + default=32, + metadata={'help': 'Max query length.'} + ) + max_passage_length: int = field( + default=128, + metadata={'help': 'Max passage length.'} + ) + batch_size: int = field( + default=256, + metadata={'help': 'Inference batch size.'} + ) + index_factory: str = field( + default="Flat", + metadata={'help': 'Faiss index factory.'} + ) + k: int = field( + default=100, + metadata={'help': 'How many neighbors to retrieve?'} + ) + + save_embedding: bool = field( + default=False, + metadata={'help': 'Save embeddings in memmap at save_dir?'} + ) + load_embedding: bool = field( + default=False, + metadata={'help': 'Load embeddings from save_dir?'} + ) + save_path: str = field( + default="embeddings.memmap", + metadata={'help': 'Path to save embeddings.'} + ) + +def index(model: FlagModel, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False): + if load_embedding: + test = model.encode("test") + dtype = test.dtype + dim = len(test) + + corpus_embeddings = np.memmap( + save_path, + mode="r", + dtype=dtype + ).reshape(-1, dim) + + else: + corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_length) + dim = corpus_embeddings.shape[-1] + + if save_embedding: + logger.info(f"saving embeddings at {save_path}...") + memmap = np.memmap( + save_path, + shape=corpus_embeddings.shape, + mode="w+", + dtype=corpus_embeddings.dtype + ) + + length = corpus_embeddings.shape[0] + # add in batch + save_batch_size = 10000 + if length > save_batch_size: + for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"): + j = min(i + save_batch_size, length) + memmap[i: j] = corpus_embeddings[i: j] + else: + memmap[:] = corpus_embeddings + + # create faiss index + faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT) + + # if model.device == torch.device("cuda"): + if False: + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = True + # faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co) + faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co) + + # NOTE: faiss only accepts float32 + logger.info("Adding embeddings...") + corpus_embeddings = corpus_embeddings.astype(np.float32) + faiss_index.train(corpus_embeddings) + faiss_index.add(corpus_embeddings) + return faiss_index + + +def search(model: FlagModel, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512): + query_embeddings = model.encode_queries(queries["query"], batch_size=batch_size, max_length=max_length) + query_size = len(query_embeddings) + + all_scores = [] + all_indices = [] + + for i in tqdm(range(0, query_size, batch_size), desc="Searching"): + j = min(i + batch_size, query_size) + query_embedding = query_embeddings[i: j] + score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k) + all_scores.append(score) + all_indices.append(indice) + + all_scores = np.concatenate(all_scores, axis=0) + all_indices = np.concatenate(all_indices, axis=0) + return all_scores, all_indices + + +import numpy as np + + +def evaluate(preds, labels, cutoffs=[1, 5, 10, 30]): + metrics = {} + + # MRR + mrrs = np.zeros(len(cutoffs)) + for pred, label in zip(preds, labels): + jump = False + for i, x in enumerate(pred, 1): + if x in label: + for k, cutoff in enumerate(cutoffs): + if i <= cutoff: + mrrs[k] += 1 / i + jump = True + if jump: + break + mrrs /= len(preds) + for i, cutoff in enumerate(cutoffs): + mrr = mrrs[i] + metrics[f"MRR@{cutoff}"] = mrr + + # Recall + recalls = np.zeros(len(cutoffs)) + for pred, label in zip(preds, labels): + for k, cutoff in enumerate(cutoffs): + recall = np.intersect1d(label, pred[:cutoff]) + recalls[k] += len(recall) / len(label) + recalls /= len(preds) + for i, cutoff in enumerate(cutoffs): + recall = recalls[i] + metrics[f"Recall@{cutoff}"] = recall + + # Precision + precisions = np.zeros(len(cutoffs)) + for pred, label in zip(preds, labels): + for k, cutoff in enumerate(cutoffs): + precision = np.intersect1d(label, pred[:cutoff]) + precisions[k] += len(precision) / cutoff + precisions /= len(preds) + for i, cutoff in enumerate(cutoffs): + precision = precisions[i] + metrics[f"Precision@{cutoff}"] = precision + + # F1-Score@K + f1_scores = np.zeros(len(cutoffs)) + for pred, label in zip(preds, labels): + for k, cutoff in enumerate(cutoffs): + recall = np.intersect1d(label, pred[:cutoff]) + precision = np.intersect1d(label, pred[:cutoff]) + recall_value = len(recall) / len(label) + precision_value = len(precision) / cutoff + if precision_value + recall_value > 0: + f1_score = 2 * (precision_value * recall_value) / (precision_value + recall_value) + else: + f1_score = 0.0 + f1_scores[k] += f1_score + f1_scores /= len(preds) + for i, cutoff in enumerate(cutoffs): + f1_score = f1_scores[i] + metrics[f"F1-Score@{cutoff}"] = f1_score + + # Hit Rate + hit_rates = np.zeros(len(cutoffs)) + for pred, label in zip(preds, labels): + for k, cutoff in enumerate(cutoffs): + hit = np.intersect1d(label, pred[:cutoff]) + if len(hit) > 0: + hit_rates[k] += 1 + hit_rates /= len(preds) + for i, cutoff in enumerate(cutoffs): + hit_rate = hit_rates[i] + metrics[f"HitRate@{cutoff}"] = hit_rate + + # NDCG + ndcgs = np.zeros(len(cutoffs)) + for pred, label in zip(preds, labels): + for k, cutoff in enumerate(cutoffs): + pred_cutoff = pred[:cutoff] + dcg = 0.0 + idcg = 0.0 + for i, item in enumerate(pred_cutoff, 1): + if item in label: + dcg += 1.0 / np.log2(i + 1) + for i in range(1, len(label) + 1): + idcg += 1.0 / np.log2(i + 1) + if idcg == 0: + ndcgs[k] += 0 + else: + ndcgs[k] += dcg / idcg + ndcgs /= len(preds) + for i, cutoff in enumerate(cutoffs): + ndcg = ndcgs[i] + metrics[f"NDCG@{cutoff}"] = ndcg + + return metrics +def main(): + parser = HfArgumentParser([Args]) + args: Args = parser.parse_args_into_dataclasses()[0] + + with open(args.test_data, "r") as f: + test_data = json.load(f) + # print(json.dumps(test_data, indent=4, ensure_ascii=False)) + # print(json.dumps(test_data, indent=4, ensure_ascii=False)) + corpus = test_data['corpus'] + test_data = test_data['test_data'] + + model = FlagModel( + args.encoder, + query_instruction_for_retrieval="Represent this sentence for searching relevant passages: " if args.add_instruction else None, + use_fp16=args.fp16 + ) + + faiss_index = index( + model=model, + corpus=corpus, + batch_size=args.batch_size, + max_length=args.max_passage_length, + index_factory=args.index_factory, + save_path=args.save_path, + save_embedding=args.save_embedding, + load_embedding=args.load_embedding + ) + + scores, indices = search( + model=model, + queries=test_data, + faiss_index=faiss_index, + k=args.k, + batch_size=args.batch_size, + max_length=args.max_query_length + ) + retrieval_results = [] + for indice in indices: + # filter invalid indices + indice = indice[indice != -1].tolist() + conts = [] + for ind in indice: + conts.append(corpus['content'][ind]) + retrieval_results.append(conts) + # print(retrieval_results) + ground_truths = [] + for sample in test_data['query']: + ground_truths.append(test_data['mapper'][sample]) + metrics = evaluate(retrieval_results, ground_truths) + for k, v in metrics.items(): + print('{'+f"\"{k}\": {v:.4f}"+'}') + + +if __name__ == "__main__": + main() diff --git a/utils/my_tools/bge_finetune/get_report.py b/utils/my_tools/bge_finetune/get_report.py new file mode 100644 index 0000000000000000000000000000000000000000..8aec082d3124d022e0852a4d48729cd3a6e53e7a --- /dev/null +++ b/utils/my_tools/bge_finetune/get_report.py @@ -0,0 +1,123 @@ +import datetime +import json + +import numpy as np +import pandas as pd +import os +import matplotlib.pyplot as plt + +# 读取数据 +data = [] +with open('./logs/embedding/temp.log', 'r') as file: + for line in file: + try: + data.append(eval(line)) + except: + print(line) +file_path = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +# 将数据转换为DataFrame +df = pd.DataFrame(data) +if os.path.exists(f'./report'): + pass +else: + os.mkdir(f'./report') +if os.path.exists(f'./report/{file_path}'): + pass +else: + os.mkdir(f'./report/{file_path}') + +print(df) +plt.xticks(np.arange(0, int(df['epoch'].max()) + 1)) +# 绘制损失随时间变化的图表 +plt.figure(figsize=(10, 5)) +plt.plot(df['epoch'], df['loss'], label='Loss') +plt.title('Training Loss Over Epochs') +plt.xlabel('Epoch') +plt.ylabel('Loss') +plt.legend() +plt.grid(True) +plt.savefig(f'./report/{file_path}/loss_over_epochs.png') + +# 绘制学习率随时间变化的图表 +plt.figure(figsize=(10, 5)) +plt.plot(df['epoch'], df['learning_rate'], label='Learning Rate', color='orange') +plt.title('Learning Rate Over Epochs') +plt.xlabel('Epoch') +plt.ylabel('Learning Rate') +plt.legend() +plt.grid(True) +plt.savefig(f'./report/{file_path}/learning_rate_over_epochs.png') + +df.to_excel(f'./report/{file_path}/training_report.xlsx', index=False) + +# 读取日志文件 +def read_log(file_path): + data = {} + with open(file_path, 'r') as f: + for line in f: + record = json.loads(line.strip()) + key = list(record.keys())[0] + value = record[key] + data[key] = value + return data + +# 读取 base.log 和 target.log +base_data = read_log('./logs/embedding/base.log') +target_data = read_log('./logs/embedding/target.log') + +# 生成表格 +def generate_table(base_data, target_data): + + # 生成 Markdown 表格 + table = "| Metric | Base Value | final Value |\n" + table += "|--------|------------|--------------|\n" + for key,v in base_data.items(): + base_value = base_data.get(key, "-") + target_value = target_data.get(key, "-") + table += f"| {key} | {base_value} | {target_value} |\n" + return table + +# 生成表格 +table = generate_table(base_data, target_data) + + +# 生成训练报告 +report = f""" +训练报告 +================ + +- **总训练时间**: {df['train_runtime'].iloc[-1]:.2f} 秒 +- **平均每秒样本数**: {df['train_samples_per_second'].iloc[-1]:.2f} +- **平均每秒步骤数**: {df['train_steps_per_second'].iloc[-1]:.2f} +- **最终损失值**: {df['loss'].iloc[-2]} +- **最终学习率**: {df['learning_rate'].iloc[-2]} +- **最终轮次**: {df['epoch'].iloc[-1]:.2f} + +损失和学习率图表: +----------------------------- +![各轮次的损失变化](./loss_over_epochs.png) +![各轮次的学习率变化](./learning_rate_over_epochs.png) + +选取的指标: +------------------- +在评估模型性能时,我们选择了以下几类关键指标来衡量不同方面的表现。这些指标有助于全面理解模型在推荐或检索任务中的效果。 + +- **MRR (Mean Reciprocal Rank)**: 用于评估排序质量,特别是对于首位相关性高的项目。MRR 能够反映出模型将相关项排在前面的能力,较高的MRR值表明模型更早地呈现了相关项。 +- **Recall (召回率)**: 表示正确推荐的项目占所有相关项目的比例,反映了模型覆盖真实相关项的能力。高召回率意味着模型能够找到更多的相关项。 +- **Precision (精确率)**: 表示推荐列表中真正相关的项目比例,尤其对高位置的推荐项重要。高精确率意味着推荐的大多数项都是用户感兴趣的。 +- **F1-Score**: 是精确率和召回率的调和平均数,提供了一个综合考虑两者平衡的评价标准。F1-Score 在精确率和召回率之间寻求平衡,特别适用于两者都重要的场景。当模型需要同时优化精确率和召回率时,F1-Score 是一个非常有用的指标。 +- **HitRate (命中率)**: 指的是是否至少有一个相关项被推荐,是二元评价标准。HitRate 强调了模型是否有能力将任何相关项包含在推荐列表中。 +- **NDCG (Normalized Discounted Cumulative Gain)**: 考虑了推荐列表中项目的位置权重,越靠前的相关项得分越高。NDCG 是一个综合性的评价指标,既考虑了相关性也考虑了位置因素。 + +上述指标均在不同的K值下计算(如@1, @5, @10等),以评估不同长度结果列表的性能。选择这些指标是因为它们在信息检索和推荐系统中广泛应用,并且可以从不同角度全面评估模型的性能。 + +指标对比: +------------------- +{table} +""" + +# 保存报告到文件 +with open(f'./report/{file_path}/training_report.md', 'w') as file: + file.write(report) + +print(f"Training report generated successfully. Report in ./report/{file_path}.") \ No newline at end of file diff --git a/utils/my_tools/bge_finetune/hn_mine.py b/utils/my_tools/bge_finetune/hn_mine.py new file mode 100644 index 0000000000000000000000000000000000000000..67980e1ae7712a6c741dbfaa7053b628e5fe95eb --- /dev/null +++ b/utils/my_tools/bge_finetune/hn_mine.py @@ -0,0 +1,237 @@ +import json +import random +import numpy as np +from tqdm import tqdm +from typing import Optional +from dataclasses import dataclass, field + +import faiss +from transformers import HfArgumentParser +from FlagEmbedding import FlagAutoModel +from FlagEmbedding.abc.inference import AbsEmbedder + + +@dataclass +class DataArgs: + """ + Data arguments for hard negative mining. + """ + input_file: str = field( + metadata={"help": "The input file for hard negative mining."} + ) + output_file: str = field( + metadata={"help": "The output file for hard negative mining."} + ) + candidate_pool: Optional[str] = field( + default=None, metadata={"help": "The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key 'text'."} + ) + range_for_sampling: str = field( + default="10-210", metadata={"help": "The range to sample negatives."} + ) + negative_number: int = field( + default=15, metadata={"help": "The number of negatives."} + ) + use_gpu_for_searching: bool = field( + default=False, metadata={"help": "Whether to use faiss-gpu for searching."} + ) + search_batch_size: int = field( + default=64, metadata={"help": "The batch size for searching."} + ) + + +@dataclass +class ModelArgs: + """ + Model arguments for embedder. + """ + embedder_name_or_path: str = field( + metadata={"help": "The embedder name or path.", "required": True} + ) + embedder_model_class: Optional[str] = field( + default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]} + ) + normalize_embeddings: bool = field( + default=True, metadata={"help": "whether to normalize the embeddings"} + ) + pooling_method: str = field( + default="cls", metadata={"help": "The pooling method fot the embedder."} + ) + use_fp16: bool = field( + default=True, metadata={"help": "whether to use fp16 for inference"} + ) + devices: Optional[str] = field( + default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"} + ) + query_instruction_for_retrieval: Optional[str] = field( + default=None, metadata={"help": "Instruction for query"} + ) + query_instruction_format_for_retrieval: str = field( + default="{}{}", metadata={"help": "Format for query instruction"} + ) + examples_for_task: Optional[str] = field( + default=None, metadata={"help": "Examples for task"} + ) + examples_instruction_format: str = field( + default="{}{}", metadata={"help": "Format for examples instruction"} + ) + trust_remote_code: bool = field( + default=False, metadata={"help": "Trust remote code"} + ) + cache_dir: str = field( + default=None, metadata={"help": "Cache directory for models."} + ) + # ================ for inference =============== + batch_size: int = field( + default=3000, metadata={"help": "Batch size for inference."} + ) + embedder_query_max_length: int = field( + default=512, metadata={"help": "Max length for query."} + ) + embedder_passage_max_length: int = field( + default=512, metadata={"help": "Max length for passage."} + ) + + +def create_index(embeddings: np.ndarray, use_gpu: bool = False): + index = faiss.IndexFlatIP(len(embeddings[0])) + embeddings = np.asarray(embeddings, dtype=np.float32) + if use_gpu: + co = faiss.GpuMultipleClonerOptions() + co.shard = True + co.useFloat16 = True + index = faiss.index_cpu_to_all_gpus(index, co=co) + index.add(embeddings) + return index + + +def batch_search( + index: faiss.Index, + query: np.ndarray, + topk: int = 200, + batch_size: int = 64 +): + all_scores, all_inxs = [], [] + for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256): + batch_query = query[start_index:start_index + batch_size] + batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk) + all_scores.extend(batch_scores.tolist()) + all_inxs.extend(batch_inxs.tolist()) + return all_scores, all_inxs + + +def get_corpus(candidate_pool: str): + corpus = [] + with open(candidate_pool, "r", encoding="utf-8") as f: + for line in f.readlines(): + line = json.loads(line.strip()) + corpus.append(line['text']) + return corpus + + +def find_knn_neg( + model: AbsEmbedder, + input_file: str, + output_file: str, + candidate_pool: Optional[str] = None, + sample_range: str = "10-210", + negative_number: int = 15, + use_gpu: bool = False +): + corpus = [] + queries = [] + train_data = [] + for line in open(input_file): + try: + line = json.loads(line.strip()) + train_data.append(line) + corpus.extend(line['pos']) + if 'neg' in line: + corpus.extend(line['neg']) + queries.append(line['query']) + except: + print(line) + continue + + if candidate_pool is not None: + if not isinstance(candidate_pool, list): + candidate_pool = get_corpus(candidate_pool) + corpus = list(set(candidate_pool)) + else: + corpus = list(set(corpus)) + + print(f'inferencing embedding for corpus (number={len(corpus)})--------------') + p_vecs = model.encode(corpus) + print(f'inferencing embedding for queries (number={len(queries)})--------------') + q_vecs = model.encode_queries(queries) + + print('create index and search------------------') + index = create_index(p_vecs, use_gpu=use_gpu) + _, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1]) + assert len(all_inxs) == len(train_data) + + for i, data in enumerate(train_data): + query = data['query'] + inxs = all_inxs[i][sample_range[0]:sample_range[1]] + filtered_inx = [] + for inx in inxs: + if inx == -1: break + if corpus[inx] not in data['pos'] and corpus[inx] != query: + filtered_inx.append(inx) + + if len(filtered_inx) > negative_number: + filtered_inx = random.sample(filtered_inx, negative_number) + data['neg'] = [corpus[inx] for inx in filtered_inx] + + with open(output_file, 'w') as f: + for data in train_data: + if len(data['neg']) < negative_number: + samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos'])) + samples = [sent for sent in samples if sent not in data['pos']] + data['neg'].extend(samples[: negative_number - len(data['neg'])]) + f.write(json.dumps(data, ensure_ascii=False) + '\n') + + +def load_model(model_args: ModelArgs): + model = FlagAutoModel.from_finetuned( + model_name_or_path=model_args.embedder_name_or_path, + model_class=model_args.embedder_model_class, + normalize_embeddings=model_args.normalize_embeddings, + pooling_method=model_args.pooling_method, + use_fp16=model_args.use_fp16, + query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, + query_instruction_format=model_args.query_instruction_format_for_retrieval, + devices=model_args.devices, + examples_for_task=model_args.examples_for_task, + examples_instruction_format=model_args.examples_instruction_format, + trust_remote_code=model_args.trust_remote_code, + cache_dir=model_args.cache_dir, + batch_size=model_args.batch_size, + query_max_length=model_args.embedder_query_max_length, + passage_max_length=model_args.embedder_passage_max_length, + ) + return model + + +def main(data_args: DataArgs, model_args: ModelArgs): + model = load_model(model_args) + + find_knn_neg( + model=model, + input_file=data_args.input_file, + output_file=data_args.output_file, + candidate_pool=data_args.candidate_pool, + sample_range=[int(x) for x in data_args.range_for_sampling.split('-')], + negative_number=data_args.negative_number, + use_gpu=data_args.use_gpu_for_searching + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(( + DataArgs, + ModelArgs + )) + data_args, model_args = parser.parse_args_into_dataclasses() + data_args: DataArgs + model_args: ModelArgs + main(data_args, model_args) \ No newline at end of file diff --git a/utils/my_tools/llm.py b/utils/my_tools/llm.py new file mode 100644 index 0000000000000000000000000000000000000000..396adf51c4827bc472851182e72c31c7a7426285 --- /dev/null +++ b/utils/my_tools/llm.py @@ -0,0 +1,61 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +import asyncio +import time +import json +from langchain_openai import ChatOpenAI +from langchain.schema import SystemMessage, HumanMessage +from utils.my_tools.logger import logger as logging + + +class LLM: + def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1): + self.client = ChatOpenAI(model_name=model_name, + openai_api_base=openai_api_base, + openai_api_key=openai_api_key, + request_timeout=request_timeout, + max_tokens=max_tokens, + temperature=temperature) + + def assemble_chat(self, chat=None, system_call='', user_call=''): + if chat is None: + chat = [] + chat.append(SystemMessage(content=system_call)) + chat.append(HumanMessage(content=user_call)) + return chat + + async def nostream(self, chat, system_call, user_call): + chat = self.assemble_chat(chat, system_call, user_call) + response = await self.client.ainvoke(chat) + return response.content + + async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): + message = self.assemble_chat(history, system_call, user_call) + try: + async for frame in self.client.astream(message): + await q.put(frame.content) + except Exception as e: + await q.put(None) + logging.error(f"Error in data producer due to: {e}") + return + await q.put(None) + + async def stream(self, chat, system_call, user_call): + st = time.time() + q = asyncio.Queue(maxsize=10) + + # 启动生产者任务 + producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) + first_token_reach = False + while True: + data = await q.get() + if data is None: + break + if not first_token_reach: + first_token_reach = True + logging.info(f"大模型回复第一个字耗时 = {time.time() - st}") + for char in data: + yield "data: " + json.dumps({'content': char}, ensure_ascii=False) + '\n\n' + await asyncio.sleep(0.03) # 使用异步 sleep + + yield "data: [DONE]" + logging.info(f"大模型回复耗时 = {time.time() - st}") diff --git a/utils/my_tools/logger.py b/utils/my_tools/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..dff30ce50c3b7b9a0be3c9a6ec348ea39086e91b --- /dev/null +++ b/utils/my_tools/logger.py @@ -0,0 +1,42 @@ +import logging +from utils.config.config import config + + +class LoggerSingleton: + _instance = None + + @staticmethod + def get_logger(): + if LoggerSingleton._instance is None: + LoggerSingleton._instance = LoggerSingleton._initialize_logger() + return LoggerSingleton._instance + + @staticmethod + def _initialize_logger(): + # 创建一个 logger 对象 + logger = logging.getLogger('my_logger') + logger.setLevel(logging.INFO) + + # 禁用父级 logger 的传播 + logger.propagate = False + + # 创建一个 formatter 对象 + formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s') + + # 根据配置选择添加不同的 handler + if config['LOG_METHOD'] != 'stdout': + # 添加 FileHandler 到 logger + file_handler = logging.FileHandler('apps.log') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + else: + # 添加 StreamHandler 到 logger + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + return logger + + +# 使用单例模式获取 logger +logger = LoggerSingleton.get_logger() diff --git a/utils/parser/handler/base_parser.py b/utils/parser/handler/base_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..16101bcfad2f74c6c83939983ce049a30d3e8ee8 --- /dev/null +++ b/utils/parser/handler/base_parser.py @@ -0,0 +1,397 @@ +import os +import uuid +from utils.my_tools.logger import logger as logging +from pandas import DataFrame +from docx.table import Table as DocxTable +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from utils.parser.tools.split import split_tools + + +# TODO chunk和chunk_link可以封装成类 + + +class BaseService: + + def __init__(self): + self.vectorizer = None + self.llm_max_tokens = None + self.llm = None + self.tokens = None + + async def init_service(self, llm_entity, llm_max_tokens, tokens, parser_method): + self.parser_method = parser_method + if llm_entity is None: + self.llm = None + self.llm_max_tokens = None + else: + self.llm = llm_entity + self.llm_max_tokens = llm_max_tokens + self.tokens = tokens + self.vectorizer = TfidfVectorizer() + + @staticmethod + def get_uuid(): + """ + 获取uuid + 返回: + 生成的uuid + """ + return uuid.uuid4() + + def check_similarity(self, text1, text2): + """ + TODO :获取段落相似度,具体数值待微调 + """ + # 将文本转换为TF-IDF向量 + if len(text1) < len(text2)*10: + tfidf_matrix = self.vectorizer.fit_transform([text1, text2]) + + # 计算余弦相似度 + cosine_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2]) + if cosine_sim > 0.85: + return True + return False + + def merge_texts(self, texts): + now_len = 0 + now_text = "" + new_texts = [] + for text in texts: + if text['type'] == 'para': + if text['text'] == "": + continue + token_len = split_tools.get_tokens(text) + if now_len + token_len < max(self.tokens // 2, 128) or ( + now_len + token_len < self.tokens and self.check_similarity(now_text, text['text'])): + now_text += text['text'] + '\n' + now_len += token_len + else: + new_texts.append({'text': now_text, 'type': 'para'}) + now_text = text['text'] + '\n' + now_len = token_len + else: + if now_len: + new_texts.append({'text': now_text, 'type': 'para'}) + now_text = "" + now_len = 0 + new_texts.append(text) + if now_len: + new_texts.append({'text': now_text, 'type': 'para'}) + return new_texts + + @staticmethod + def split_sentences(text, TOKENS=1024): + """ + 分句,不超过Tokens数量 + """ + try: + words = split_tools.split_words(text) + current_length = 0 + current_sentence = "" + result = [] + for word in words: + current_sentence = current_sentence + word + current_length = current_length + 1 + if current_length >= TOKENS: + result.append(current_sentence) + current_sentence = "" + current_length = 0 + result.append(current_sentence) + return result + except Exception as e: + logging.error(f"split sentences error as {e}") + return [] + + def split_table(self, table): + """ + 按照行分表 + """ + + if table is None: + return [] + result = [] + new_table = [] + cell_num = 1 + try: + if isinstance(table, DataFrame): + for index, row in table.iterrows(): + row_string_list = [s.replace('|', '||') for s in row.astype(str).tolist()] + cell_num = max(cell_num, len(row_string_list)) + new_table.append(row_string_list) + elif isinstance(table, DocxTable): + if table.rows: + for row in table.rows: + row_string_list = [s.replace('|', '||') for s in (cell.text.strip() for cell in row.cells)] + cell_num = max(cell_num, len(row_string_list)) + new_table.append(row_string_list) + else: + logging.error(f"table type Error as{type(table)}") + return [] + except Exception as e: + logging.error(f"split tables error as{e}") + return [] + + max_tokens = (self.tokens - cell_num) // cell_num + for row in new_table: + new_line = [] + max_len = 0 + for cell in row: + cell = self.split_sentences(cell, max_tokens) + if not cell: + cell = [''] + new_line.append(cell) + max_len = max(max_len, len(cell)) + for i in range(max_len): + row_text = ' | '.join([cell[i] if len(cell) > i else ' ' for cell in new_line]) + row_text = row_text.replace('\n', '\\n') + result.append(row_text) + + return result + + def package_to_chunk(self, **kwargs): + """ + 整合成chunk + + 参数: + - id (str, optional): 目标uuid,默认生成一个新的UUID + - text (str, optional): 目标内容,默认为空字符串 + - tokens (int, optional): 词数,默认为0 + - status (str, optional): 状态,默认为空字符串 + - type_from (str, optional): 来源类型,默认为general + - type_big (str, optional): 大类型,默认为para + - type_small (str, optional): 小类型,默认为line + - type_attr (str, optional): 属性类型,默认为normal + - link_to (str, optional): 链接目标uuid,默认为空字符串 + - offset_in_document (int, optional): 在文档中的偏移量,默认为0 + + 返回: + - dict: 包含chunk信息的字典 + """ + # TODO:可以进行封装 + default_values = { + 'id': self.get_uuid(), + 'text': "", + 'tokens': 0, + 'status': "", + 'type_from': 'general', + 'type_big': 'para', + 'type_small': 'line', + 'type_attr': 'normal', + 'link_to': "", + 'enabled': True, + 'local_offset': 0, + 'global_offset': 0, + } + + # 更新默认值为传入的参数值 + for key, value in kwargs.items(): + if key in default_values: + default_values[key] = value + chunk_type = f"{default_values['type_from']}.{default_values['type_big']}." \ + f"{default_values['type_small']}.{default_values['type_attr']}" + + # 构建chunk字典 + chunk = { + 'id': default_values['id'], + 'text': default_values['text'], + 'type': chunk_type, + 'tokens': default_values['tokens'], + 'global_offset': default_values['global_offset'], + 'local_offset': default_values['local_offset'], + 'enabled': default_values['enabled'], + 'status': default_values['status'], + 'link_to': default_values['link_to'], + } + + return chunk + + def package_to_link(self, chunk_a, chunk_b, **kwargs): + """ + 打包link + 参数: + - chunk_a (str): 出发chunk的id + - chunk_b (str): 目标chunk的id + - is_global (str, optional): link为全局边或者局部边,默认为local + - structure (str, optional): link的小类,表示link属于line/tree/map,默认为line + - model (str, optional): link的模型,默认为pre + - jump (str or int, optional): 跳转值,默认为0 + + 返回: + - dict: 包含link信息的字典 + """ + + default_values = { + 'is_global': 'local', + 'structure': 'line', + 'model': 'pre', + 'jump': 0 + } + + # 更新默认值为传入的参数值 + for key, value in kwargs.items(): + if key in default_values: + default_values[key] = value + + # 确保 jump 是字符串类型 + jump = str(default_values['jump']) + + # 构建 link 字典 + link_type = (f"{default_values['is_global']}.{default_values['model']}." + f"{default_values['structure']}.{jump}") + link = { + 'id': self.get_uuid(), + 'chunk_a': chunk_a, + 'chunk_b': chunk_b, + 'type': link_type, + } + + return link + + def build_chunks_by_lines(self, sentences): + """ + chunks 连接函数 + sentences中需要type和text字段 + """ + sentences = self.merge_texts(sentences) + chunks = [] + local_count = 0 + para_count = 0 + now_type = 'None' + last_para = None + last_local = None + chunks_para = [] + local_offset = 0 + global_offset = 0 + for part in sentences: + global_offset += 1 + local_offset += 1 + if now_type != part['type']: + last_local = None + local_offset = 0 + if part['type'] == now_type: + local_count += 1 + type_attr = 'normal' + else: + type_attr = 'head' + local_count = 0 + if part['type'] == 'para': + link_to = last_para + para_count += 1 + else: + link_to = last_local + + now_type = part['type'] + if 'id' not in part: + part['id'] = self.get_uuid() + chunk = self.package_to_chunk(id=part["id"], text=part["text"], tokens=split_tools.get_tokens(part["text"]), + link_to=link_to, status="", type_from="general", type_big=part["type"], + type_small="line", type_attr=type_attr, global_offset=global_offset, + local_offset=local_offset, ) + last_local = chunk['id'] + chunks.append(chunk) + if now_type == 'para': + last_para = chunk['id'] + chunks_para.append(chunk) + return chunks + + def build_chunks_and_links_by_tree(self, tree: dict): + """ + chunks 连接函数 + tree为dict表示的树结构, + """ + chunks = [] + + def get_edges(node, parent_id=None, dep=0): + chunk = self.package_to_chunk(text=node["text"], tokens=split_tools.get_tokens(tree["text"]), status="", + type_big=node["type"], type_small='tree', type_attr=node['type_attr'], + global_offset=dep, link_to=parent_id, ) + node['id'] = chunk['id'] + chunks.append(chunk) + + # 如果当前节点有子节点,则遍历每个子节点 + if 'children' in node and node['children']: + for child in node['children']: + # 递归处理子节点 + get_edges(child, node['id'], dep + 1) + + get_edges(tree) + + chunk_links = [] + chunk_links.extend(self.edge_link(chunks, 'global', 'tree')) + return chunks, chunk_links + + def build_chunk_links_by_line(self, chunks): + """ + 线性分割chunks并构建上下文关系 + """ + chunk_links = [] + chunks_para = [] + tmp_chunks = [] + for chunk in chunks: + if chunk['type'] == 'para': + if tmp_chunks is not None and len(tmp_chunks) > 0: + chunk_links.extend(self.edge_link(tmp_chunks, 'local', 'line')) + tmp_chunks = [] + else: + tmp_chunks.append(chunk) + chunk_links.extend(self.edge_link(chunks_para, 'local', 'line')) + chunk_links.extend(self.edge_link(chunks, 'global', 'line')) + return chunk_links + + def edge_link(self, chunks, is_global, structure, **kwargs): + """ + 根据给定的块列表构建边缘链接。 + 该函数通过遍历每个块,并为每个块与其链接的目标块创建双向链接数据。 + 然后,根据这些链接数据生成链接对象列表。 + + """ + links = [] + links_data = [] + for chunk in chunks: + links_data.append({ + 'chunk_a': chunk['id'], + 'chunk_b': chunk['link_to'], + 'is_global': is_global, + 'structure': structure, + 'model': 'next', + 'jump': 0 + }) + links_data.append({ + 'chunk_a': chunk['link_to'], + 'chunk_b': chunk['id'], + 'is_global': is_global, + 'structure': structure, + 'model': 'pre', + 'jump': 0 + }) + + for data in links_data: + if data['chunk_a'] is None or data['chunk_b'] is None: + continue + links.append( + self.package_to_link(chunk_a=data['chunk_a'], chunk_b=data['chunk_b'], is_global=data['is_global'], + structure=data['structure'], model=data['model'], jump=data['jump'], )) + return links + + async def insert_image_to_tmp_folder(self, image_bytes, image_id, image_extension): + """ + 插入图像字节流到临时文件夹中(用于插入到minIO) + 参数: + - image_bytes: 图像字节流(可以是多个图像) + - image_id: 用于保存图像文件的id + """ + output_dir = None + try: + if not isinstance(type(image_bytes), list): + image_bytes = [image_bytes] + for image in image_bytes: + output_dir = os.path.join('./parser', str(image_id)) + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, str(image_id)+'.'+image_extension) + with open(output_path, 'wb') as f: + f.write(image) + return True + except Exception as e: + logging.error(f'Insert images {image_id} error: {e}') + return False diff --git a/utils/parser/handler/doc_parser.py b/utils/parser/handler/doc_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7f58f6578ae26cf96d64db21895bb995d731cd --- /dev/null +++ b/utils/parser/handler/doc_parser.py @@ -0,0 +1,38 @@ +from utils.my_tools.logger import logger as logging +from tika import parser + +from utils.parser.handler.base_parser import BaseService + + + +class DocService(BaseService): + def extract_paragraph(self, paragraph): + sentences = self.split_sentences(paragraph, self.tokens) + results = [] + for sentence in sentences: + results.append({ + "type": "para", + "text": sentence, + }) + return results + + @staticmethod + def open_file(file_path): + return open(file_path, 'rb') + + async def parser(self, file_path): + binary = self.open_file(file_path) + try: + content = parser.from_buffer(binary) + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + raise e + + paragraphs = content.split('\n') + sentences = [] + for paragraph in paragraphs: + sentences.extend(self.extract_paragraph(paragraph)) + chunks = self.build_chunks_by_lines(sentences) + chunk_links = self.build_chunk_links_by_line(chunks) + return chunks, chunk_links, [] + diff --git a/utils/parser/handler/docx_parser.py b/utils/parser/handler/docx_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5a229cc3632887762d15bb8c7cfe047815a17ea5 --- /dev/null +++ b/utils/parser/handler/docx_parser.py @@ -0,0 +1,218 @@ +from io import BytesIO +from PIL import Image +import numpy as np +import docx +from docx.document import Document +from docx.text.paragraph import Paragraph +from docx.parts.image import ImagePart +from docx.table import _Cell, Table +from docx.oxml.table import CT_Tbl +from docx.oxml.text.paragraph import CT_P +from docx.oxml.shape import CT_Picture +import mimetypes +from utils.parser.handler.base_parser import BaseService +from utils.parser.tools.ocr import BaseOCR +from utils.my_tools.logger import logger as logging + + +class DocxService(BaseService): + def __init__(self): + super().__init__() + self.image_model = None + + def open_file(self, file_path): + try: + doc = docx.Document(file_path) + return doc + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + raise e + + def is_image(self, graph: Paragraph, doc: Document): + images = graph._element.xpath('.//pic:pic') + for image in images: + for img_id in image.xpath('.//a:blip/@r:embed'): + part = doc.part.related_parts[img_id] + if isinstance(part, ImagePart): + return True + return False + + # 获取run中的所有图片 + def get_imageparts_from_run(self, run, doc: Document): + image_parts = [] + drawings = run._r.xpath('.//w:drawing') # 获取所有图片 + for drawing in drawings: + for img_id in drawing.xpath('.//a:blip/@r:embed'): # 获取图片id + part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片 + if isinstance(part, ImagePart): + image_parts.append(part) + return image_parts + + # 遍历文档中的块级元素 + def get_lines(self, parent): + if isinstance(parent, Document): + parent_elm = parent.element.body + elif isinstance(parent, _Cell): + parent_elm = parent._tc + else: + logging.error("Unsupported parent type: %s", type(parent)) + return [] + lines = [] + for child in parent_elm.iterchildren(): + if isinstance(child, CT_P): + paragraph = Paragraph(child, parent) + if self.is_image(paragraph, parent): + text_part = '' + run_index = 0 + runs = paragraph.runs + + while run_index < len(runs): + run = runs[run_index] + image_parts = self.get_imageparts_from_run(run, parent) + if image_parts: + if text_part: + lines.append((text_part, 'para')) + text_part = '' + # 处理图片 + for image_part in image_parts: + image_blob = image_part.image.blob + content_type = image_part.content_type + extension = mimetypes.guess_extension(content_type).replace('.', '') + lines.append(([(Image.open(BytesIO(image_blob)), extension)], 'image')) + else: + # 处理文字 + text_part += run.text + run_index += 1 + + if text_part: + lines.append((text_part, 'para')) + else: + lines.append((paragraph.text, 'para')) + elif isinstance(child, CT_Tbl): + table = Table(child, parent) + rows=self.split_table(table) + for row in rows: + lines.append((row, 'table')) + elif isinstance(child, CT_Picture): + img_id = child.xpath('.//a:blip/@r:embed')[0] + part = parent.part.related_parts[img_id] + if isinstance(part, ImagePart): + image_blob = part.image.blob + content_type = part.content_type + extension = mimetypes.guess_extension(content_type).replace('.', '') + lines.append(([(Image.open(BytesIO(image_blob)), extension)], 'image')) + new_lines = [] + for i in range(len(lines)): + if lines[i][1] == 'image': + if len(new_lines) > 0 and new_lines[-1][1] == 'image': + new_lines[-1][0].append(lines[i][0][0]) + else: + new_lines.append(lines[i]) + else: + new_lines.append(lines[i]) + return new_lines + + async def solve_lines(self, lines, method): + """ + 修整处理lines,根据不同的类型(图像、段落、表格)处理每一行,并根据method参数决定处理方式。 + + 参数: + - lines (list): 需要处理的行列表,每行包含内容和类型。 + - method (str): 处理方法,可能是"ocr"、"llm-Enhance"或其他。 + + 返回: + - tuple: 包含处理后的句子列表和图像列表的元组。 + """ + sentences = [] + images = [] + last_para = "" + last_para_id = None + for line in lines: + if line[1] == 'image': + # 处理图像 + for image_tuple in line[0]: + image_id = self.get_uuid() + image = image_tuple[0] + image_bytes = image.tobytes() + image_extension = image_tuple[1] + await self.insert_image_to_tmp_folder(image_bytes, image_id, image_extension) + if method in ['ocr', 'enhanced']: + # 将图片关联到图片的描述chunk上 + chunk_id = self.get_uuid() + sentences.append({'id': chunk_id, + 'type': 'image'}) + sentences[-1]['near_text'] = last_para + sentences[-1]['image'] = np.array(image) + images.append({ + 'id': image_id, + 'chunk_id': chunk_id, + 'extension': image_extension, + }) + else: + # 将图片关联到上一个段落chunk上 + images.append({ + 'id': image_id, + 'chunk_id': last_para_id, + 'extension': 'png', + }) + + elif line[1] == 'para': + # 处理段落 + sentences.append({'id': self.get_uuid(), + 'text': line[0], + 'type': line[1]}) + last_para = line[0] + last_para_id = sentences[-1]['id'] + + elif line[1] == 'table': + # 处理表格 + sentences.append({'id': self.get_uuid(), + 'text': line[0], + 'type': line[1]}) + + if method in ['ocr', 'enhanced']: + sentences = await self.get_near_text(sentences) + return sentences, images + + async def get_near_text(self, sentences): + # 获取图像相邻文本 + last_para = "" + len_sentences = len(sentences) + for i in range(len_sentences - 1, -1, -1): + sentence = sentences[i] + if sentence['type'] == 'image': + sentences[i]['near_text'] = sentences[i]['near_text'] + last_para + elif sentence['type'] == 'para': + last_para = sentence['text'] + elif sentence['type'] == 'table': + pass + for sentence in sentences: + if sentence['type'] == 'image': + # 通过ocr/llm-Enhance进行强化 + sentence['text'] = await self.image_model.run(sentence['image'], text=sentence['near_text']) + return sentences + + async def parser(self, file_path): + """ + 解析文件并提取其中的文本和图像信息。 + + 参数: + - file_path (str): 文件的路径。 + + 返回: + - tuple: 包含分块的文本信息、分块间的链接信息和提取的图像信息的元组。 + 如果文件无法打开或解析失败,则返回 None。 + """ + doc = self.open_file(file_path) + if not doc: + return None + method = self.parser_method + if method != "general": + self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, method=method) + lines = self.get_lines(doc) + + sentences, images = await self.solve_lines(lines, method) + + chunks = self.build_chunks_by_lines(sentences) + chunk_links = self.build_chunk_links_by_line(chunks) + return chunks, chunk_links, images diff --git a/utils/parser/handler/html_parser.py b/utils/parser/handler/html_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b258c2960a61a2dfd3492566949401d9ee0f14a0 --- /dev/null +++ b/utils/parser/handler/html_parser.py @@ -0,0 +1,54 @@ +from utils.my_tools import logger as logging +from bs4 import BeautifulSoup +from utils.parser.handler.base_parser import BaseService + + + +class HtmlService(BaseService): + # 读取 HTML 文件 + + @staticmethod + def open_file(file_path): + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as file: + html_content = file.read() + return html_content + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + raise e + + def element_to_dict(self, element): + node_dict = { + "tag": element.name, # 当前节点的标签名 + "attributes": element.attrs if element.attrs else None, # 标签的属性(如果有) + "text": element.get_text(strip=True) if element.string else None, # 标签内的文字 + "children": [], # 子节点列表 + "id": self.get_uuid(), + "type": "general", + "type_attr": 'leaf', + } + + # 处理图片 + if element.name == "img": + node_dict["img"] = element.get('src', None) + # 处理列表 + elif element.name in ["ul", "ol"]: + node_dict["list"] = [li.get_text(strip=True) for li in element.find_all('li')] + + # 递归处理子元素 + for child in element.children: + if child.name: # 如果子节点是标签而不是字符串 + node_dict['type_attr'] = 'node' + child_node = self.element_to_dict(child) + node_dict["children"].append(child_node) + + return node_dict + + def parser(self, file_path): + html_content = self.open_file(file_path) + # 解析 HTML 内容 + soup = BeautifulSoup(html_content, 'lxml') + tree = self.element_to_dict(soup) + chunks, chunk_links = self.build_chunks_and_links_by_tree(tree) + return chunks, chunk_links, [] + diff --git a/utils/parser/handler/md_parser.py b/utils/parser/handler/md_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4121832350a9d34b82fa52b466c11f8c5421bea7 --- /dev/null +++ b/utils/parser/handler/md_parser.py @@ -0,0 +1,50 @@ +from utils.my_tools.logger import logger as logging +from utils.parser.handler.base_parser import BaseService + + +class MdService(BaseService): + + + @staticmethod + def read_md(file_path): + # 打开并读取Markdown文件 + try: + with open(file_path, 'r', encoding='utf-8',errors='ignore') as file: + data = file.read() + return data + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + raise e + + # 提取列表分词结果 + def extract_from_md(self, data) -> dict: + md = data + lines = md.split('\n') + results = [] + if len(lines) > 1: + type = "table" + else: + type = "para" + lines = lines[0] + lines = self.split_sentences(lines, self.tokens) + for line in lines: + results.append({ + 'type': type, + 'text': line, + }) + return results + + async def parser(self, file_path): + data = self.read_md(file_path) + parts = data.split('\n\n') #分割 + sentences = [] + for part in parts: + sentences.extend(self.extract_from_md(part)) + chunks = self.build_chunks_by_lines(sentences) + chunk_links = self.build_chunk_links_by_line(chunks) + return chunks, chunk_links, [] + + +if __name__ == '__main__': + model = MdService() + chunks, links, images = model.parser('test.md') diff --git a/utils/parser/handler/pdf_parser.py b/utils/parser/handler/pdf_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..749c8f30c3344d6200d364894524207695868fec --- /dev/null +++ b/utils/parser/handler/pdf_parser.py @@ -0,0 +1,215 @@ +import io +import fitz +import numpy as np +from utils.my_tools.logger import logger as logging +from PIL import Image +from utils.parser.tools.ocr import BaseOCR +from utils.parser.handler.base_parser import BaseService + + + + +class PdfService(BaseService): + + def __init__(self): + super().__init__() + self.image_model = None + self.page_numbers = None + self.pdf = None + + def open_pdf(self, file_path): + try: + self.pdf = fitz.open(file_path) + self.page_numbers = len(self.pdf) + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + raise e + + def extract_text(self, page_number): + page = self.pdf.load_page(page_number) + lines = [] + + text_blocks = page.get_text('blocks') + for block in text_blocks: + if block[6] == 0: # 确保是文本块 + text = block[4].strip() + rect = block[:4] # (x0, y0, x1, y1) + if text: + lines.append({'bbox': rect, + 'text': text, + 'type': 'para', + }) + sorted_lines = sorted(lines, key=lambda x: (x['bbox'][1], x['bbox'][0])) + return sorted_lines + + def extract_table(self, page_number): + """ + 读取pdf中的列表 + :param page_number:pdf页码 + :返回 pdf中的列表的内容(pandas格式)和坐标(x0,y0,x1,y1) + """ + page = self.pdf.load_page(page_number) + tabs = page.find_tables() + dfs = [] + for tab in tabs: + tab_bbox = fitz.Rect(tab.bbox) + page.add_redact_annot(tab.bbox) + df = tab.to_pandas() + lines = self.split_table(df) + for line in lines: + dfs.append({ + 'text': line, + 'bbox': tab_bbox, + 'type': 'table' + }) + + page.apply_redactions() + return dfs + + async def extract_image(self, page_number, text): + page = self.pdf.load_page(page_number) + image_list = page.get_images(full=True) + results = [] + image_chunks = [] + for img in image_list: + # 获取图像的xref + xref = img[0] + # 获取图像的base图像(如果存在) + base_image = self.pdf.extract_image(xref) + pos = page.get_image_rects(xref)[0] + # 获取图像的二进制数据 + image_bytes = base_image["image"] + # 获取图像的扩展名 + image_ext = base_image["ext"] + # 获取图像的位置信息 + + bbox = (pos.x0, pos.y0, pos.x1, pos.y1) + near = self.find_near_words(bbox, text) + + image = Image.open(io.BytesIO(image_bytes)) + image_id = self.get_uuid() + + await self.insert_image_to_tmp_folder(image_bytes, image_id,image_ext) + + img_np = np.array(image) + ocr_results = await self.image_model.run(img_np, text=near) + + # 获取OCR + chunk_id = self.get_uuid() + results.append({ + 'type': 'image', + 'text': ocr_results, + 'bbox': bbox, + 'xref': xref, + 'id': chunk_id, + }) + image_chunks.append({ + 'id': image_id, + 'chunk_id': chunk_id, + 'extension': image_ext, + }) + + return results, image_chunks + + def extract_text_with_position(self, page_number): + """获取带坐标的文本块""" + page = self.pdf.load_page(page_number) + text_blocks = [] + for block in page.get_text("dict")["blocks"]: + if "lines" in block: # 确保是文本块 + for line in block["lines"]: + for span in line["spans"]: + text_blocks.append({ + 'text': span['text'], + 'bbox': span['bbox'], # 文本的矩形区域 (x0, y0, x1, y1) + 'type': 'para' + }) + return text_blocks + + def find_near_words(self, bbox, texts): + """寻找相邻文本""" + nearby_text = [] + image_x0, image_y0, image_x1, image_y1 = bbox + threshold = 100 + image_x0 -= threshold + image_y0 -= threshold + image_x1 += threshold + image_y1 += threshold + line = "" + for text in texts: + text_x0, text_y0, text_x1, text_y1 = text['bbox'] + text_content = text['text'] + # 左右相邻:水平距离小于等于阈值,且垂直方向有重叠 + horizontally_adjacent = (text_x1 >= image_x0 - threshold and text_x0 <= image_x1 + threshold) + # 上下相邻:垂直距离小于等于阈值,且水平方向有重叠 + vertically_adjacent = (text_y1 >= image_y0 - threshold and text_y0 <= image_y1 + threshold) + # 判断相交或相邻 + if horizontally_adjacent and vertically_adjacent: + line = line + text_content + + return line + + @staticmethod + def merge_list(list_a, list_b): + """ + 按照x0,y0,x1,y1合并list_a,list_b + :param + list_a:文字list + list_b:图像或者列表的list + """ + if list_a is None: + return list_b + if list_b is None: + return list_a + len_b = len(list_b) + now_b = 0 + max_x = 0 + result_list = [] + + for part_a in list_a: + max_x = max(max_x, part_a['bbox'][2]) + if now_b < len_b: + part_b = list_b[now_b] + while now_b < len_b and part_b['bbox'][0] < max_x and part_b['bbox'][1] < part_a['bbox'][1]: + result_list.append(part_b) + now_b += 1 + if now_b < len_b: + part_b = list_b[now_b] + result_list.append(part_a) + while now_b < len_b: + part_b = list_b[now_b] + result_list.append(part_b) + now_b += 1 + + return result_list + + async def parser(self, file_path): + self.open_pdf(file_path) + method = self.parser_method + sentences = [] + all_image_chunks = [] + if method != "general": + self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, + method=self.parser_method) + for page_num in range(self.page_numbers): + tables = self.extract_table(page_num) + text = self.extract_text(page_num) + temp_list = self.merge_list(text, tables) + if method != "general": + images, image_chunks = await self.extract_image(page_num, text) + merge_list = self.merge_list(temp_list, images) + all_image_chunks.extend(image_chunks) + else: + merge_list = temp_list + sentences.extend(merge_list) + + chunks = self.build_chunks_by_lines(sentences) + chunk_links = self.build_chunk_links_by_line(chunks) + return chunks, chunk_links, all_image_chunks + + def __del__(self): + if self.pdf: + self.pdf.close() + self.page_numbers = None + self.pdf = None + self.image_model = None diff --git a/utils/parser/handler/txt_parser.py b/utils/parser/handler/txt_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..67256b6f539219db3dd5ecb5b694639b89230772 --- /dev/null +++ b/utils/parser/handler/txt_parser.py @@ -0,0 +1,51 @@ +import uuid +import chardet +from utils.my_tools.logger import logger as logging +from utils.parser.handler.base_parser import BaseService + +Empty_id = uuid.UUID(int=0) + + +class TxtService(BaseService): + + # 提取段落分词结果 + def extract_paragraph(self, paragraph): + sentences = self.split_sentences(paragraph, self.tokens) + results = [] + for sentence in sentences: + results.append({ + "type": "para", + "text": sentence, + }) + return results + + @staticmethod + # 获取编码方式 + def detect_encoding(file_path): + with open(file_path, 'rb') as file: + raw_data = file.read() + result = chardet.detect(raw_data) + encoding = result['encoding'] + return encoding + + # 获取段落 + def read_text_file_by_paragraph(self, file_path): + try: + encoding = self.detect_encoding(file_path) + with open(file_path, 'r', encoding=encoding,errors='ignore') as file: # 打开文件 + content = file.read() + paragraphs = content.split('\n') + return paragraphs + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + + async def parser(self, file_path): + # 使用函数 + paragraphs = self.read_text_file_by_paragraph(file_path) + sentences = [] + for paragraph in paragraphs: + sentences.extend(self.extract_paragraph(paragraph)) + chunks = self.build_chunks_by_lines(sentences) + chunk_links = self.build_chunk_links_by_line(chunks) + # 打印每个段落 + return chunks, chunk_links, [] diff --git a/utils/parser/handler/xlsx_parser.py b/utils/parser/handler/xlsx_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..97606d7176106359ec5a3f1f0cfde5128610e60c --- /dev/null +++ b/utils/parser/handler/xlsx_parser.py @@ -0,0 +1,34 @@ +import pandas as pd +from utils.my_tools.logger import logger as logging +from utils.parser.handler.base_parser import BaseService + + +class XlsxService(BaseService): + + # 打开Excel文件 + @staticmethod + def read_xlsx(file_path): + try: + data = pd.read_excel(file_path) + return data + except Exception as e: + logging.error(f"Error opening file {file_path} :{e}") + raise e + + # 提取列表分词结果 + def extract_table(self, data): + lines = self.split_table(data) + results = [] + for line in lines: + results.append({ + 'type': 'table', + 'text': line, + }) + return results + + async def parser(self, file_path): + data = self.read_xlsx(file_path) + sentences = self.extract_table(data) + chunks = self.build_chunks_by_lines(sentences) + chunk_links = self.build_chunk_links_by_line(chunks) + return chunks, chunk_links, [] diff --git a/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdiparams b/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdiparams new file mode 100644 index 0000000000000000000000000000000000000000..089594aeeb05fcdf9da804ddd5e7a5623bec5156 Binary files /dev/null and b/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdiparams differ diff --git a/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdiparams.info b/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdiparams.info new file mode 100644 index 0000000000000000000000000000000000000000..082c148e06ff888838eb930dc90113ae82e790a2 Binary files /dev/null and b/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdiparams.info differ diff --git a/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdmodel b/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdmodel new file mode 100644 index 0000000000000000000000000000000000000000..223b8614e44abd584620c62e07f9c39e0a967a8d Binary files /dev/null and b/utils/parser/model/ocr/ch_PP-OCRv4_det_infer/inference.pdmodel differ diff --git a/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdiparams b/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdiparams new file mode 100644 index 0000000000000000000000000000000000000000..4c3d9e9cbbf3881c7fd87299cd49291cad140665 Binary files /dev/null and b/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdiparams differ diff --git a/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdiparams.info b/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdiparams.info new file mode 100644 index 0000000000000000000000000000000000000000..923329f5eca3a9db10dbbe287c24421a4e662a5d Binary files /dev/null and b/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdiparams.info differ diff --git a/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdmodel b/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdmodel new file mode 100644 index 0000000000000000000000000000000000000000..dccddcc7c88cdea9cccd399d9dc6fa0d7ae7a8f2 Binary files /dev/null and b/utils/parser/model/ocr/ch_PP-OCRv4_rec_infer/inference.pdmodel differ diff --git a/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams b/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams new file mode 100644 index 0000000000000000000000000000000000000000..3449efb5c6e3737c9340e11a9a21b43c7c86775e Binary files /dev/null and b/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams differ diff --git a/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info b/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info new file mode 100644 index 0000000000000000000000000000000000000000..f31a15752c98484d5364d3f3ff700522dc5256b7 Binary files /dev/null and b/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdiparams.info differ diff --git a/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdmodel b/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdmodel new file mode 100644 index 0000000000000000000000000000000000000000..b90c1550df13a0909b4c2eac9f0d1cf8491d2cba Binary files /dev/null and b/utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer/inference.pdmodel differ diff --git a/utils/parser/service/parser_service.py b/utils/parser/service/parser_service.py new file mode 100644 index 0000000000000000000000000000000000000000..31112b3d802e88ea8e25c4fc35a48e716a4ae036 --- /dev/null +++ b/utils/parser/service/parser_service.py @@ -0,0 +1,40 @@ +from utils.my_tools.logger import logger as logging +from utils.parser.handler.docx_parser import DocxService +from utils.parser.handler.html_parser import HtmlService +from utils.parser.handler.xlsx_parser import XlsxService +from utils.parser.handler.txt_parser import TxtService +from utils.parser.handler.pdf_parser import PdfService +from utils.parser.handler.md_parser import MdService +from utils.parser.handler.doc_parser import DocService + +class EasyParser: + # TODO:把user_id和doc_id提取到这层 + def __init__(self): + self.doc = None + + async def parser(self, file_path, llm_entity, llm_max_tokens=8096, chunk_size=1024, parser_method='general'): + model_map = { + ".docx": DocxService, + ".doc": DocService, + ".txt": TxtService, + ".pdf": PdfService, + ".xlsx": XlsxService, + ".md": MdService, + ".html": HtmlService, + } + file_extension = '.'+file_path.split(".")[-1] + try: + if file_extension in model_map: + model = model_map[file_extension]() # 判断文件类型 + await model.init_service(llm_entity=llm_entity, + llm_max_tokens=llm_max_tokens, + tokens=chunk_size, + parser_method=parser_method) + chunk_list, chunk_link_list, image_chunks = await model.parser(file_path) + else: + logging.error(f"No service available for file type: {file_extension}") + return {"chunk_list": [], "chunk_link_list": [], "image_chunks": []} + except Exception as e: + logging.error(f'fail with exception:{e}') + raise e + return {"chunk_list": chunk_list, "chunk_link_list": chunk_link_list, "image_chunks": image_chunks} diff --git a/utils/parser/tools/ocr.py b/utils/parser/tools/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..497e714a8dd6798336b5b420ac16f8f1b078b838 --- /dev/null +++ b/utils/parser/tools/ocr.py @@ -0,0 +1,170 @@ +import yaml +from utils.my_tools.logger import logger as logging +from paddleocr import PaddleOCR +from utils.config.config import config +from utils.parser.tools.split import split_tools + + +class BaseOCR: + + def __init__(self, llm=None, llm_max_tokens=None, method='general'): + # 指定模型文件的路径 + det_model_dir = 'utils/parser/model/ocr/ch_PP-OCRv4_det_infer' + rec_model_dir = 'utils/parser/model/ocr/ch_PP-OCRv4_rec_infer' + cls_model_dir = 'utils/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer' + + # 创建 PaddleOCR 实例,指定模型路径 + self.model = PaddleOCR( + det_model_dir=det_model_dir, + rec_model_dir=rec_model_dir, + cls_model_dir=cls_model_dir, + use_angle_cls=True, # 是否使用角度分类模型 + use_space_char=True # 是否使用空格字符 + ) + self.llm = llm + if llm is None and method == 'enhanced': + method = 'ocr' + else: + self.max_tokens = llm_max_tokens + self.method = method + + def ocr(self, image): + """ + ocr识别文字 + 参数: + image:图像文件 + **kwargs:可选参数,如语言、gpu + 返回: + 一个list,包含了所有识别出的文字以及对应坐标 + """ + try: + # get my_tools + results = self.model.ocr(image) + logging.info(f"OCR job down {results}") + return results + except Exception as e: + logging.error(f"OCR job error {e}") + raise e + + @staticmethod + def get_text_from_ocr_results(ocr_results): + results = '' + if ocr_results[0] is None: + return '' + try: + for result in ocr_results[0][0]: + results += result[1][0] + return results + except Exception as e: + logging.error(f'Get text from ocr result failed with {e}') + return '' + + @staticmethod + def split_list(image_result, max_tokens): + """ + 分句,不超过Tokens数量 + """ + sum_tokens = 0 + result = [] + temp = [] + for sentences in image_result[0]: + if sentences is not None and len(sentences) > 0: + tokens = split_tools.get_tokens(sentences) + if sum_tokens + tokens > max_tokens: + result.append(temp) + temp = [sentences] + sum_tokens = tokens + else: + temp.append(sentences) + sum_tokens += tokens + if temp: + result.append(temp) + return result + + @staticmethod + def get_prompt_dict(): + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + return prompt_dict + except Exception as e: + logging.error(f'Get prompt failed : {e}') + raise e + + async def improve(self, image_results, text): + """ + llm强化接口 + 参数: + - image_results:ocr识别结果,包含了文字坐标、内容、置信度 + - text:图片组对应的前后文 + """ + try: + user_call = '请详细输出图片的总结,不要输出其他内容' + split_images = [] + max_tokens = self.max_tokens // 2 + for image in image_results: + split_result = self.split_list(image, max_tokens) + split_images.append(split_result) + front_text = text + front_image_description = "" + front_part_description = "" + prompt_dict = self.get_prompt_dict() + for image in split_images: + for part in image: + prompt = prompt_dict.get('OCR_ENHANCED_PROMPT', '') + try: + prompt = prompt.format( + front_text=front_text, + front_image_description=front_image_description, + front_part_description=front_part_description, + part=part) + front_part_description = await self.llm.nostream([], prompt, user_call) + except Exception as e: + raise e + front_image_description = front_part_description + answer = front_image_description + return answer + except Exception as e: + raise e + + async def run(self, image, text): + """ + 执行ocr的接口 + 输入: + image:图像文件 + """ + method = self.method + if not isinstance(image, list): + image = [image] + + image_results = self.process_images(image) + results = await self.generate_results(method, image_results, text) + + return results + + def process_images(self, images): + image_results = [] + for every_image in images: + try: + ocr_result = self.ocr(every_image) + image_results.append(ocr_result) + except Exception as e: + # 记录异常信息,可以选择日志记录或其他方式 + logging.error(f"Error processing image: {e}") + return image_results + + async def generate_results(self, method, image_results, text): + if method == 'ocr': + results = self.get_text_from_ocr_results(image_results) + return f'{results}' + elif method == 'enhanced': + try: + results = await self.improve(image_results, text) + if len(results.strip()) == 0: + return self.get_text_from_ocr_results(image_results) + return results + except Exception as e: + logging.error(f"LLM ERROR with: {e}") + return self.get_text_from_ocr_results(image_results) + else: + return "" diff --git a/utils/parser/tools/split.py b/utils/parser/tools/split.py new file mode 100644 index 0000000000000000000000000000000000000000..4f21b9e2353ba3669ba96d1ee84f6872d91d2665 --- /dev/null +++ b/utils/parser/tools/split.py @@ -0,0 +1,14 @@ +import jieba + + +class SplitTools: + def get_tokens(self, content): + sum_tokens = len(self.split_words(content)) + return sum_tokens + + @staticmethod + def split_words(text): + return list(jieba.cut(str(text))) + + +split_tools = SplitTools() diff --git a/utils/service/document_governance.py b/utils/service/document_governance.py new file mode 100644 index 0000000000000000000000000000000000000000..c279973f86695570949fc931eb949d5598fac194 --- /dev/null +++ b/utils/service/document_governance.py @@ -0,0 +1,335 @@ +# TODO: 文档治理,包含去重、标准化、格式化操作,接下来是具体解释 + +import os +import re +import time + +import pandas as pd +from datasketch import MinHash, MinHashLSH +import jieba +from docx import Document + +from utils.config.config import config +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import json + +from utils.service.qa_generate import tokenize + +from utils.my_tools.logger import logger as logging + +class UniqueTools: + @staticmethod + def tokenize(sentence): + return list(jieba.cut(sentence)) + + @staticmethod + def generate_minhash(text, num_perm=128): + """ + 生成文本的 MinHash 签名。 + """ + m = MinHash(num_perm=num_perm) + for word in text.split(): + m.update(word.encode('utf8')) + return m + + def deduplicate_blocks(self, blocks, threshold=0.8): + """ + 文本块间去重,结合 MinHash 和 LSH。 + :param blocks: 文本块列表 + :param threshold: 去重的相似度阈值(默认0.8) + :return: 去重后的文本块列表 + """ + lsh = MinHashLSH(threshold=threshold, num_perm=128) + unique_blocks = [] + for i, block in enumerate(blocks): + # 生成 MinHash + m = self.generate_minhash(block['text']) + # 检查是否已存在相似块 + if not lsh.query(m): + unique_blocks.append(block) + lsh.insert(f"block_{i}", m) + return unique_blocks + + # 计算余弦相似度 + @staticmethod + def jaccard_similarity(list1, list2): + set1 = set(list1) + set2 = set(list2) + intersection = set1.intersection(set2) + union = set1.union(set2) + return len(intersection) / float(len(union)) + + def deduplicate(self, chunk, threshold=0.9): + unique_sentences = [] + unique_tokenized = [] + sentences = re.split(r'(?<=[.!?]) |\n', chunk) + tokenized_sentences = [self.tokenize(s) for s in sentences] + for i, tokens in enumerate(tokenized_sentences): + if not any(self.jaccard_similarity(tokens, ut) > threshold for ut in unique_tokenized): + unique_sentences.append(sentences[i]) + unique_tokenized.append(tokens) + return ''.join(unique_sentences) + + +class NormalizeTools: + # TODO: 文档标准化工具 + def __init__(self): + self.sensitive_words = self.load_sensitive_words(config['SENSITIVE_WORDS_PATH']) + self.term_replacements = self.load_term_replacements(config['TERM_REPLACEMENTS_PATH']) + self.sensitive_patterns = self.load_sensitive_patterns(config['SENSITIVE_PATTERNS_PATH']) + + @staticmethod + def load_sensitive_words(file_path): + """ + 读取敏感词表 + """ + # 加载敏感词表 + with open(file_path, 'r', encoding='utf-8') as file: + sensitive_words = set(line.strip() for line in file) + return sensitive_words + + @staticmethod + def load_term_replacements(file_path): + """ + 读取术语替换表 + """ + with open(file_path, 'r', encoding='utf-8') as file: + term_replacements = {} + for line in file: + parts = line.strip().split(':') + if len(parts) == 2: + term_replacements[parts[0]] = parts[1] + else: + print(f"Warning: Invalid format in line: {line.strip()}") + return term_replacements + + @staticmethod + def load_sensitive_patterns(file_path): + """ + 读取敏感格式表 + """ + # 加载敏感词表 + with open(file_path, 'r', encoding='utf-8') as file: + sensitive_patterns = set(line.strip() for line in file) + return sensitive_patterns + + @staticmethod + def normalize_punctuation(text): + """ + 标准化文本中的标点符号,全角转半角 + """ + + def to_half_width(c): # 全角转半角 + if 65281 <= ord(c) <= 65374: # 全角字符范围 + return chr(ord(c) - 65248) + elif c == ' ': # 全角空格 + return ' ' + else: + return c + + return ''.join(to_half_width(c) for c in text) + + @staticmethod + def normalize_text_format(text): + """ + 文本格式化,包括去除多余的空格、重复的换行符,单独的换行符需要保留等 + """ + # 去除多余的空格 + text = re.sub(r'[ \t]+', ' ', text) + # 将多个换行符替换为单个换行符 + text = re.sub(r'\n+', '\n', text) + # 去除行首和行尾的空格 + text = text.strip() + return text + + @staticmethod + def normalize_text_content(text, term_replacements, number_format='%d'): + """ + 文本内容标准化,包括术语替换和数字格式化等操作 + """ + # 统一大小写 + text = text.lower() + # 术语替换 + for old, new in term_replacements.items(): + text = text.replace(old, new) + # 数字格式统一 + text = re.sub(r'\d+', lambda x: number_format % int(x.group()), text) + return text + + #文档编码标准化 + @staticmethod + def normalize_encoding(text): + """ + 文本编码标准化,统一转换成utf-8 + """ + return text.encode('utf-8', errors='ignore').decode('utf-8') + + @staticmethod + def mask_sensitive_data(text, sensitive_words, sensitive_patterns): + """ + 文本内容脱敏,包括敏感词和正则表达式匹配等 + """ + for word in sensitive_words: + text = re.sub(re.escape(word), '***', text) + for pattern in sensitive_patterns: + text = re.sub(pattern, '***', text) + return text + + def run_all_tools(self, text): + """ + 运行所有标准化工具 + """ + text = self.normalize_text_format(text) + text = self.normalize_punctuation(text) + text = self.normalize_text_content(text, self.term_replacements) + text = self.mask_sensitive_data(text, self.sensitive_words, self.sensitive_patterns) + text = self.normalize_encoding(text) + return text + + +class FormatTools: + @staticmethod + async def chat_with_llm(llm, prompt, text, prev_text, front_text, file_name) -> dict: + """ + 对于给定的文本,通过llm按照prompt的格式 + params: + - llm: LLm + - text: str + - prompt: str + return: + - qa_pairs: list[dict] + + """ + text.replace("\"", "\\\"") + user_call = (f"文本内容来自于{file_name},请以标准格式的输出格式化后的段落") + prompt = prompt.format(chunk=text, text=prev_text, file_name=file_name, front_text=front_text) + # print(prompt) + # logging.info(f"prompt: {prompt}") + answer = await llm.nostream([], prompt, user_call) + # logging.info(f"answer: {answer}") + # print(answer) + + return answer + + +# 文档去重 +class DocumentGovernance: + + async def unique(self, chunks): + """ + 文档去重,输入为单个文档的路径,可以实现对于文档内容的解析,文本内容按照段落划分 + 对于解析出的文本块,实现文本块内去重+文本块间去重 + """ + unique_model = UniqueTools() + new_chunks = [] + # 文本块内去重 + for chunk in chunks: + if chunk['type'] == 'image': + new_chunks.append(chunk) + continue + new_text = unique_model.deduplicate(chunk['text']) + new_chunks.append({ + 'text': new_text, + 'type': chunk['type'], + }) + # 文本块间去重 + return unique_model.deduplicate_blocks(new_chunks) + + async def standardize(self, chunks): + """ + 文档标准化,输入为单个文档的路径,对于每一段实现文档的标准化, + 标准化主要为: + - 文本格式标准化:缩进、空白字符、换行等 + - 标点符号标准化:全半角等 + - 文本内容标准化:大小写、术语替换、数字格式统一等 + - 文档编码标准化:确保文档文件的编码一致,如统一采用utf-8。 + - 敏感数据处理:敏感词、敏感信息屏蔽,支持自定义,敏感词表存储在common/sensitive_words.txt, + 敏感正则表达式保存在common/sensitive_patterns.txt; + 敏感信息包括密码、账号名等,需要修改成***。 + 输出统一为docx或md, 图像不处理, 正常插入到原位置。 + """ + new_chunks = [] + model = NormalizeTools() + for chunk in chunks: + if chunk['type'] == 'image': + new_chunks.append(chunk) + continue + chunk['text'] = model.run_all_tools(chunk['text']) + new_chunks.append(chunk) + return new_chunks + + @staticmethod + async def format(chunks, prompt, llm, file_name): + """ + 文档格式化。输入为单个文档的路径,实现文档的格式化,如: + 1、默认格式:每个文本段使用三段论的方式进行总结。 + 2、自定义模式:每个文本段使用用户自定义的prompt进行总结。 + """ + new_chunks = [] + answer = "" + prev_texts = [] + now_text = "" + for i, chunk in enumerate(chunks): + prev_text = "\n".join(prev_texts) + now_text = now_text + str(chunk) + if i != len(chunks) - 1 and tokenize(now_text) < config['MAX_TOKENS'] // 8: + continue + while tokenize(prev_text) + tokenize(now_text) > config['MAX_TOKENS'] // 4: + prev_texts.pop(0) + prev_text = "\n".join(prev_texts) + count = 0 + while count < 5: + try: + answer = await FormatTools.chat_with_llm(llm=llm, prompt=prompt, text=now_text, prev_text=prev_text, + front_text=answer, file_name=file_name) + new_chunks.append({ + 'text': answer, + 'type': chunk['type'], + 'original_text': now_text, + }) + count = 5 + except Exception as e: + count = count + 1 + print(f"retry {count} times due to error:", e) + # logging.error(f"Failed to chat with llm due to: {e}") + time.sleep(1) + now_text = "" + + return new_chunks + + @staticmethod + def output_chunks_to_file(output_path, chunks, file_name, file_extension="doc"): + # 检测output_path是否存在 + if not os.path.exists(output_path): + os.makedirs(output_path) + + # 构建完整的文件路径 + file_path = os.path.join(output_path, f"{file_name}.{file_extension}") + file_path_xlsx = os.path.join(output_path, f"{file_name}.xlsx") + df = pd.DataFrame(chunks) + os.makedirs(os.path.dirname(file_path_xlsx), exist_ok=True) + df.to_excel(file_path_xlsx, index=False) + print(f'Excel结果已输出到{file_path_xlsx}') + if file_extension == "docx": + # 创建一个新的Word文档 + doc = Document() + + # 将每个chunk['text']添加到文档中 + for chunk in chunks: + doc.add_paragraph(chunk['text'] + '\n') + + # 保存文档到指定路径 + doc.save(file_path) + print(f"文档已保存到{file_path}") + + elif file_extension == "md": + # 打开文件以写入Markdown内容 + with open(file_path, 'w', encoding='utf-8') as md_file: + for chunk in chunks: + md_file.write(f"{chunk['text']}\n\n") + print(f"文档已保存到{file_path}") + + else: + print("不支持的格式") + # logging.info(f"Document saved to {file_path} down") diff --git a/utils/service/embedding_training.py b/utils/service/embedding_training.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0678950826492a0705ed0711a07229d8a923df --- /dev/null +++ b/utils/service/embedding_training.py @@ -0,0 +1,80 @@ +import os + +import numpy as np +import subprocess +class EmbeddingTraining: + + @staticmethod + def parser(args): + # 定义要执行的命令 + command = [ + "torchrun", + "--nproc_per_node", str(args.gpu_num), + "-m", "FlagEmbedding.finetune.embedder.encoder_only.base", + "--model_name_or_path", str(args.model_name_or_path), + "--cache_dir", "./cache/model", + "--train_data", str(args.train_data), + "--cache_path", "./cache/data", + "--train_group_size", "8", + "--query_max_len", str(args.tokens), + "--passage_max_len", str(args.tokens), + "--pad_to_multiple_of", "8", + "--query_instruction_for_retrieval", "\"Represent this sentence for searching relevant passages: \"", + "--query_instruction_format", "{}{}", + "--knowledge_distillation", "False", + "--output_dir", str(args.output_path), + "--overwrite_output_dir", + "--learning_rate", str(args.learning_rate), + "--num_train_epochs", str(args.epochs), + "--per_device_train_batch_size", str(args.batch_size), + "--dataloader_drop_last", "True", + "--warmup_ratio", str(args.warmup), + "--gradient_checkpointing", + "--deepspeed", str(args.deepspeed), + "--logging_steps", str(args.logging_steps), + "--save_steps", str(args.save_steps), + "--negatives_cross_device", + "--temperature", str(args.temperature), + "--sentence_pooling_method", "cls", + "--normalize_embeddings", "True", + "--kd_loss_type", "kl_div", + "--ddp_find_unused_parameters=False", + '>', './logs/embedding/temp.log' # 将输出重定向到 temp.log + ] + return command + + # 执行 torchrun 命令 + @staticmethod + def run_command(command): + try: + # 使用 subprocess 执行命令并捕获输出 + print("running...") + result = subprocess.run(' '.join(command), shell=True, check=True) + print("running successfully.") + except subprocess.CalledProcessError as e: + print(f"Error occurred during running: {e}") + return False + return True + + # 执行 get_reports.py 脚本 + def run(self, args): + if not os.path.exists("./logs"): + os.mkdir("./logs") + if not os.path.exists("./logs/embedding"): + os.mkdir("./logs/embedding") + if not self.run_command(self.parser(args)): + return + if not self.run_command(["python", "utils/my_tools/bge_finetune/eval.py", "--encoder", f"{args.model_name_or_path}", + "--test_data", f"{args.test_data}", '>', './logs/embedding/base.log']): + return + if not self.run_command(["python", "utils/my_tools/bge_finetune/eval.py", "--encoder", f"{args.output_path}", + "--test_data", f"{args.test_data}", '>', './logs/embedding/final.log']): + return + if not self.run_command(["python", "utils/my_tools/bge_finetune/get_report.py"]): + return + print("Embedding fine tuning down") + + + +if __name__ == '__main__': + EmbeddingTraining.run() \ No newline at end of file diff --git a/utils/service/qa_generate.py b/utils/service/qa_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..db0dc1d253e19495fe6db3be29d11ccda87defdc --- /dev/null +++ b/utils/service/qa_generate.py @@ -0,0 +1,197 @@ +import asyncio +import json +import time + +import pandas as pd +import yaml +import os +from utils.config.config import config +import random +import jieba + +from utils.my_tools.logger import logger as logging + +def tokenize(text): + return len(list(jieba.cut(str(text))))*1.5 + + +def get_random_number(l, r): + if l >= r: + return r - 1 + return random.randint(l, r - 1) + + +class QAgenerator: + + async def qa_generate(self, chunks, file, qa_count, prompt, llm, enhance): + """ + 多线程生成问答对 + """ + start_time = time.time() + results = [] + prev_texts = [] + ans = 0 + # 使用 asyncio.gather 来并行处理每个 chunk + tasks = [] + # 获取 chunks 的长度 + num_chunks = len(chunks) + if qa_count > 20 * (num_chunks-2): + qa_count = max(1,20 * (num_chunks-2)) + for chunk in chunks: + chunk['count'] = 0 + for i in range(qa_count): + x = get_random_number(min(3, num_chunks - 1), num_chunks) + if x >= num_chunks: + x = num_chunks-1 + chunks[x]['count'] = chunks[x]['count'] + 1 + + now_text = "" + count = 0 + for chunk in chunks: + now_text = now_text + chunk['text'] + '\n' + count = count + chunk['count'] + if count >= 10 or count >= qa_count: + while count > 0: + temp_count = min(count, 10) + if len(tasks) == 5: + await asyncio.gather(*tasks) + tasks = [] + tasks.append( + self.generate(llm, prompt, now_text, prev_texts, results, file, temp_count, chunk['type'], enhance)) + qa_count = qa_count - temp_count + count = count - temp_count + ans = ans + temp_count + if tokenize(now_text) > (config['MAX_TOKENS'] // 8): + prev_texts.append(now_text) + now_text = '' + if qa_count > 0: + tasks.append( + self.generate(llm, prompt, now_text, prev_texts, results, file, qa_count, chunks[-1]['type'], enhance)) + ans = ans + qa_count + # 等待所有任务完成 + await asyncio.gather(*tasks) + print('问答对案例:', results[0]) + print("问答对生成总计用时:", time.time() - start_time) + print(f"文件 {file} 总计生成 {ans} 条问答对") + return results, ans + + async def check_qa(self, llm, now_text, prev_text, results_temp): + prompt_check = ''' +你是一个问答对检查专家,请根据给出的上下文和段落内容,判断问答对是否能够描述段落的内容 + +注意: + +1. 只要回答"是"或者"否" + +2. 不要输出多余内容 + +下面是给出的段落内容: +{chunk} + +下面是段落的上下文内容: +{text} + +下面是被生成的问答对: +{qa} + ''' + prompt_check = prompt_check.format(chunk=now_text, text=prev_text, qa=results_temp) + user_call = "请判断问答对是否正确,如果正确,请输出“是”,否则请输出“否,不要输出多余内容" + return await llm.nostream([], prompt_check, user_call) + + async def generate(self, llm, prompt, now_text, prev_texts, results, file, temp_count, text_type, enhance=False): + """ + 生成问答 + """ + prev_text = '\n'.join(prev_texts) + + while tokenize(prev_text) > (config['MAX_TOKENS'] // 12): + prev_texts.pop(0) + prev_text = '\n'.join(prev_texts) + count = 0 + while count < 5: + try: + # 使用多线程处理 chat_with_llm 调用 + print(f"本次生成{temp_count}对") + result_temp = await self.chat_with_llm(llm, prompt, now_text, prev_text, + temp_count, file) + if enhance and await self.check_qa(llm, now_text, prev_text, result_temp) == "否": + count += 1 + print("重新生成") + continue + for result in result_temp: + result['text'] = now_text + result['text_with_prev'] =prev_text + now_text + result['text_type'] = text_type + results.append(result) + count = 5 + except Exception as e: + count += 1 + print('error:', e, 'retry times', count) + time.sleep(1) + if count == 5: + for i in range(temp_count): + results.append( + {'text': now_text, 'question': '无法生成问答对', 'answer': '无法生成问答对', 'type': 'error', + 'text_type': text_type}) + + @staticmethod + async def chat_with_llm(llm, prompt, text, prev_text, qa_count, file_name) -> dict: + """ + 对于给定的文本,通过llm生成问题-答案-段落对。 + params: + - llm: LLm + - text: str + - prompt: str + return: + - qa_pairs: list[dict] + + """ + text.replace("\"", "\\\"") + user_call = (f"文本内容来自于{file_name},请以正确的JSON格式输出{qa_count}对不同的问题-答案-领域,格式为[" + "{" + "\"question\": \" 问题 \", " + "\"answer\": \" 回答 \"," + "\"type\": \" 领域 \"" + "}\n" + "],并且必须将问题和回答中和未被转义的双引号和逗号转义,元素标签请用双引号括起来,不要输出多余内容") + prompt = prompt.format(chunk=text, qa_count=qa_count, text=prev_text, file_name=file_name) + # print(prompt)0 + # logging.info(f"prompt: {prompt}") + qa_pair = await llm.nostream([], prompt, user_call) + # 提取问题、答案段落对的list,字符串格式为["问题","答案","段落对"] + # logging.info(f"qa_pair: {qa_pair}") + # print(qa_pair) + # print("原文:", text) + qa_pair = json.loads(qa_pair) + return qa_pair + + async def output_results(self, results, file_name, output_path, output_format): + """ + 将结果输出到指定路径 + params: + - results: list[dict] + - output_path: str + - output_format: str + """ + if not os.path.exists(output_path): + os.makedirs(output_path) + try: + # 输出文件名为file+时间 + output_file = os.path.join(output_path, f'{file_name}.{output_format}') + # print(output_file) + if output_format == 'json': + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(results, f, ensure_ascii=False, indent=4) + print(f'JSON结果已输出到{output_file}') + elif output_format == 'yaml': + with open(output_file, 'w', encoding='utf-8') as f: + yaml.dump(results, f, allow_unicode=True) + print(f'YAML结果已输出到{output_file}') + elif output_format == 'xlsx': # 输出到xlsx文件 + df = pd.DataFrame(results) + os.makedirs(os.path.dirname(output_file), exist_ok=True) + df.to_excel(output_file, index=False) + print(f'Excel结果已输出到{output_file}') + except Exception as e: + print("error output") + raise e