From 4b4ea012759882f8986aca045ea9b71f5621b379 Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Thu, 20 Oct 2022 19:49:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dindex=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E7=BC=96=E8=AF=91=E9=97=AE=E9=A2=98=20&&=20v?= =?UTF-8?q?iewcopy=E4=BF=AE=E6=94=B9=E5=85=A5=E5=8F=82=E7=BC=96=E8=AF=91?= =?UTF-8?q?=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/common/CopyKernelNpu.cpp | 12 ++++++------ torch_npu/csrc/aten/npu_native_functions.yaml | 2 +- torch_npu/csrc/aten/ops/IndexKernelNpu.cpp | 12 ++++-------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp index a4fac7ab3e..d4133b16fb 100644 --- a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp @@ -63,13 +63,13 @@ void copy_kernel_npu( OpCommand cmd; cmd.Name("ViewCopy") .InputWithoutContiguous(self) - .Input(self_size) - .Input(self_stride) - .InputScalarToNPUTensor(at::Scalar(0), at::kLong) + .Input(self_size, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(self_stride, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(at::Scalar(0), at::kLong) .InputWithoutContiguous(src) - .Input(src_size) - .Input(src_stride) - .InputScalarToNPUTensor(at::Scalar(0), at::kLong) + .Input(src_size, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(src_stride, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(at::Scalar(0), at::kLong) .Output(self) .Run(); diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1ea5baaf5a..b177754c30 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1841,7 +1841,6 @@ custom: - func: npu_conv3d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, *, Tensor(a!) out) -> Tensor(a!) - func: npu_conv3d_backward(Tensor input, Tensor grad, Tensor weight, int[] stride, int[] padding, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - func: get_npu_format(Tensor self) -> int - - func: npu_format_cast(Tensor self, int acl_format) -> Tensor - func: npu_format_cast.Tensor(Tensor self, Tensor dst) -> Tensor - func: npu_format_cast_.acl_format(Tensor(a!) self, int acl_format) -> Tensor(a!) - func: npu_format_cast_(Tensor(a!) self, Tensor src) -> Tensor(a!) @@ -1934,6 +1933,7 @@ custom_autograd: - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor - func: npu_max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - func: npu_max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + - func: npu_format_cast(Tensor self, int acl_format) -> Tensor - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor variants: function, method - func: npu_dtype_cast(Tensor self, ScalarType dtype) -> Tensor diff --git a/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp b/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp index bf3b1d21ea..06aa1bf93a 100644 --- a/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp @@ -21,14 +21,14 @@ namespace native { at::Tensor& index_out_nocheck_npu( const at::Tensor& self, - const at::Tensor& masksTensor, + const at::IntArrayRef masks, const at::TensorList& indices, at::Tensor& result) { OpCommand cmd; cmd.Name("Index") .Input(self) - .Input(masksTensor) - .Input(result.sizes()); + .Input(masks, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(result.sizes(), at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT); for (int i = 0; i < indices.size(); i++) { std::string name = "indices" + std::to_string(i); cmd.Input(indices[i], name); @@ -61,12 +61,8 @@ at::Tensor NPUNativeFunctions::index(const at::Tensor& self, const torch::List