diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index a7cfdcef152bc46ebd9453c429c46048ece65e10..f0e506cdfc7cb4860a6652f6b66cae8e9794c4e6 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -394,6 +394,15 @@ def npu_quant_matmul_meta(x1, x2, scale, offset=None, bias=None, output_dtype=No return shape_long.new_empty(tuple(dim_list), dtype=torch.int8) +@impl(m, "npu_tome_unmerge") +def npu_tome_unmerge_meta(atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, top_r_rate=0.5): + dim_list = [] + dim_list.append(atten_out.size(0)) + dim_list.append(ori_indice_a.size(1) + ori_indice_b.size(1)) + dim_list.append(atten_out.size(2)) + return atten_out.new_empty(tuple(dim_list)) + + @impl(m, "npu_trans_quant_param") def npu_trans_quant_param_meta(scale, offset=None): scale_dim_num = scale.dim() diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 5c1903dcb1bd9fd0ed9dc66e7ea9af6a4739d1b3..59223893eda143be89b948f95f850457663de32a 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -20,6 +20,19 @@ class NPUOneHotOP(torch.autograd.Function): on_value_i=on_value, off_value_i=off_value) +class NPUTomeUnmergeOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_tome_unmerge(*args, **kwargs) + + @staticmethod + def symbolic(g, atten_out: Tensor, ori_indice_a: Tensor, ori_indice_b: Tensor, topk_indice: Tensor, + arg_max: Tensor, top_r_rate: float = 0.5): + return g.op("npu::NPUTomeUnmerge", atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, + top_r_rate_f=top_r_rate) + + class NPUSliceOP(torch.autograd.Function): @staticmethod @@ -790,6 +803,12 @@ def wrapper_npu_iou(bboxes, gtboxes, mode=0): return NPUIouOP.apply(bboxes, gtboxes, mode) +def wrapper_npu_tome_unmerge(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate=0.5): + return NPUTomeUnmergeOP.apply(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate) + + def wrapper_npu_batch_nms(self, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size, change_coordinate_frame=False, transpose_box=False): @@ -1041,6 +1060,7 @@ def add_onnx_ops(): torch_npu.npu_roi_align = wrapper_npu_roi_align torch_npu.npu_group_norm_silu = wrapper_npu_group_norm_silu torch_npu.npu_iou = wrapper_npu_iou + torch_npu.npu_tome_unmerge = wrapper_npu_tome_unmerge torch_npu.npu_batch_nms = wrapper_npu_batch_nms torch_npu.fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_fast_gelu = wrapper_npu_fast_gelu