From a69e2bc5e60906415f94911e9fee2daf2b4cbcff Mon Sep 17 00:00:00 2001 From: TD_lihan Date: Fri, 23 Feb 2024 17:30:08 +0800 Subject: [PATCH] add npu_tome_unmerge --- torch_npu/meta/meta_registrations.py | 8 ++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 7193e4d899..f901475f06 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -224,6 +224,14 @@ def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, ant return x1.new_empty(tuple(dim_list)) +@impl(m, "npu_tome_unmerge") +def npu_tome_unmerge_meta(atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, top_r_rate=0.5): + dim_list = [] + dim_list.append(atten_out.size(0)) + dim_list.append(ori_indice_a.size(1) + ori_indice_b.size(1)) + dim_list.append(atten_out.size(2)) + return atten_out.new_empty(tuple(dim_list)) + @impl(m, "npu_weight_quant_batchmatmul") def npu_weight_quant_batchmatmul_meta(x, weight, antiquant_scale, antiquant_offset=None, quant_scale=None, quant_offset=None, bias=None, antiquant_group_size=0): diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 1fde7f18df..d0a30ac9bd 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -651,6 +651,19 @@ class NPURotaryMulOP(torch.autograd.Function): return g.op("npu::NPURotaryMul", x, r1, r2) +class NPUTomeUnmergeOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_tome_unmerge(*args, **kwargs) + + @staticmethod + def symbolic(g, atten_out: Tensor, ori_indice_a: Tensor, ori_indice_b: Tensor, topk_indice: Tensor, + arg_max: Tensor, top_r_rate: float = 0.5): + return g.op("npu::NPUTomeUnmerge", atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, + top_r_rate_f=top_r_rate) + + class NPUPromptFlashAttentionOP(torch.autograd.Function): @staticmethod @@ -813,6 +826,11 @@ def wrapper_npu_batch_nms(self, scores, score_threshold, iou_threshold, def wrapper_npu_fast_gelu(self): return NPUFastGeluOP.apply(self) +def wrapper_npu_tome_unmerge(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate=0.5): + return NPUTomeUnmergeOP.apply(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate) + def wrapper_npu_geglu(self, dim=-1, approximate=1, activate_left=False): return NPUGeGluOP.apply(self, dim, approximate, activate_left) @@ -1060,6 +1078,7 @@ def add_onnx_ops(): torch_npu.npu_batch_nms = wrapper_npu_batch_nms torch_npu.fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_fast_gelu = wrapper_npu_fast_gelu + torch_npu.npu_tome_unmerge = wrapper_npu_tome_unmerge torch_npu.npu_geglu = wrapper_npu_geglu torch_npu.npu_fused_attention_score = wrapper_npu_fused_attention_score torch_npu.npu_ciou = wrapper_npu_ciou -- Gitee