diff --git a/ACL_PyTorch/built-in/ocr/MinerU/README.md b/ACL_PyTorch/built-in/ocr/MinerU/README.md index af345076b8c1e53ee78a9d3e24d245f2947f81cf..00f78f10c41d99fc5ada514d2fbe0c1174acc213 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/README.md +++ b/ACL_PyTorch/built-in/ocr/MinerU/README.md @@ -44,6 +44,8 @@ MinerU是由上海人工智能实验室OpenDataLab团队开发的开源文档解 1. 获取`Pytorch`源码 ``` + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git + cd ModelZoo-PyTorch/ACL_PyTorch/built-in/ocr/MinerU git clone https://github.com/opendatalab/MinerU.git cd MinerU git reset --hard de41fa58590263e43b783fe224b6d07cae290a33 @@ -71,13 +73,15 @@ MinerU是由上海人工智能实验室OpenDataLab团队开发的开源文档解 3. 修改第三方库 进入第三方库安装路径,默认为`source_path = /usr/local/lib/python3.11/site-packages`,通过工作目录`workdir`(自定义)中的`ultralytics.patch`和`doclayout_yolo.patch`进行修改 ``` - source_path=/usr/local/lib/python3.11/site-packages + workdir=$(pwd) + source_path=$(pip show ultralytics | grep Location | awk '{print $2}') cd ${source_path}/ultralytics - patch -p2 < ${workdir}/ultralytics.patch + patch -p1 < ${workdir}/ultralytics.patch cd ${source_path}/doclayout_yolo - patch -p2 < ${workdir}/doclayout_yolo.patch - cd ${workdir} - patch -p0 < mfr_encoder_mhsa.patch + patch -p1 < ${workdir}/doclayout_yolo.patch + cd ${workdir}/MinerU + git apply ../mineru.patch + cd .. ``` ## 获取权重 @@ -155,33 +159,76 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local 1. 推理结果整理 - 将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的results_md文件夹 + 将解析结果文件夹中的markdown文件整理放置于同一目录,本例将所有markdown文件存放于OmniDocBench_dataset目录下的`end2end`文件夹 ``` - cp OmniDocBench_dataset/output/*/auto/*.md OmniDocBench_dataset/results_md/ + cp OmniDocBench_dataset/output/*/auto/*.md OmniDocBench_dataset/end2end/ ``` 2. 获取测评源码并构建环境 + + - 安装OmniDocBench基础环境 ``` git clone https://github.com/opendatalab/OmniDocBench.git cd OmniDocBench - git reset --hard dc96d812d219960773399c02ae8f89e4706120d4 + git reset --hard 523fd1d529c3e9d0088c662e983aa70fb9585c9a conda create -n omnidocbench python=3.10 conda activate omnidocbench pip install -r requirements.txt ``` + - 公式精度指标CDM需要额外安装环境 + + step.1 install nodejs + ``` + wget https://nodejs.org/dist/v16.13.1/node-v16.13.1-linux-arm64.tar.xz + tar -xf node-v16.13.1-linux-arm64.tar.xz + mv node-v16.13.1-linux-arm64/* /usr/local/nodejs/ + ln -s /usr/local/nodejs/bin/node /usr/local/bin + ln -s /usr/local/nodejs/bin/npm /usr/local/bin + node -v + ``` + + step.2 install imagemagic + ``` + git clone https://github.com/ImageMagick/ImageMagick.git ImageMagick-7.1.2 + cd ImageMagick-7.1.2 + apt-get update && apt-get install -y libpng-dev zlib1g-dev + apt-get install -y ghostscript + ./configure + make + sudo make install + sudo ldconfig /usr/local/lib + convert --version + ``` + + step.3 install latexpdf + ``` + sudo apt-get install texlive-full + ``` + + step.4 install python requriements + ``` + pip install -r metrics/cdm/requirements.txt + ``` + 3. 测评配置修改 修改`OmniDocBench`测评代码中的config文件,具体来说,我们使用端到端测评配置,修改configs/end2end.yaml文件中的ground_truth的data_path为下载的OmniDocBench.json路径,修改prediction的data_path中提供整理的推理结果的文件夹路径,如下: ``` # -----以下是需要修改的部分 ----- + display_formula: + metric: + - Edit_dist + - CDM ### 安装好CDM环境后,可以在config文件中设置并直接计算 + - CDM_plain + ... dataset: dataset_name: end2end_dataset ground_truth: data_path: ../OmniDocBench_dataset/OmniDocBench.json prediction: - data_path: ../OmniDocBench_dataset/results_md + data_path: ../OmniDocBench_dataset/end2end ``` 4. 精度测量结果 @@ -190,10 +237,18 @@ python3 infer.py --data_path=OmniDocBench_dataset --model_source=local ``` python pdf_validation.py --config ./configs/end2end.yaml ``` + 评测结果将会存储在result目录下,Overall指标的计算方式为: + $$\text{Overall} = \frac{(1-\textit{Text Edit Distance}) \times 100 + \textit{Table TEDS} +\textit{Formula CDM}}{3}$$ + + 运行overall_metric.py可以得到精度结果: + ``` + cd .. + python overall_metric.py + ``` - 在`OmniDocBench`数据集上的精度为: - |模型|芯片|overall_EN|overall_CH| + 在`OmniDocBench`数据集上的精度和性能数据分别为: + |模型|芯片|overall|性能(s)| |------|------|------|------| - |MinerU|300I DUO|0.1588|0.2527| - |MinerU|800I A2 64G|0.1580|0.2510| + |MinerU|300I DUO|81.68| 3.37 | + |MinerU|800I A2 64G|81.51| 1.85 | diff --git a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch index b5fd6669aa2dec34a5a4038305bb63deabe8c673..291a2914abbd98d4b04ead6c11c4022e4840e514 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/doclayout_yolo.patch @@ -1,7 +1,123 @@ -diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo-0.0.4_fix/doclayout_yolo/engine/predictor.py +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/data/loaders.py doclayout_yolo/data/loaders.py +--- doclayout_yolo-0.0.4/doclayout_yolo/data/loaders.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo/data/loaders.py 2025-10-19 01:27:41.984000000 +0800 +@@ -14,6 +14,7 @@ + import requests + import torch + from PIL import Image ++from torchvision.transforms import functional as TF + + from doclayout_yolo.data.utils import IMG_FORMATS, VID_FORMATS + from doclayout_yolo.utils import LOGGER, is_colab, is_kaggle, ops +@@ -411,7 +412,7 @@ + self.bs = len(self.im0) + + @staticmethod +- def _single_check(im): ++ def __single_check(im): ## origin _single_check + """Validate and format an image to numpy array.""" + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): +@@ -419,6 +420,18 @@ + im = im.convert("RGB") + im = np.asarray(im)[:, :, ::-1] + im = np.ascontiguousarray(im) # contiguous ++ ++ return im ++ ++ @staticmethod ++ def _single_check(im): ++ """Validate and format an image to numpy array.""" ++ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" ++ if isinstance(im, Image.Image): ++ if im.mode != "RGB": ++ im = im.convert("RGB") ++ im = np.asarray(im) ++ + return im + + def __len__(self): +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/model.py doclayout_yolo/engine/model.py +--- doclayout_yolo-0.0.4/doclayout_yolo/engine/model.py 2025-02-11 15:49:31.000000000 +0800 ++++ doclayout_yolo/engine/model.py 2025-10-19 01:27:41.988000000 +0800 +@@ -143,6 +143,8 @@ + else: + self._load(model, task=task) + ++ self.model.half() ++ + def __call__( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo/engine/predictor.py --- doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py 2025-02-11 15:49:31.000000000 +0800 -+++ doclayout_yolo-0.0.4_fix/doclayout_yolo/engine/predictor.py 2025-09-09 16:05:20.011737230 +0800 -@@ -152,7 +152,8 @@ ++++ doclayout_yolo/engine/predictor.py 2025-10-19 01:27:41.988000000 +0800 +@@ -47,6 +47,8 @@ + from doclayout_yolo.utils.files import increment_path + from doclayout_yolo.utils.torch_utils import select_device, smart_inference_mode + ++import torch.nn.functional as F ++ + STREAM_WARNING = """ + WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory + errors for large sources or long-running streams and videos. See https://docs.doclayout_yolo.com/modes/predict/ for help. +@@ -112,7 +114,7 @@ + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + +- def preprocess(self, im): ++ def _preprocess(self, im): ### origin preprocess + """ + Prepares input image before inference. + +@@ -132,6 +134,46 @@ + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + ++ ++ def preprocess(self, images): ### adapt preprocess ++ """ ++ Prepares input image before inference. ++ ++ Args: ++ images (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. ++ """ ++ new_shape = (new_shape, new_shape) if isinstance(self.imgsz, int) else self.imgsz ++ tensors = [] ++ for im in images: ++ im = torch.from_numpy(im).to(self.device).permute((2, 0, 1)) / 255.0 ++ ++ c, h, w = im.shape ++ ++ r = min(new_shape[0] / h, new_shape[1] / w) ++ ++ new_unpad = (int(round(w * r)), int(round(h * r))) ++ ++ if (w, h) != new_unpad: ++ im = F.interpolate(im.unsqueeze(0), size=(new_unpad[1], new_unpad[0]), ++ mode="bilinear", align_corners=False).squeeze(0) ++ ++ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ im = F.pad(im, (left, right, top, bottom), value=114/255.0) ++ ++ _, H, W = im.shape ++ assert (H, W) == (new_shape[0], new_shape[1]), f"Expected image size do not match: padding image size:{(H, W)} != expected image size: {(new_shape[0], new_shape[1])}" ++ ++ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 ++ ++ tensors.append(im) ++ ++ return torch.stack(tensors, dim=0) ++ ++ + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( +@@ -152,7 +194,8 @@ (list): A list of transformed images. """ same_shapes = len({x.shape for x in im}) == 1 @@ -11,7 +127,7 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo return [letterbox(image=x) for x in im] def postprocess(self, preds, img, orig_imgs): -@@ -225,7 +226,8 @@ +@@ -225,7 +268,8 @@ # Warmup model if not self.done_warmup: @@ -21,10 +137,9 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo self.done_warmup = True self.seen, self.windows, self.batch = 0, [], None - -diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo-0.0.4_fix/doclayout_yolo/nn/modules/block.py +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo/nn/modules/block.py --- doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py 2025-02-11 15:49:31.000000000 +0800 -+++ doclayout_yolo-0.0.4_fix/doclayout_yolo/nn/modules/block.py 2025-09-09 16:05:20.019737230 +0800 ++++ doclayout_yolo/nn/modules/block.py 2025-10-19 01:27:41.996000000 +0800 @@ -230,7 +230,9 @@ def forward(self, x): """Forward pass through C2f layer.""" @@ -36,10 +151,9 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo return self.cv2(torch.cat(y, 1)) def forward_split(self, x): - -diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_fix/doclayout_yolo/utils/tal.py +diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo/utils/tal.py --- doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py 2025-02-11 15:49:31.000000000 +0800 -+++ doclayout_yolo-0.0.4_fix/doclayout_yolo/utils/tal.py 2025-09-09 16:05:20.023737230 +0800 ++++ doclayout_yolo/utils/tal.py 2025-10-19 01:27:42.000000000 +0800 @@ -328,7 +328,8 @@ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) @@ -48,3 +162,4 @@ diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo-0.0.4_ + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) return torch.cat(anchor_points), torch.cat(stride_tensor) + diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch b/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch deleted file mode 100644 index 1fe80a05cbfbdbee80ee84508469f256c48f777d..0000000000000000000000000000000000000000 --- a/ACL_PyTorch/built-in/ocr/MinerU/mfr_encoder_mhsa.patch +++ /dev/null @@ -1,23 +0,0 @@ ---- MinerU/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py 2025-09-02 17:58:15.032000000 +0800 -+++ copy_mfr.py 2025-09-10 13:58:36.616000000 +0800 -@@ -465,11 +465,15 @@ - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - batch_size, dim, num_channels = hidden_states.shape -- mixed_query_layer = self.query(hidden_states) - -- key_layer = self.transpose_for_scores(self.key(hidden_states)) -- value_layer = self.transpose_for_scores(self.value(hidden_states)) -- query_layer = self.transpose_for_scores(mixed_query_layer) -+ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" -+ batch_size, dim, num_channels = hidden_states.shape -+ qkv = self.qkv(hidden_states) -+ q, k, v = qkv.chunk(3, dim=-1) -+ -+ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) -+ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) -+ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - diff --git a/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch new file mode 100644 index 0000000000000000000000000000000000000000..c031d1413a5ac52bbfeb5c5de663ff5c1fe83884 --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/mineru.patch @@ -0,0 +1,881 @@ +diff --git a/demo/demo.py b/demo/demo.py +index 36433c45..6f28620f 100644 +--- a/demo/demo.py ++++ b/demo/demo.py +@@ -86,7 +86,7 @@ def do_parse( + image_dir = str(os.path.basename(local_image_dir)) + content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) + md_writer.write_string( +- f"{pdf_file_name}_content_list.json", ++ f"{pdf_file_name}_content.json", + json.dumps(content_list, ensure_ascii=False, indent=4), + ) + +@@ -142,7 +142,8 @@ def do_parse( + image_dir = str(os.path.basename(local_image_dir)) + content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) + md_writer.write_string( +- f"{pdf_file_name}_content_list.json", ++ # f"{pdf_file_name}_content_list.json", ++ f"{pdf_file_name}_content.json", ## 文件名太长了,linux文件系统ext4超过255字节无法保存 + json.dumps(content_list, ensure_ascii=False, indent=4), + ) + +diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py +index c88a52a3..b0b79a80 100644 +--- a/mineru/backend/pipeline/batch_analyze.py ++++ b/mineru/backend/pipeline/batch_analyze.py +@@ -3,6 +3,9 @@ from loguru import logger + from tqdm import tqdm + from collections import defaultdict + import numpy as np ++import time ++import torch ++import torch_npu + + from .model_init import AtomModelSingleton + from ...utils.config_reader import get_formula_enable, get_table_enable +@@ -95,6 +98,7 @@ class BatchAnalyze: + }) + + # OCR检测处理 ++ from concurrent.futures import ThreadPoolExecutor, as_completed + if self.enable_ocr_det_batch: + # 批处理模式 - 按语言和分辨率分组 + # 收集所有需要OCR检测的裁剪图像 +@@ -139,79 +143,73 @@ class BatchAnalyze: + ) + + # 按分辨率分组并同时完成padding ++ stride = 64 + resolution_groups = defaultdict(list) + for crop_info in lang_crop_list: + cropped_img = crop_info[0] + h, w = cropped_img.shape[:2] + # 使用更大的分组容差,减少分组数量 + # 将尺寸标准化到32的倍数 +- normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数 +- normalized_w = ((w + 32) // 32) * 32 ++ normalized_h = ((h + stride) // stride) * stride # 向上取整到stride的倍数 ++ normalized_w = ((w + stride) // stride) * stride + group_key = (normalized_h, normalized_w) + resolution_groups[group_key].append(crop_info) + +- # 对每个分辨率组进行批处理 +- for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"): +- +- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数) +- max_h = max(crop_info[0].shape[0] for crop_info in group_crops) +- max_w = max(crop_info[0].shape[1] for crop_info in group_crops) +- target_h = ((max_h + 32 - 1) // 32) * 32 +- target_w = ((max_w + 32 - 1) // 32) * 32 +- +- # 对所有图像进行padding到统一尺寸 +- batch_images = [] +- for crop_info in group_crops: +- img = crop_info[0] +- h, w = img.shape[:2] +- # 创建目标尺寸的白色背景 +- padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 +- # 将原图像粘贴到左上角 +- padded_img[:h, :w] = img +- batch_images.append(padded_img) +- +- # 批处理检测 +- det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小 +- # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}") +- batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) +- +- # 处理批处理结果 +- for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): +- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info +- +- if dt_boxes is not None and len(dt_boxes) > 0: +- # 直接应用原始OCR流程中的关键处理步骤 +- from mineru.utils.ocr_utils import ( +- merge_det_boxes, update_det_boxes, sorted_boxes +- ) + +- # 1. 排序检测框 +- if len(dt_boxes) > 0: +- dt_boxes_sorted = sorted_boxes(dt_boxes) +- else: +- dt_boxes_sorted = [] +- +- # 2. 合并相邻检测框 +- if dt_boxes_sorted: +- dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) +- else: +- dt_boxes_merged = [] +- +- # 3. 根据公式位置更新检测框(关键步骤!) +- if dt_boxes_merged and adjusted_mfdetrec_res: +- dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) +- else: +- dt_boxes_final = dt_boxes_merged +- +- # 构造OCR结果格式 +- ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] +- +- if ocr_res: +- ocr_result_list = get_ocr_result_list( +- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang +- ) +- +- ocr_res_list_dict['layout_res'].extend(ocr_result_list) ++ def _run_one_group_ocr(group_key, group_crops): ++ ++ max_h = max(ci[0].shape[0] for ci in group_crops) ++ max_w = max(ci[0].shape[1] for ci in group_crops) ++ target_h = ((max_h + stride - 1) // stride) * stride ++ target_w = ((max_w + stride - 1) // stride) * stride ++ ++ batch_images = [] ++ for ci in group_crops: ++ img = ci[0] ++ h, w = img.shape[:2] ++ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255 ++ padded_img[:h, :w] = img ++ batch_images.append(padded_img) ++ ++ det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) ++ ++ batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size) ++ ++ for i, (ci, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)): ++ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = ci ++ if dt_boxes is not None and len(dt_boxes) > 0: ++ from mineru.utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes ++ ++ if len(dt_boxes) > 0: ++ dt_boxes_sorted = sorted_boxes(dt_boxes) ++ else: ++ dt_boxes_sorted = [] ++ ++ if dt_boxes_sorted: ++ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) ++ else: ++ dt_boxes_merged = [] ++ ++ if dt_boxes_merged and adjusted_mfdetrec_res: ++ dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res) ++ else: ++ dt_boxes_final = dt_boxes_merged ++ ++ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final] ++ if ocr_res: ++ ocr_result_list = get_ocr_result_list( ++ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang ++ ) ++ ocr_res_list_dict['layout_res'].extend(ocr_result_list) ++ ++ MAX_WORKERS = 4 ++ start = time.time() ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: ++ futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()] ++ for f in as_completed(futures): ++ f.result() ++ end = time.time() ++ logger.info(f"ocr det run time : {end -start}") + else: + # 原始单张处理模式 + for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"): +@@ -247,7 +245,7 @@ class BatchAnalyze: + + # 表格识别 table recognition + if self.table_enable: +- for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"): ++ def _run_one_group_table(table_res_dict): + _lang = table_res_dict['lang'] + table_model = atom_model_manager.get_atom_model( + atom_model_name='table', +@@ -271,6 +269,16 @@ class BatchAnalyze: + 'table recognition processing fails, not get html return' + ) + ++ ++ MAX_WORKERS = 4 ++ start = time.time() ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: ++ futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page] ++ for f in as_completed(futures): ++ f.result() ++ end = time.time() ++ logger.info(f"table run time : {end - start}") ++ + # Create dictionaries to store items by language + need_ocr_lists_by_lang = {} # Dict of lists for each language + img_crop_lists_by_lang = {} # Dict of lists for each language +diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py +index 5667a909..fc5056bb 100644 +--- a/mineru/model/layout/doclayout_yolo.py ++++ b/mineru/model/layout/doclayout_yolo.py +@@ -66,6 +66,7 @@ class DocLayoutYOLOModel: + conf=self.conf, + iou=self.iou, + verbose=False, ++ half=True + ) + for pred in predictions: + results.append(self._parse_prediction(pred)) +diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py +index 33dac091..1fb4b50e 100644 +--- a/mineru/model/mfd/yolo_v8.py ++++ b/mineru/model/mfd/yolo_v8.py +@@ -31,7 +31,8 @@ class YOLOv8MFDModel: + conf=self.conf, + iou=self.iou, + verbose=False, +- device=self.device ++ device=self.device, ++ half=True + ) + return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu() + +diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py +index ae3879da..23e56f2a 100644 +--- a/mineru/model/mfr/unimernet/Unimernet.py ++++ b/mineru/model/mfr/unimernet/Unimernet.py +@@ -1,7 +1,7 @@ + import torch + from torch.utils.data import DataLoader, Dataset + from tqdm import tqdm +- ++import numpy as np + + class MathDataset(Dataset): + def __init__(self, image_paths, transform=None): +@@ -61,7 +61,7 @@ class UnimernetModel(object): + res["latex"] = latex + return formula_list + +- def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: ++ def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: + images_formula_list = [] + mf_image_list = [] + backfill_list = [] +@@ -137,3 +137,94 @@ class UnimernetModel(object): + res["latex"] = latex + + return images_formula_list ++ ++ ++ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: ++ ++ images_formula_list = [] ++ mf_image_list = [] ++ backfill_list = [] ++ image_info = [] # Store (area, original_index, image) tuples ++ ++ # Collect images with their original indices ++ for image_index in range(len(images_mfd_res)): ++ mfd_res = images_mfd_res[image_index] ++ pil_img = images[image_index] ++ # split代替多次索引 ++ data = mfd_res.boxes.data.numpy() ++ xyxy, conf, cla = np.split(data, [4, 5], axis=-1) ++ ++ cla = cla.reshape(-1).astype(int).tolist() ++ conf = np.round(conf.reshape(-1).astype(float), 2).tolist() ++ ++ xyxy = xyxy.astype(np.int32) ++ xmin, ymin, xmax, ymax = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3] ++ # area 直接矩阵运算 ++ areas = (xmax - xmin) * (ymax - ymin) ++ ++ num_boxes = len(conf) ++ ++ formula_list = [] ++ for i in range(num_boxes): ++ xmin_i, ymin_i, xmax_i, ymax_i = xyxy[i].tolist() ++ formula_list.append({ ++ "category_id": 13 + cla[i], ++ "poly": [xmin_i, ymin_i, xmax_i, ymin_i, ++ xmax_i, ymax_i, xmin_i, ymax_i], ++ "score": conf[i], ++ "latex": "", ++ }) ++ ++ # bbox_img 截取 ++ # bbox_img = pil_img[:, ymin_i:ymax_i, xmin_i:xmax_i] ++ bbox_img = pil_img.crop((xmin_i, ymin_i, xmax_i, ymax_i)) ++ curr_idx = len(mf_image_list) ++ image_info.append((areas[i], curr_idx, bbox_img)) ++ mf_image_list.append(bbox_img) ++ ++ images_formula_list.append(formula_list) ++ backfill_list += formula_list ++ ++ # Stable sort by area ++ image_info.sort(key=lambda x: x[0]) # sort by area ++ sorted_indices = [x[1] for x in image_info] ++ sorted_images = [x[2] for x in image_info] ++ ++ # Create mapping for results ++ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} ++ ++ # Create dataset with sorted images ++ dataset = MathDataset(sorted_images, transform=self.model.transform) ++ ++ # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂 ++ batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1 ++ ++ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) ++ ++ # Process batches and store results ++ mfr_res = [] ++ # for mf_img in dataloader: ++ ++ with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar: ++ for index, mf_img in enumerate(dataloader): ++ mf_img = mf_img.to(dtype=self.model.dtype) ++ mf_img = mf_img.to(self.device) ++ with torch.no_grad(): ++ output = self.model.generate({"image": mf_img}, batch_size=batch_size) ++ mfr_res.extend(output["fixed_str"]) ++ ++ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size ++ current_batch_size = min(batch_size, len(sorted_images) - index * batch_size) ++ pbar.update(current_batch_size) ++ ++ # Restore original order ++ unsorted_results = [""] * len(mfr_res) ++ for new_idx, latex in enumerate(mfr_res): ++ original_idx = index_mapping[new_idx] ++ unsorted_results[original_idx] = latex ++ ++ # Fill results back ++ for res, latex in zip(backfill_list, unsorted_results): ++ res["latex"] = latex ++ ++ return images_formula_list +diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +index 98d1deee..3866a257 100644 +--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py ++++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py +@@ -5,7 +5,9 @@ import cv2 + import albumentations as alb + from albumentations.pytorch import ToTensorV2 + from torchvision.transforms.functional import resize +- ++import torch ++import torch_npu ++import torch.nn.functional as F + + # TODO: dereference cv2 if possible + class UnimerSwinImageProcessor(BaseImageProcessor): +@@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor): + ] + ) + +- def __call__(self, item): ++ self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu") ++ self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu") ++ self.mean = torch.tensor(0.7931, dtype=torch.float16, device="npu") ++ self.std = torch.tensor(0.1738, dtype=torch.float16, device="npu") ++ ++ self._mul_buf = torch.empty((3, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [3,H,W] ++ self._gray_buf = torch.empty((1, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [1,H,W] ++ ++ ++ def ___call__(self, item): + image = self.prepare_input(item) + return self.transform(image=image)['image'][:1] + ++ def pil_to_npu(self, pil_img, device="npu"): ++ img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16)) ++ img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR ++ return img ++ ++ def __call__(self, item): ++ ++ img = self.crop_margin(item) ++ img = self.pil_to_npu(img) ++ ++ _, h, w = img.shape ++ target_h, target_w = self.input_size ++ scale = min(target_h / h, target_w / w) ++ new_h, new_w = int(h*scale), int(w*scale) ++ ++ img = img.view(1, *img.shape) # [1,C,H,W] ++ img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False) ++ img = img.view(*img.shape[1:]) ++ ++ dw, dh = target_w - new_w, target_h - new_h ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ img = F.pad(img, (left, right, top, bottom), value=0.0) ++ ++ # RGB -> Gray ++ gray_tensor = (img * self.weights).sum(dim=0, keepdim=True) # [1, H, W] ++ ++ # Normalize ++ gray_tensor.sub_(self.mean).div_(self.std) ++ return gray_tensor ++ ++ + @staticmethod + def crop_margin(img: Image.Image) -> Image.Image: + data = np.array(img.convert("L")) +@@ -44,6 +89,32 @@ class UnimerSwinImageProcessor(BaseImageProcessor): + a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box + return img.crop((a, b, w + a, h + b)) + ++ def crop_margin_tensor(self, img): ++ """ ++ img: [C,H,W] tensor, uint8 或 float ++ """ ++ ++ gray = (img * self.weights).sum(dim=0) ++ ++ gray = gray.to(torch.uint8) ++ max_val = gray.max() ++ min_val = gray.min() ++ ++ if max_val == min_val: ++ return img ++ ++ norm_gray = (gray - min_val) / (max_val - min_val) ++ ++ mask = (norm_gray < self.threshold) ++ ++ coords = mask.nonzero(as_tuple=False) ++ if coords.shape[0] == 0: ++ return img ++ ymin, xmin = coords.min(0)[0] ++ ymax, xmax = coords.max(0)[0] ++ ++ return img[:, ymin:ymax+1, xmin:xmax+1] ++ + @staticmethod + def crop_margin_numpy(img: np.ndarray) -> np.ndarray: + """Crop margins of image using NumPy operations""" +diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +index 1b808e8b..0fe54751 100644 +--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py ++++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py +@@ -465,11 +465,15 @@ class UnimerSwinSelfAttention(nn.Module): + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape +- mixed_query_layer = self.query(hidden_states) + +- key_layer = self.transpose_for_scores(self.key(hidden_states)) +- value_layer = self.transpose_for_scores(self.value(hidden_states)) +- query_layer = self.transpose_for_scores(mixed_query_layer) ++ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法""" ++ batch_size, dim, num_channels = hidden_states.shape ++ qkv = self.qkv(hidden_states) ++ q, k, v = qkv.chunk(3, dim=-1) ++ ++ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) ++ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) +diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py +index 3de483ac..23813db9 100755 +--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py ++++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py +@@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20): + self.net.eval() + self.net.to(self.device) + ++ ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) ++ + def _batch_process_same_size(self, img_list): + """ + 对相同尺寸的图像进行批处理 +@@ -162,12 +166,12 @@ class TextDetector(BaseOCRV20): + return batch_results, time.time() - starttime + + # 批处理推理 +- with torch.no_grad(): +- inp = torch.from_numpy(batch_tensor) +- inp = inp.to(self.device) +- outputs = self.net(inp) +- +- # 处理输出 ++ with self._dev_lock: ++ with torch.no_grad(): ++ inp = torch.from_numpy(batch_tensor) ++ inp = inp.to(self.device) ++ outputs = self.net(inp) ++ # 处理输出 + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs['f_geo'].cpu().numpy() +@@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20): + img = img.copy() + starttime = time.time() + +- with torch.no_grad(): +- inp = torch.from_numpy(img) +- inp = inp.to(self.device) +- outputs = self.net(inp) ++ with self._dev_lock: ++ with torch.no_grad(): ++ inp = torch.from_numpy(img) ++ inp = inp.to(self.device) ++ outputs = self.net(inp) + + preds = {} + if self.det_algorithm == "EAST": +diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +index c06ca5fe..d865b201 100755 +--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py ++++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py +@@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20): + self.net.eval() + self.net.to(self.device) + ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) ++ + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR': +@@ -301,74 +304,78 @@ class TextRecognizer(BaseOCRV20): + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + elapse = 0 +- # for beg_img_no in range(0, img_num, batch_num): +- with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: +- index = 0 +- for beg_img_no in range(0, img_num, batch_num): +- end_img_no = min(img_num, beg_img_no + batch_num) +- norm_img_batch = [] +- max_wh_ratio = 0 +- for ino in range(beg_img_no, end_img_no): +- # h, w = img_list[ino].shape[0:2] +- h, w = img_list[indices[ino]].shape[0:2] +- wh_ratio = w * 1.0 / h +- max_wh_ratio = max(max_wh_ratio, wh_ratio) +- for ino in range(beg_img_no, end_img_no): +- if self.rec_algorithm == "SAR": +- norm_img, _, _, valid_ratio = self.resize_norm_img_sar( +- img_list[indices[ino]], self.rec_image_shape) +- norm_img = norm_img[np.newaxis, :] +- valid_ratio = np.expand_dims(valid_ratio, axis=0) +- valid_ratios = [] +- valid_ratios.append(valid_ratio) +- norm_img_batch.append(norm_img) +- +- elif self.rec_algorithm == "SVTR": +- norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], +- self.rec_image_shape) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- elif self.rec_algorithm == "SRN": +- norm_img = self.process_image_srn(img_list[indices[ino]], +- self.rec_image_shape, 8, +- self.max_text_length) +- encoder_word_pos_list = [] +- gsrm_word_pos_list = [] +- gsrm_slf_attn_bias1_list = [] +- gsrm_slf_attn_bias2_list = [] +- encoder_word_pos_list.append(norm_img[1]) +- gsrm_word_pos_list.append(norm_img[2]) +- gsrm_slf_attn_bias1_list.append(norm_img[3]) +- gsrm_slf_attn_bias2_list.append(norm_img[4]) +- norm_img_batch.append(norm_img[0]) +- elif self.rec_algorithm == "CAN": +- norm_img = self.norm_img_can(img_list[indices[ino]], +- max_wh_ratio) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- norm_image_mask = np.ones(norm_img.shape, dtype='float32') +- word_label = np.ones([1, 36], dtype='int64') +- norm_img_mask_batch = [] +- word_label_list = [] +- norm_img_mask_batch.append(norm_image_mask) +- word_label_list.append(word_label) +- else: +- norm_img = self.resize_norm_img(img_list[indices[ino]], +- max_wh_ratio) +- norm_img = norm_img[np.newaxis, :] +- norm_img_batch.append(norm_img) +- norm_img_batch = np.concatenate(norm_img_batch) +- norm_img_batch = norm_img_batch.copy() +- +- if self.rec_algorithm == "SRN": +- starttime = time.time() +- encoder_word_pos_list = np.concatenate(encoder_word_pos_list) +- gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) +- gsrm_slf_attn_bias1_list = np.concatenate( +- gsrm_slf_attn_bias1_list) +- gsrm_slf_attn_bias2_list = np.concatenate( +- gsrm_slf_attn_bias2_list) + ++ # for beg_img_no in range(0, img_num, batch_num): ++ from concurrent.futures import ThreadPoolExecutor, as_completed ++ def _rec_batch_worker(beg_img_no: int, end_img_no: int): ++ ++ ++ max_wh_ratio = 0.0 ++ norm_img_batch = [] ++ for ino in range(beg_img_no, end_img_no): ++ # h, w = img_list[ino].shape[0:2] ++ h, w = img_list[indices[ino]].shape[0:2] ++ wh_ratio = w * 1.0 / h ++ max_wh_ratio = max(max_wh_ratio, wh_ratio) ++ for ino in range(beg_img_no, end_img_no): ++ if self.rec_algorithm == "SAR": ++ norm_img, _, _, valid_ratio = self.resize_norm_img_sar( ++ img_list[indices[ino]], self.rec_image_shape) ++ norm_img = norm_img[np.newaxis, :] ++ valid_ratio = np.expand_dims(valid_ratio, axis=0) ++ valid_ratios = [] ++ valid_ratios.append(valid_ratio) ++ norm_img_batch.append(norm_img) ++ ++ elif self.rec_algorithm == "SVTR": ++ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], ++ self.rec_image_shape) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ elif self.rec_algorithm == "SRN": ++ norm_img = self.process_image_srn(img_list[indices[ino]], ++ self.rec_image_shape, 8, ++ self.max_text_length) ++ encoder_word_pos_list = [] ++ gsrm_word_pos_list = [] ++ gsrm_slf_attn_bias1_list = [] ++ gsrm_slf_attn_bias2_list = [] ++ encoder_word_pos_list.append(norm_img[1]) ++ gsrm_word_pos_list.append(norm_img[2]) ++ gsrm_slf_attn_bias1_list.append(norm_img[3]) ++ gsrm_slf_attn_bias2_list.append(norm_img[4]) ++ norm_img_batch.append(norm_img[0]) ++ elif self.rec_algorithm == "CAN": ++ norm_img = self.norm_img_can(img_list[indices[ino]], ++ max_wh_ratio) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ norm_image_mask = np.ones(norm_img.shape, dtype='float32') ++ word_label = np.ones([1, 36], dtype='int64') ++ norm_img_mask_batch = [] ++ word_label_list = [] ++ norm_img_mask_batch.append(norm_image_mask) ++ word_label_list.append(word_label) ++ else: ++ norm_img = self.resize_norm_img(img_list[indices[ino]], ++ max_wh_ratio) ++ norm_img = norm_img[np.newaxis, :] ++ norm_img_batch.append(norm_img) ++ norm_img_batch = np.concatenate(norm_img_batch) ++ norm_img_batch = norm_img_batch.copy() ++ ++ starttime = time.time() ++ ++ if self.rec_algorithm == "SRN": ++ starttime = time.time() ++ encoder_word_pos_list = np.concatenate(encoder_word_pos_list) ++ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) ++ gsrm_slf_attn_bias1_list = np.concatenate( ++ gsrm_slf_attn_bias1_list) ++ gsrm_slf_attn_bias2_list = np.concatenate( ++ gsrm_slf_attn_bias2_list) ++ ++ with self._dev_lock: + with torch.no_grad(): + inp = torch.from_numpy(norm_img_batch) + encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list) +@@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20): + + backbone_out = self.net.backbone(inp) # backbone_feat + prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp]) +- # preds = {"predict": prob_out[2]} +- preds = {"predict": prob_out["predict"]} +- +- elif self.rec_algorithm == "SAR": +- starttime = time.time() +- # valid_ratios = np.concatenate(valid_ratios) +- # inputs = [ +- # norm_img_batch, +- # valid_ratios, +- # ] +- ++ # preds = {"predict": prob_out[2]} ++ preds = {"predict": prob_out["predict"]} ++ ++ elif self.rec_algorithm == "SAR": ++ starttime = time.time() ++ # valid_ratios = np.concatenate(valid_ratios) ++ # inputs = [ ++ # norm_img_batch, ++ # valid_ratios, ++ # ] ++ ++ with self._dev_lock: + with torch.no_grad(): + inp = torch.from_numpy(norm_img_batch) + inp = inp.to(self.device) + preds = self.net(inp) + +- elif self.rec_algorithm == "CAN": +- starttime = time.time() +- norm_img_mask_batch = np.concatenate(norm_img_mask_batch) +- word_label_list = np.concatenate(word_label_list) +- inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] ++ elif self.rec_algorithm == "CAN": ++ starttime = time.time() ++ norm_img_mask_batch = np.concatenate(norm_img_mask_batch) ++ word_label_list = np.concatenate(word_label_list) ++ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list] + +- inp = [torch.from_numpy(e_i) for e_i in inputs] +- inp = [e_i.to(self.device) for e_i in inp] ++ inp = [torch.from_numpy(e_i) for e_i in inputs] ++ inp = [e_i.to(self.device) for e_i in inp] ++ with self._dev_lock: + with torch.no_grad(): + outputs = self.net(inp) + outputs = [v.cpu().numpy() for k, v in enumerate(outputs)] + +- preds = outputs +- +- else: +- starttime = time.time() ++ preds = outputs + ++ else: ++ with self._dev_lock: + with torch.no_grad(): +- inp = torch.from_numpy(norm_img_batch) +- inp = inp.to(self.device) ++ inp = torch.from_numpy(norm_img_batch).to(self.device) + prob_out = self.net(inp) ++ preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy() + +- if isinstance(prob_out, list): +- preds = [v.cpu().numpy() for v in prob_out] +- else: +- preds = prob_out.cpu().numpy() ++ rec_result = self.postprocess_op(preds) + +- rec_result = self.postprocess_op(preds) +- for rno in range(len(rec_result)): +- rec_res[indices[beg_img_no + rno]] = rec_result[rno] +- elapse += time.time() - starttime ++ for rno in range(len(rec_result)): ++ global_idx = indices[beg_img_no + rno] ++ rec_res[global_idx] = rec_result[rno] ++ ++ batch_elapse = time.time() - starttime ++ return len(rec_result), batch_elapse ++ ++ MAX_WORKERS = 4 ++ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, \ ++ tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar: ++ ++ futures = [] ++ for beg_img_no in range(0, img_num, batch_num): ++ end_img_no = min(img_num, beg_img_no + batch_num) ++ futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no)) + +- # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size +- current_batch_size = min(batch_num, img_num - index * batch_num) +- index += 1 +- pbar.update(current_batch_size) ++ for fut in as_completed(futures): ++ n_done, batch_elapse = fut.result() ++ elapse += batch_elapse ++ pbar.update(n_done) + + # Fix NaN values in recognition results + for i in range(len(rec_res)): +diff --git a/mineru/model/table/rapid_table.py b/mineru/model/table/rapid_table.py +index 174a8052..dd796bcc 100644 +--- a/mineru/model/table/rapid_table.py ++++ b/mineru/model/table/rapid_table.py +@@ -21,6 +21,8 @@ class RapidTableModel(object): + self.table_model = RapidTable(input_args) + self.ocr_engine = ocr_engine + ++ import threading ++ self._dev_lock = getattr(self, "_dev_lock", threading.Lock()) + + def predict(self, image): + bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) +@@ -30,44 +32,45 @@ class RapidTableModel(object): + img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0 + img_is_portrait = img_aspect_ratio > 1.2 + +- if img_is_portrait: ++ with self._dev_lock: ++ if img_is_portrait: + +- det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] +- # Check if table is rotated by analyzing text box aspect ratios +- is_rotated = False +- if det_res: +- vertical_count = 0 ++ det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0] ++ # Check if table is rotated by analyzing text box aspect ratios ++ is_rotated = False ++ if det_res: ++ vertical_count = 0 + +- for box_ocr_res in det_res: +- p1, p2, p3, p4 = box_ocr_res ++ for box_ocr_res in det_res: ++ p1, p2, p3, p4 = box_ocr_res + +- # Calculate width and height +- width = p3[0] - p1[0] +- height = p3[1] - p1[1] ++ # Calculate width and height ++ width = p3[0] - p1[0] ++ height = p3[1] - p1[1] + +- aspect_ratio = width / height if height > 0 else 1.0 ++ aspect_ratio = width / height if height > 0 else 1.0 + +- # Count vertical vs horizontal text boxes +- if aspect_ratio < 0.8: # Taller than wide - vertical text +- vertical_count += 1 +- # elif aspect_ratio > 1.2: # Wider than tall - horizontal text +- # horizontal_count += 1 ++ # Count vertical vs horizontal text boxes ++ if aspect_ratio < 0.8: # Taller than wide - vertical text ++ vertical_count += 1 ++ # elif aspect_ratio > 1.2: # Wider than tall - horizontal text ++ # horizontal_count += 1 + +- # If we have more vertical text boxes than horizontal ones, +- # and vertical ones are significant, table might be rotated +- if vertical_count >= len(det_res) * 0.3: +- is_rotated = True ++ # If we have more vertical text boxes than horizontal ones, ++ # and vertical ones are significant, table might be rotated ++ if vertical_count >= len(det_res) * 0.3: ++ is_rotated = True + +- # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") ++ # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}") + +- # Rotate image if necessary +- if is_rotated: +- # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") +- image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) +- bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) ++ # Rotate image if necessary ++ if is_rotated: ++ # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise") ++ image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE) ++ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + +- # Continue with OCR on potentially rotated image +- ocr_result = self.ocr_engine.ocr(bgr_image)[0] ++ # Continue with OCR on potentially rotated image ++ ocr_result = self.ocr_engine.ocr(bgr_image)[0] + if ocr_result: + ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if + len(item) == 2 and isinstance(item[1], tuple)] + diff --git a/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..5d89b41693aefb448d1bcd8c3cc3b9ec8f881c3b --- /dev/null +++ b/ACL_PyTorch/built-in/ocr/MinerU/overall_metric.py @@ -0,0 +1,47 @@ +import os +import json +import argparse + +import pandas as pd +import numpy as np + +parser = argparse.ArgumentParser(description='result path') +parser.add_argument('--result', type=str, default='OmniDocBench/result') +args = parser.parse_args() + + +ocr_types_dict = { + 'end2end': 'end2end' +} + +result_folder = args.result + +match_name = 'quick_match' + +# overall result: not distinguishing between Chinese and English, page-level average + +dict_list = [] + +for ocr_type in ocr_types_dict.values(): + result_path = os.path.join(result_folder, f'{ocr_type}_{match_name}_metric_result.json') + + with open(result_path, 'r') as f: + result = json.load(f) + + save_dict = {} + + for category_type, metric in [("text_block", "Edit_dist"), ("display_formula", "CDM"), ("table", "TEDS"), ("table", "TEDS_structure_only"), ("reading_order", "Edit_dist")]: + if metric == 'CDM' or metric == "TEDS" or metric == "TEDS_structure_only": + if result[category_type]["page"].get(metric): + save_dict[category_type + '_' + metric] = result[category_type]["page"][metric]["ALL"] * 100 # page级别的avg + else: + save_dict[category_type + '_' + metric] = 0 + else: + save_dict[category_type + '_' + metric] = result[category_type]["all"][metric].get("ALL_page_avg", np.nan) + + dict_list.append(save_dict) + +df = pd.DataFrame(dict_list, index=ocr_types_dict.keys()).round(3) +df['overall'] = ((1 - df['text_block_Edit_dist']) * 100 + df['display_formula_CDM'] + df['table_TEDS']) / 3 + +print(df) diff --git a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch index 5511fa6a9e750819a60e1d89c46380d4548cd49d..4baf8c7b1500f9941ddf1413d12f7bfd354bee22 100644 --- a/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch +++ b/ACL_PyTorch/built-in/ocr/MinerU/ultralytics.patch @@ -1,7 +1,120 @@ -diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultralytics/engine/predictor.py ---- ultralytics-8.3.193/ultralytics/engine/predictor.py 2025-09-04 19:51:11.000000000 +0800 -+++ ultralytics_/ultralytics/engine/predictor.py 2025-09-09 14:56:14.535737230 +0800 -@@ -196,9 +196,10 @@ +diff -ruN ultralytics/data/loaders.py ultralytics/data/loaders.py +--- ultralytics/data/loaders.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/data/loaders.py 2025-10-19 01:27:48.412000000 +0800 +@@ -534,7 +534,7 @@ + self.bs = len(self.im0) + + @staticmethod +- def _single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray: ++ def __single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray: + """Validate and format an image to numpy array, ensuring RGB order and contiguous memory.""" + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): +@@ -546,6 +546,19 @@ + im = im[..., None] + return im + ++ @staticmethod ++ def _single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray: ++ """Validate and format an image to numpy array, ensuring RGB order and contiguous memory.""" ++ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" ++ if isinstance(im, Image.Image): ++ if im.mode != "RGB": ++ im = im.convert("RGB") ++ im = np.asarray(im) ++ elif im.ndim == 2: # grayscale in numpy form ++ im = im[..., None] ++ return im ++ ++ + def __len__(self) -> int: + """Return the length of the 'im0' attribute, representing the number of loaded images.""" + return len(self.im0) +diff -ruN ultralytics/engine/model.py ultralytics/engine/model.py +--- ultralytics/engine/model.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/engine/model.py 2025-10-19 01:27:48.412000000 +0800 +@@ -152,6 +152,8 @@ + else: + self._load(model, task=task) + ++ self.model.half() ++ + # Delete super().training for accessing self.model.training + del self.training + +diff -ruN ultralytics/engine/predictor.py ultralytics/engine/predictor.py +--- ultralytics/engine/predictor.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/engine/predictor.py 2025-10-19 01:27:48.412000000 +0800 +@@ -43,6 +43,7 @@ + import cv2 + import numpy as np + import torch ++import torch.nn.functional as F + + from ultralytics.cfg import get_cfg, get_save_dir + from ultralytics.data import load_inference_source +@@ -149,7 +150,7 @@ + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + +- def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor: ++ def _preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor: + """ + Prepare input image before inference. + +@@ -174,6 +175,51 @@ + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + ++ def preprocess(self, images: torch.Tensor | list[np.ndarray]) -> torch.Tensor: ++ """ ++ Prepare input image before inference. ++ ++ Args: ++ images (torch.Tensor | List[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list. ++ ++ Returns: ++ (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W). ++ """ ++ ++ new_shape = (new_shape, new_shape) if isinstance(self.imgsz, int) else self.imgsz ++ tensors = [] ++ for im in images: ++ im = torch.from_numpy(im).to(self.device).permute((2, 0, 1)) / 255.0 ++ ++ c, h, w = im.shape ++ ++ r = min(new_shape[0] / h, new_shape[1] / w) ++ ++ new_unpad = (int(round(w * r)), int(round(h * r))) ++ ++ if (w, h) != new_unpad: ++ im = F.interpolate(im.unsqueeze(0), size=(new_unpad[1], new_unpad[0]), ++ mode="bilinear", align_corners=False).squeeze(0) ++ ++ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] ++ dw /= 2 ++ dh /= 2 ++ left, right = int(dw), int(dw + 0.5) ++ top, bottom = int(dh), int(dh + 0.5) ++ im = F.pad(im, (left, right, top, bottom), value=114/255.0) ++ ++ _, H, W = im.shape ++ assert (H, W) == (new_shape[0], new_shape[1]), f"Expected image size do not match: padding image size:{(H, W)} != expected image size: {(new_shape[0], new_shape[1])}" ++ ++ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 ++ ++ tensors.append(im) ++ ++ return torch.stack(tensors, dim=0) ++ ++ ++ ++ + def inference(self, im: torch.Tensor, *args, **kwargs): + """Run inference on a given image using the specified model and arguments.""" + visualize = ( +@@ -196,9 +242,10 @@ same_shapes = len({x.shape for x in im}) == 1 letterbox = LetterBox( self.imgsz, @@ -15,7 +128,7 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultra stride=self.model.stride, ) return [letterbox(image=x) for x in im] -@@ -311,8 +312,11 @@ +@@ -311,8 +358,11 @@ # Warmup model if not self.done_warmup: @@ -28,7 +141,7 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultra ) self.done_warmup = True -@@ -400,7 +404,8 @@ +@@ -400,7 +450,8 @@ dnn=self.args.dnn, data=self.args.data, fp16=self.args.half, @@ -38,9 +151,9 @@ diff -ruN ultralytics-8.3.193/ultralytics/engine/predictor.py ultralytics_/ultra verbose=verbose, ) -diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics_/ultralytics/nn/modules/block.py ---- ultralytics-8.3.193/ultralytics/nn/modules/block.py 2025-09-04 19:51:11.000000000 +0800 -+++ ultralytics_/ultralytics/nn/modules/block.py 2025-09-09 14:56:14.543737230 +0800 +diff -ruN ultralytics/nn/modules/block.py ultralytics/nn/modules/block.py +--- ultralytics/nn/modules/block.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/nn/modules/block.py 2025-10-19 01:27:48.424000000 +0800 @@ -237,7 +237,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply sequential pooling operations to input and return concatenated feature maps.""" @@ -63,10 +176,9 @@ diff -ruN ultralytics-8.3.193/ultralytics/nn/modules/block.py ultralytics_/ultra return self.cv2(torch.cat(y, 1)) def forward_split(self, x: torch.Tensor) -> torch.Tensor: - -diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/utils/tal.py ---- ultralytics-8.3.193/ultralytics/utils/tal.py 2025-09-04 19:51:11.000000000 +0800 -+++ ultralytics_/ultralytics/utils/tal.py 2025-09-09 14:56:14.551737230 +0800 +diff -ruN ultralytics/utils/tal.py ultralytics/utils/tal.py +--- ultralytics/utils/tal.py 2025-09-04 19:51:11.000000000 +0800 ++++ ultralytics/utils/tal.py 2025-10-19 01:27:48.428000000 +0800 @@ -375,7 +375,8 @@ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) @@ -75,3 +187,4 @@ diff -ruN ultralytics-8.3.193/ultralytics/utils/tal.py ultralytics_/ultralytics/ + # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride) return torch.cat(anchor_points), torch.cat(stride_tensor) +