diff --git a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp index a4fac7ab3ef41cac06b9073286c959a99a717461..d4133b16fb3fed3dae4eac564c339cae804e49ce 100644 --- a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp @@ -63,13 +63,13 @@ void copy_kernel_npu( OpCommand cmd; cmd.Name("ViewCopy") .InputWithoutContiguous(self) - .Input(self_size) - .Input(self_stride) - .InputScalarToNPUTensor(at::Scalar(0), at::kLong) + .Input(self_size, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(self_stride, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(at::Scalar(0), at::kLong) .InputWithoutContiguous(src) - .Input(src_size) - .Input(src_stride) - .InputScalarToNPUTensor(at::Scalar(0), at::kLong) + .Input(src_size, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(src_stride, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(at::Scalar(0), at::kLong) .Output(self) .Run(); diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1ea5baaf5ae304840af8ddd3aec3d9b9800ade2e..b177754c30a541f7612318d4f8a5312d57b8c2ab 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1841,7 +1841,6 @@ custom: - func: npu_conv3d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, *, Tensor(a!) out) -> Tensor(a!) - func: npu_conv3d_backward(Tensor input, Tensor grad, Tensor weight, int[] stride, int[] padding, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - func: get_npu_format(Tensor self) -> int - - func: npu_format_cast(Tensor self, int acl_format) -> Tensor - func: npu_format_cast.Tensor(Tensor self, Tensor dst) -> Tensor - func: npu_format_cast_.acl_format(Tensor(a!) self, int acl_format) -> Tensor(a!) - func: npu_format_cast_(Tensor(a!) self, Tensor src) -> Tensor(a!) @@ -1934,6 +1933,7 @@ custom_autograd: - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor - func: npu_max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - func: npu_max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + - func: npu_format_cast(Tensor self, int acl_format) -> Tensor - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor variants: function, method - func: npu_dtype_cast(Tensor self, ScalarType dtype) -> Tensor diff --git a/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp b/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp index bf3b1d21eaafa42aa7eb96432724613ff1e2f66e..06aa1bf93a32c1d559cce9b2d41b2fa45c6c674e 100644 --- a/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/IndexKernelNpu.cpp @@ -21,14 +21,14 @@ namespace native { at::Tensor& index_out_nocheck_npu( const at::Tensor& self, - const at::Tensor& masksTensor, + const at::IntArrayRef masks, const at::TensorList& indices, at::Tensor& result) { OpCommand cmd; cmd.Name("Index") .Input(self) - .Input(masksTensor) - .Input(result.sizes()); + .Input(masks, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT) + .Input(result.sizes(), at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT); for (int i = 0; i < indices.size(); i++) { std::string name = "indices" + std::to_string(i); cmd.Input(indices[i], name); @@ -61,12 +61,8 @@ at::Tensor NPUNativeFunctions::index(const at::Tensor& self, const torch::List