From 7b8db33760c81f142f59d3af09802a2088c4f3c5 Mon Sep 17 00:00:00 2001 From: TD_lihan Date: Fri, 23 Feb 2024 17:54:58 +0800 Subject: [PATCH] add tomeunmerge --- torch_npu/meta/meta_registrations.py | 9 +++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index a7cfdcef15..f0e506cdfc 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -394,6 +394,15 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) +@impl(m, "npu_tome_unmerge") +def npu_tome_unmerge_meta(atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, top_r_rate=0.5): + dim_list = [] + dim_list.append(atten_out.size(0)) + dim_list.append(ori_indice_a.size(1) + ori_indice_b.size(1)) + dim_list.append(atten_out.size(2)) + return atten_out.new_empty(tuple(dim_list)) + + @impl(m, "npu_trans_quant_param") def npu_trans_quant_param_meta(scale, offset=None): scale_dim_num = scale.dim() diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 5c1903dcb1..59223893ed 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -20,6 +20,19 @@ class NPUOneHotOP(torch.autograd.Function): on_value_i=on_value, off_value_i=off_value) +class NPUTomeUnmergeOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_tome_unmerge(*args, **kwargs) + + @staticmethod + def symbolic(g, atten_out: Tensor, ori_indice_a: Tensor, ori_indice_b: Tensor, topk_indice: Tensor, + arg_max: Tensor, top_r_rate: float = 0.5): + return g.op("npu::NPUTomeUnmerge", atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, + top_r_rate_f=top_r_rate) + + class NPUSliceOP(torch.autograd.Function): @staticmethod @@ -790,6 +803,12 @@ def wrapper_npu_iou(bboxes, gtboxes, mode=0): return NPUIouOP.apply(bboxes, gtboxes, mode) +def wrapper_npu_tome_unmerge(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate=0.5): + return NPUTomeUnmergeOP.apply(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate) + + def wrapper_npu_batch_nms(self, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size, change_coordinate_frame=False, transpose_box=False): @@ -1041,6 +1060,7 @@ def add_onnx_ops(): torch_npu.npu_roi_align = wrapper_npu_roi_align torch_npu.npu_group_norm_silu = wrapper_npu_group_norm_silu torch_npu.npu_iou = wrapper_npu_iou + torch_npu.npu_tome_unmerge = wrapper_npu_tome_unmerge torch_npu.npu_batch_nms = wrapper_npu_batch_nms torch_npu.fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_fast_gelu = wrapper_npu_fast_gelu -- Gitee