diff --git a/test/test_network_ops/test_masked_scatter.py b/test/test_network_ops/test_masked_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..e509e089c390630c307350c2c0f4d80e613f43a3 --- /dev/null +++ b/test/test_network_ops/test_masked_scatter.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestMaskedScatter(TestCase): + def cpu_op_exec(self, input1, maskbool, source): + cpu_output = torch.masked_scatter(input1, maskbool, source) + return cpu_output.numpy() + + def npu_op_exec(self, input1, maskbool, source): + input1 = input1.to("npu") + maskbool = maskbool.to("npu") + source = source.to("npu") + npu_output = torch.masked_scatter(input1, maskbool, source) + npu_output = npu_output.to("cpu") + return npu_output.numpy() + + def cpu_inp_op_exec(self, input1, maskbool, source): + cpu_output = input1.masked_scatter_(maskbool, source) + return cpu_output.numpy() + + def npu_inp_op_exec(self, input1, maskbool, source): + maskbool = maskbool.to("npu") + npu_output = input1.masked_scatter_(maskbool, source) + npu_output = npu_output.to("cpu") + return npu_output.numpy() + + def test_masked_scatter_float(self, device): + dtype_list = [np.float32] + format_list = [0, 3] + shape_list = [[4, 5],[3, 4, 5], [2, 3, 4, 5]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + mask = torch.randn(4, 1) + maskbool = mask.ge(0.5) + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_source, npu_source = create_common_tensor(item, 0, 100) + cpu_output2 = self.cpu_inp_op_exec(cpu_input, maskbool, cpu_source) + npu_output2 = self.npu_inp_op_exec(npu_input, maskbool, npu_source) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_masked_scatter_int(self, device): + dtype_list = [np.int32, np.int64] + format_list = [0] + shape_list = [[4, 5],[3, 4, 5], [2, 3, 4, 5]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + mask = torch.randn(4, 1) + maskbool = mask.ge(0.5) + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_source, npu_source = create_common_tensor(item, 0, 100) + cpu_output2 = self.cpu_inp_op_exec(cpu_input, maskbool, cpu_source) + npu_output2 = self.npu_inp_op_exec(npu_input, maskbool, npu_source) + self.assertRtolEqual(cpu_output2, npu_output2) + +instantiate_device_type_tests(TestMaskedScatter, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/MaskedScatterKernelNpu.cpp b/torch_npu/csrc/aten/ops/MaskedScatterKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24931b3f566b7bb13acc0f04686fa5d3f77d8212 --- /dev/null +++ b/torch_npu/csrc/aten/ops/MaskedScatterKernelNpu.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& masked_scatter_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& mask, + const at::Tensor& source) { + at::Tensor maskBool = mask; + if (!(mask.dtype() == at::kBool)) { + maskBool = NPUNativeFunctions::npu_dtype_cast(maskBool, at::kBool); + } + + OpCommand cmd; + cmd.Name("MaskedScatter") + .Input(self) + .Input(maskBool) + .Input(source) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::masked_scatter_( + at::Tensor& self, + const at::Tensor& mask, + const at::Tensor& source) { + c10::SmallVector inputs = {self, mask, source}; + c10::SmallVector outputs = {self}; + CalcuOpUtil::check_memory_over_laps(inputs, outputs); + + at::Tensor selfFp32 = self; + at::Tensor sourceFp32 = source; + at::ScalarType selfType = self.scalar_type(); + if (selfType == at::ScalarType::Half) { + selfFp32 = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); + sourceFp32 = NPUNativeFunctions::npu_dtype_cast(source, at::ScalarType::Float); + } + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(selfFp32); + at::Tensor result = masked_scatter_out_npu_nocheck(contiguousSelf, contiguousSelf, mask, sourceFp32); + NpuUtils::format_fresh_view(self, result); + } else { + masked_scatter_out_npu_nocheck(selfFp32, selfFp32, mask, sourceFp32); + self.copy_(selfFp32); + } + + return (self.scalar_type() != selfType) ? self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Half) : self; +} +} // namespace native +} // namespace at_npu \ No newline at end of file