diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cdf01eb0ea57ab63777d4a2e3da27ce279acc74f --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -0,0 +1,195 @@ +# MinerU(TorchAir)-推理指导 + +- [MinerU(TorchAir)-推理指导](#MinerU(TorchAir)-推理指导) +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [获取权重](#获取权重) + - [获取数据集](#获取数据集) + - [执行推理](#执行推理) + - [精度测试](#精度测试) + +****** + +# 概述 +MinerU是由上海人工智能实验室OpenDataLab团队开发的开源文档解析工具,致力于解决大模型(LLM)训练和RAG(检索增强生成)应用中高质量结构化数据的提取难题。其核心价值在于将复杂文档(如PDF、网页、电子书)转换为机器可读的Markdown、JSON格式,同时保留原始文档的语义逻辑与多模态元素。 + +- 版本说明: + + ``` + url=https://github.com/opendatalab/MinerU.git + commit_id=de41fa58590263e43b783fe224b6d07cae290a33 + model_name=MinerU + ``` + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------- | ----------- | --------------------------------------------------------------------------------------------- | + | 固件与驱动 | 25.2.RC1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.3.0 | - | + | Python | 3.11 | - | + | PyTorch | 2.6.0 | - | + | Ascend Extension PyTorch | 2.6.0 | - | + | 说明:Atlas 800I A2/Atlas 300I Pro 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + +# 快速上手 + +## 获取源码 + +1. 获取`Pytorch`源码 + + ``` + git clone https://github.com/opendatalab/MinerU.git + cd MinerU + git reset --hard de41fa58590263e43b783fe224b6d07cae290a33 + pip3 install -e . + cd .. + ``` + +2. 安装依赖 + + ``` + pip3 install -r requirements.txt + ### 此外,还需安装 Torchvision Adapter + git clone https://gitee.com/ascend/vision.git vision_npu + cd vision_npu + git checkout v0.21.0-7.1.0 + pip3 install -r requirement.txt + source /usr/local/Ascend/ascend-toolkit/set_env.sh # Default path, change it if needed. + python setup.py bdist_wheel + cd dist + pip install torchvision_npu-0.21.0+git22ca6b2-cp311-cp311-linux_aarch64.whl + cd ../../ + ``` + + +3. 修改第三方库 +进入第三方库安装路径,默认为`source_path = /usr/local/lib/python3.11/site-packages`,通过工作目录`workdir`(自定义)中的`ultralytics.patch`和`doclayout_yolo.patch`进行修改 + ``` + source_path=/usr/local/lib/python3.11/site-packages + cd ${source_path}/ultralytics + patch -p2 < ${workdir}/ultralytics.patch + cd ${source_path}/doclayout_yolo + patch -p2 < ${workdir}/doclayout_yolo.patch + cd ${workdir} + patch -p0 < mfr_encoder_mhsa.patch + ``` + +## 获取权重 + +运行以下指令,下载权重文件[Model weights](https://www.modelscope.cn/models/OpenDataLab/PDF-Extract-Kit-1.0/summary),默认存放为`/root/.cache/modelscope/hub/models/OpenDataLab/PDF-Extract-Kit-1___0` + +``` +mineru-models-download --source modelscope --model_type pipeline +``` +下载完成后,默认在根目录生成`mineru.json`文件,移动数据集时,需修改`/root/mineru.json`文件中"models-dir": "pipeline"为修改后权重存放路径 + +权重目录大致结构为: +```text +📁 models +├── 📁 Layout +│   └── 📁 YOLO +│   └── doclayout_yolo_docstructbench_imgsz1280_2501.pt +├── 📁 MFD +│   └── 📁 YOLO +│   └── yolo_v8_ft.pt +├── 📁 MFR +│   └── 📁 unimernet_hf_small_2503 +│   ├── model.safetensors +│   ├── …… +│   └── tokenizer_config.json +├── 📁 OCR +│   └── 📁 paddleocr_torch +│   ├── Multilingual_PP-OCRv3_det_infer.pth +│   ├── arabic_PP-OCRv3_rec_infer.pth +│   ├── …… +│   ├── …… +│   └── th_PP-OCRv5_rec_infer.pth +├── 📁 ReadingOrder +│   └── 📁 layout_reader +│   ├── config.json +│   └── model.safetensors +└── 📁 TabRec + └── 📁 SlanetPlus + └── slanet-plus.onnx +``` + + +## 获取数据集 + +创建数据集目录`OmniDocBench_dataset`,下载多样性文档解析评测集`OmniDocBench`数据集的[pdfs和标注](https://opendatalab.com/OpenDataLab/OmniDocBench),解压并放置在`OmniDocBench_dataset`目录下 +文件目录格式大致如下: + ``` + 📁 workdir + ├── infer.py + ├── …… + ├── 📁 MinerU + └── 📁 OmniDocBench_dataset +   ├── OmniDocBench.json +   └── 📁 pdfs + └── ***.pdf + ``` + +## 执行推理 + +运行推理脚本infer.py + +``` +python3 infer.py --data_path=OmniDocBench_dataset --model_source=local +``` + +- 参数说明 + - data_path: 数据集路径 + - model_source: 模型源类型,local表示使用本地文件,modelscope/huggingface表示在线拉取权重 + +推理执行完成后,解析结果存放于`OmniDocBench_dataset/output/`目录,结果除了输出主要的 markdown 文件外,还会生成多个辅助文件用于调试、质检和进一步处理。 + +## 精度测试 + +使用`OmniDocBench`数据集配套评测代码测试精度。 + +1. 推理结果整理 +将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的results_md文件夹 + ``` + cp OmniDocBench_dataset/output/*/auto/*.md OmniDocBench_dataset/results_md/ + ``` + +2. 获取测评源码并构建环境 + + ``` + git clone https://github.com/opendatalab/OmniDocBench.git + cd OmniDocBench + conda create -n omnidocbench python=3.10 + conda activate omnidocbench + pip install -r requirements.txt + ``` + +3. 测评配置修改 +修改`OmniDocBench`测评代码中的config文件,具体来说,我们使用端到端测评配置,修改configs/end2end.yaml文件中的ground_truth的data_path为下载的OmniDocBench.json路径,修改prediction的data_path中提供整理的推理结果的文件夹路径,如下: + ``` + # -----以下是需要修改的部分 ----- + dataset: + dataset_name: end2end_dataset + ground_truth: + data_path: ../OmniDocBench_dataset/OmniDocBench.json + prediction: + data_path: ../OmniDocBench_dataset/result_md + ``` + +4. 精度测量结果 +配置好config文件后,只需要将config文件作为参数传入,运行以下代码即可进行评测: + ``` + python pdf_validation.py --config ./configs/end2end.yaml + ``` + + 在`OmniDocBench`数据集上的精度为: + |模型|芯片|overall_EN|overall_CH| + |------|------|------|------| + |MinerU|300I DUO|0.1588|0.2527| + |MinerU|800I A2 64G|0.1580|0.2510| + diff --git a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch new file mode 100644 index 0000000000000000000000000000000000000000..7cf22c0b32204da3a0f7529818fa771bd739fe1b --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch @@ -0,0 +1,50 @@ +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo-0.0.4_fix/doclayout_yolo/engine/predictor.py +--- doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo-0.0.4_fix/doclayout_yolo/engine/predictor.py 2025-09-09 16:05:20.011737230 +0800 +@@ -152,7 +152,8 @@ + (list): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 +- letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride) ++ letterbox = LetterBox(self.imgsz, auto=False, stride=self.model.stride) ++ # letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): +@@ -225,7 +226,8 @@ + + # Warmup model + if not self.done_warmup: +- self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) ++ # self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) ++ self.model.warmup(imgsz=(self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo-0.0.4_fix/doclayout_yolo/nn/modules/block.py +--- doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo-0.0.4_fix/doclayout_yolo/nn/modules/block.py 2025-09-09 16:05:20.019737230 +0800 +@@ -230,7 +230,9 @@ + def forward(self, x): + """Forward pass through C2f layer.""" + y = list(self.cv1(x).chunk(2, 1)) +- y.extend(m(y[-1]) for m in self.m) ++ # y.extend(m(y[-1]) for m in self.m) ++ for m in self.m: ++ y.append(m(y[-1])) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x): + +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_fix/doclayout_yolo/utils/tal.py +--- doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo-0.0.4_fix/doclayout_yolo/utils/tal.py 2025-09-09 16:05:20.023737230 +0800 +@@ -328,7 +328,8 @@ + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) +- stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) ++ # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) ++ stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) + return torch.cat(anchor_points), torch.cat(stride_tensor) \ No newline at end of file diff --git a/ACL_PyTorch/built-in/ocr/MinerU/infer.py b/ACL_PyTorch/built-in/ocr/MinerU/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..b04545b7fc1da6c31eba709bdee13a5e22bad531 --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/infer.py @@ -0,0 +1,278 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import math +import time +import inspect + +from pathlib import Path +import argparse +from loguru import logger +import pypdfium2 as pdfium + +import torch +import torch_npu +import torch.nn as nn +import torchvision +import torchvision_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig + +from mineru.backend.pipeline.model_list import AtomicModel +from mineru.model.mfr.unimernet.unimernet_hf.unimer_swin.modeling_unimer_swin import UnimerSwinSelfAttention +from mineru.backend.pipeline.model_init import ( + AtomModelSingleton, + table_model_init, + mfd_model_init, + mfr_model_init, + doclayout_yolo_model_init, + ocr_model_init, + ) +from mineru.utils.model_utils import get_vram +from mineru.backend.pipeline.batch_analyze import ( + YOLO_LAYOUT_BASE_BATCH_SIZE, + MFD_BASE_BATCH_SIZE, + MFR_BASE_BATCH_SIZE, + ) + +from transformers.generation.utils import GenerationMixin +from MinerU.demo.demo import parse_doc + + +def parse_args(): + parser = argparse.ArgumentParser("MinerU infer") + parser.add_argument("--model_source", type=str, default="local", help="model checkpoint source") + parser.add_argument("--data_path", type=str, default="OmniDocBench_dataset") + parser.add_argument("--warmup", type=int, default=2, help="Warm up times") + parser.add_argument("--warmup_data_path", type=str, default="OmniDocBench_dataset/pdfs/jiaocai_71434495.pdf_0.pdf") + args = parser.parse_args() + return args + + +def atom_model_init_compile(model_name: str, **kwargs): + atom_model = None + if model_name == AtomicModel.Layout: + atom_model = doclayout_yolo_model_init( + kwargs.get('doclayout_yolo_weights'), + kwargs.get('device') + ) + atom_model.model.model = compile_model(atom_model.model.model, False, True) + npu_input = torch.zeros((batch_candidate[AtomicModel.Layout][0], 3, atom_model.imgsz, atom_model.imgsz)) + tng.inference.set_dim_gears(npu_input, {0: batch_candidate[AtomicModel.Layout]}) + + elif model_name == AtomicModel.MFD: + atom_model = mfd_model_init( + kwargs.get('mfd_weights'), + kwargs.get('device') + ) + atom_model.model.model = compile_model(atom_model.model.model, False, True) + npu_input = torch.zeros((batch_candidate[AtomicModel.MFD][0], 3, atom_model.imgsz, atom_model.imgsz)) + tng.inference.set_dim_gears(npu_input, {0: batch_candidate[AtomicModel.MFD]}) + + elif model_name == AtomicModel.MFR: + atom_model = mfr_model_init( + kwargs.get('mfr_weight_dir'), + kwargs.get('device') + ) + + modify_mfr_model(atom_model.model) + + atom_model.model.encoder = compile_model(atom_model.model.encoder, False, True) + atom_model.model.decoder = compile_model(atom_model.model.decoder, True, True) + + elif model_name == AtomicModel.OCR: + atom_model = ocr_model_init( + kwargs.get('det_db_box_thresh'), + kwargs.get('lang'), + kwargs.get('det_limit_side_len'), + ) + + elif model_name == AtomicModel.Table: + atom_model = table_model_init( + kwargs.get('lang'), + ) + + else: + logger.error('model name not allow') + raise ValueError("model name not allow") + + if atom_model is None: + logger.error('model init failed') + raise RuntimeError("model init failed") + + return atom_model + + +def rewrite_mfr_encoder_multi_head_attention_forward(model): + wq = model.query.weight + wk = model.key.weight + wv = model.value.weight + model.qkv = nn.Linear(in_features=wk.shape[1], out_features=wq.shape[0] + wk.shape[0] + wv.shape[0]) + model.qkv.weight = nn.Parameter(torch.concat([wq, wk, wv], dim=0), requires_grad=False) + wq_bias = model.query.bias if model.query.bias is not None else torch.zeros(wq.shape[0]) + wk_bias = model.key.bias if model.key.bias is not None else torch.zeros(wk.shape[0]) + wv_bias = model.key.bias if model.value.bias is not None else torch.zeros(wv.shape[0]) + model.qkv.bias = nn.Parameter(torch.concat([wq_bias, wk_bias, wv_bias], dim=0), requires_grad=False) + + +def modify_mfr_model(model): + # 修改encoder的attention forward + for _, module in model.encoder.named_modules(): + if isinstance(module, UnimerSwinSelfAttention): + rewrite_mfr_encoder_multi_head_attention_forward(module) + rewrite_mfr_encoder_forward() + + +def compile_model(model, dynamic, fullgraph): + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + npu_backend = tng.get_npu_backend(compiler_config=config) + compiled_model = torch.compile(model, dynamic=dynamic, fullgraph=fullgraph, backend=npu_backend) + return compiled_model + + +def rewrite_model_init(): + def _patched_getmodel(self, atom_model_name: str, **kwargs): + lang = kwargs.get('lang', None) + table_model_name = kwargs.get('table_model_name', None) + + if atom_model_name in [AtomicModel.OCR]: + key = (atom_model_name, lang) + elif atom_model_name in [AtomicModel.Table]: + key = (atom_model_name, table_model_name, lang) + else: + key = atom_model_name + + if key not in self._models: + self._models[key] = atom_model_init_compile(model_name=atom_model_name, **kwargs) + return self._models[key] + AtomModelSingleton.get_atom_model = _patched_getmodel + + +def rewrite_mfr_encoder_forward(): + def _patched_prepare_encoder_decoder_kwargs_for_generation(self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name, + generation_config, + ): + # 1. get encoder + encoder = self.get_encoder() + + # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value + for argument, value in encoder_kwargs.items() + if argument in encoder_signature + } + encoder_kwargs["output_attentions"] = generation_config.output_attentions + encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs["return_dict"] = True + + ####### 固定input_tensor形状 + pad_count = 0 + if batch_candidate[AtomicModel.MFR] != inputs_tensor.shape[0]: + pad_count = batch_candidate[AtomicModel.MFR] - inputs_tensor.shape[0] + padding_tensor = torch.zeros(pad_count, *inputs_tensor.shape[1:], dtype=inputs_tensor.dtype, device=inputs_tensor.device) + inputs_tensor = torch.cat((inputs_tensor, padding_tensor), dim=0) + + encoder_kwargs[model_input_name] = inputs_tensor + output = encoder(**encoder_kwargs)# type: ignore + if pad_count != 0: + output.last_hidden_state = output.last_hidden_state[:-pad_count] + output.pooler_output = output.pooler_output[:-pad_count] + model_kwargs["encoder_outputs"] = output + return model_kwargs + + GenerationMixin._prepare_encoder_decoder_kwargs_for_generation = _patched_prepare_encoder_decoder_kwargs_for_generation + + +def warmup(data_path, warmup_iters): + data_path = Path(data_path) + + output_dir = Path(data_path).parent + output_dir = os.path.join(output_dir, "warmup_res") + pdf_suffixes = [".pdf"] + image_suffixes = [".png", ".jpeg", ".jpg"] + supported_suffixes = pdf_suffixes + image_suffixes + + if data_path.suffix.lower() not in supported_suffixes: + raise ValueError( + f"Unsupported file type: '{data_path.suffix}'. " + f"Supported types: {supported_suffixes}" + ) + + doc_path_list = [data_path] * sum(batch_candidate[AtomicModel.Layout]) + for _ in range(warmup_iters): + parse_doc(doc_path_list, output_dir, backend="pipeline") + + +def get_pdf_page_count(pdf_path): + pdf = pdfium.PdfDocument(pdf_path) + try: + return len(pdf) + finally: + pdf.close() + + +if __name__ == '__main__': + args = parse_args() + os.environ['MINERU_MODEL_SOURCE'] = args.model_source + + __dir__ = args.data_path + pdf_files_dir = os.path.join(__dir__, "pdfs") + output_dir = os.path.join(__dir__, "output") + pdf_suffixes = [".pdf"] + image_suffixes = [".png", ".jpeg", ".jpg"] + + + print(pdf_files_dir) + batch_ratio = 16 + + rewrite_model_init() + + doc_path_list = [] + pdfs_page_count = 0 + for doc_path in Path(pdf_files_dir).glob('*'): + if doc_path.suffix in pdf_suffixes + image_suffixes: + doc_path_list.append(doc_path) + pdfs_page_count += get_pdf_page_count(doc_path) + + batch_candidate = { + AtomicModel.Layout: [YOLO_LAYOUT_BASE_BATCH_SIZE, pdfs_page_count % YOLO_LAYOUT_BASE_BATCH_SIZE], + AtomicModel.MFD: [MFD_BASE_BATCH_SIZE, pdfs_page_count % MFD_BASE_BATCH_SIZE], + AtomicModel.MFR: batch_ratio * MFR_BASE_BATCH_SIZE, + } + print(len(doc_path_list), batch_candidate) + warmup(args.warmup_data_path, args.warmup) + + print("******** 精度测试 **********") + start_time = time.time() + parse_doc(doc_path_list, output_dir, backend="pipeline") + print(f"per page process time: {(time.time()-start_time)/pdfs_page_count:.2f}s") diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch b/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch new file mode 100644 index 0000000000000000000000000000000000000000..1fe80a05cbfbdbee80ee84508469f256c48f777d --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch @@ -0,0 +1,23 @@ +--- MinerU/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py 2025-09-02 17:58:15.032000000 +0800 ++++ copy_mfr.py 2025-09-10 13:58:36.616000000 +0800 +@@ -465,11 +465,15 @@ + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape +- mixed_query_layer = self.query(hidden_states) + +- key_layer = self.transpose_for_scores(self.key(hidden_states)) +- value_layer = self.transpose_for_scores(self.value(hidden_states)) +- query_layer = self.transpose_for_scores(mixed_query_layer) ++ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" ++ batch_size, dim, num_channels = hidden_states.shape ++ qkv = self.qkv(hidden_states) ++ q, k, v = qkv.chunk(3, dim=-1) ++ ++ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + diff --git a/ACL_PyTorch/built-in/ocr/MinerU/requirements.txt b/ACL_PyTorch/built-in/ocr/MinerU/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b98527828a2fec706eb348301bb463a1a0ec238a --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/requirements.txt @@ -0,0 +1,34 @@ +boto3==1.40.24 +click==8.2.1 +loguru==0.7.3 +numpy==2.2.6 +pandas==2.3.2 +pdfminer.six==20250506 +tqdm==4.67.1 +requests +httpx +pillow==11.3.0 +pypdfium2==4.30.0 +pypdf==6.0.0 +reportlab==4.4.3 +pdftext==0.6.3 +modelscope==1.29.2 +huggingface-hub==0.34.4 +json-repair==0.50.0 +opencv-python==4.12.0.88 +fast-langdetect==0.2.5 +matplotlib==3.10.6 +ultralytics==8.3.193 +doclayout_yolo==0.0.4 +dill==0.3.8 +rapid_table==1.0.5 +PyYAML==6.0.2 +ftfy==6.3.1 +openai==1.106.1 +shapely==2.1.1 +pyclipper==1.3.0.post6 +omegaconf==2.3.0 +torch==2.6.0 +torch_npu==2.6.0 +torchvision==0.21.0 +transformers==4.56.1 \ No newline at end of file diff --git a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch new file mode 100644 index 0000000000000000000000000000000000000000..4fab87d6054cf5d7128bb382e97b4cc33f6f6951 --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch @@ -0,0 +1,77 @@ +diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultralytics/engine/predictor.py +--- ultralytics-8.3.193/ultralytics/engine/predictor.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics_/ultralytics/engine/predictor.py 2025-09-09 14:56:14.535737230 +0800 +@@ -196,9 +196,10 @@ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, +- auto=same_shapes +- and self.args.rect +- and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)), ++ # auto=same_shapes ++ # and self.args.rect ++ # and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)), ++ auto=False, + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] +@@ -311,8 +312,11 @@ + + # Warmup model + if not self.done_warmup: ++ # self.model.warmup( ++ # imgsz=(1 if self.models.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz) ++ # ) + self.model.warmup( +- imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz) ++ imgsz=(self.dataset.bs, self.model.ch, *self.imgsz) + ) + self.done_warmup = True + +@@ -400,7 +404,8 @@ + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, +- fuse=True, ++ # fuse=True, ++ fuse=False, + verbose=verbose, + ) + +diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics_/ultralytics/nn/modules/block.py +--- ultralytics-8.3.193/ultralytics/nn/modules/block.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics_/ultralytics/nn/modules/block.py 2025-09-09 14:56:14.543737230 +0800 +@@ -237,7 +237,9 @@ + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply sequential pooling operations to input and return concatenated feature maps.""" + y = [self.cv1(x)] +- y.extend(self.m(y[-1]) for _ in range(3)) ++ # y.extend(self.m(y[-1]) for _ in range(3)) ++ for _ in range(3): ++ y.append(self.m(y[-1])) + return self.cv2(torch.cat(y, 1)) + + +@@ -315,7 +317,9 @@ + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through C2f layer.""" + y = list(self.cv1(x).chunk(2, 1)) +- y.extend(m(y[-1]) for m in self.m) ++ # y.extend(m(y[-1]) for m in self.m) ++ for m in self.m: ++ y.append(m(y[-1])) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x: torch.Tensor) -> torch.Tensor: + +diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/utils/tal.py +--- ultralytics-8.3.193/ultralytics/utils/tal.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics_/ultralytics/utils/tal.py 2025-09-09 14:56:14.551737230 +0800 +@@ -375,7 +375,8 @@ + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) +- stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) ++ # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) ++ stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) + return torch.cat(anchor_points), torch.cat(stride_tensor) \ No newline at end of file