diff --git a/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/indexing_opt.cpp b/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/indexing_opt.cpp index 50c9a7e066dc89da331215dbf12dd26fe9379be9..6b3cbd418e7f577f2a5a7589698a4470272e45ed 100644 --- a/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/indexing_opt.cpp +++ b/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/indexing_opt.cpp @@ -27,7 +27,7 @@ public: SmallVector step; if (can_use_indexing(src, start, end, step)) { - RECORD_FUNCTION("npuStridedSliceD", std::vector({src})); + RECORD_FUNCTION("npuStridedSlice", std::vector({src})); indexing_to_contiguous(src, self, start, end, step); return true; } @@ -74,7 +74,7 @@ private: // infer end index for (int64_t i = 0; i < src.dim() ; i++) { int64_t calculate_end = start[i] + src.size(i) * step[i]; - if (calculate_end > src.size(i)) { + if (calculate_end - step[i] > src_desc.base_sizes_[i]) // Op StrideSlice(Slice) don't support span-axis indexing(slice). return false; } diff --git a/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/permute_opt.cpp b/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/permute_opt.cpp index e449fc52a8ec39506c9c1d8bcad8c999c276fbf2..637c4f38af2cff80b63f65cb3b3566f47239c00b 100644 --- a/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/permute_opt.cpp +++ b/pytorch1.8.1/src/aten/src/ATen/native/npu/contiguous/permute_opt.cpp @@ -28,7 +28,7 @@ public: SmallVector sizes; if (can_use_permute(src, perm, sizes)) { // delete call and implementation, after more test - RECORD_FUNCTION("npuTransposeD", std::vector({src})); + RECORD_FUNCTION("npuTranspose", std::vector({src})); // create contiguous tensor for npu transpose Tensor temp_src = at::empty(sizes, src.options()); temp_src.set_(src.storage(), temp_src.storage_offset(), temp_src.sizes(), temp_src.strides()); diff --git a/pytorch1.8.1/test/test_npu/test_trans_contiguous/__init__.py b/pytorch1.8.1/test/test_npu/test_trans_contiguous/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..141ea8b10ada34996fa3d448d0ec15e471bb7a62 --- /dev/null +++ b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import numpy as np + +from common_utils import TestCase, run_tests +from common_device_type import instantiate_device_type_tests +from util_test import create_common_tensor_for_broadcast, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_broadcast_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32, np.int32, np.int8, np.uint8] + format_list = [-1] + shape_list = [ + [[1], [5]], + [[1, 2], [3, 2]], + [[1, 2, 1], [1, 2, 3]], + [[1, 2, 1, 3], [4, 2, 5, 3]], + [[1, 3], [1, 1, 4, 3]], + [[1, 3], [2, 1, 4, 3]], + [[1, 3], [1, 2, 4, 3]], + [[3, 1], [2, 1, 3, 1]], + [[3, 1], [1, 2, 3, 1]], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor_for_broadcast(item, 0, 100) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.expand(item[2][1]).contiguous() + self.assertEqual(check_operators_in_prof(['npuBroadcast'], prof), True, "npuBroadcast is not called!") + cpu_out1 = cpu_input.expand(item[2][1]).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_permute_copy_to_contiguous.py b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_permute_copy_to_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..5d496271f7c16b16bd3de4783ce66d56a26fa978 --- /dev/null +++ b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_permute_copy_to_contiguous.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import numpy as np + +from common_utils import TestCase, run_tests +from common_device_type import instantiate_device_type_tests +from util_test import create_common_tensor, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_permute_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [[2, 6, 9, 4]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.permute(1,0,2,3).contiguous() + self.assertEqual(check_operators_in_prof(['npuTranspose'], prof), True, "NpuTranspose op is not called!") + + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(2,3,0,1).contiguous() + self.assertEqual(check_operators_in_prof(['npuTranspose'], prof), True, "NpuTranspose op is not called!") + + cpu_out1 = cpu_input.permute(1,0,2,3).contiguous() + cpu_out2 = cpu_input.permute(2,3,0,1).contiguous() + + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bb4e582781cc40227502333fa21e62f39e8366 --- /dev/null +++ b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py @@ -0,0 +1,197 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import numpy as np + +from common_utils import TestCase, run_tests +from common_device_type import instantiate_device_type_tests +from util_test import create_common_tensor, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_view_copy(self, device): + dtype_list = [np.float16, np.float32] + format_list = [0, 3, 29] + shape_list = [ + #No padding for NZ format + [2, 3, 16, 16], + #Padding for NZ format + [2, 3, 15, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + # case1. base format + match_case1 = (item[1] == 0) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(1, 6, npu_input.size(2), npu_input.size(3)).clone() + # case2. The key axis remains unchanged for NZ format + match_case2 = (item[1] == 29) + if match_case1 or match_case2: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + cpu_out1 = cpu_input.view(1, 6, cpu_input.size(2), cpu_input.size(3)).clone() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # The key axis changes for NZ format + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.view(1, 6, npu_input.size(2)*npu_input.size(3), 1).clone() + if match_case1: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + cpu_out2 = cpu_input.view(1, 6, cpu_input.size(2)*cpu_input.size(3), 1).clone() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_unsqueeze_copy(self, device): + dtype_list = [np.float16, np.float32] + format_list = [2, 3, 29] + shape_list = [ + [3, 16, 16], + [3, 15, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for i in range(3): + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + # case1. base format + match_case1 = (item[1] == 2) + # case2. The key axis remains unchanged for NZ format + match_case2 = (item[1] == 29 and i < 2) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = npu_input.unsqueeze(i).clone() + if match_case1 or match_case2: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + cpu_out = cpu_input.unsqueeze(i).clone() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + + def test_flatten_copy(self, device): + dtype_list = [np.float16, np.float32] + format_list = [0, 3, 29] + shape_list = [ + [2, 3, 16, 16], + [2, 3, 16, 15], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = torch.flatten(npu_input, 0, 1).clone() + if item[1] == 3: + # Using d2dcopy with transdata(d2dCopyAsync) + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + else: + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + + cpu_out = torch.flatten(cpu_input, 0, 1).clone() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + + def test_narrow_at_first_axis_copy(self, device): + # this case: slice at the first dim, tensor with offset remains contiguous + dtype_list = [np.float16, np.float32] + format_list = [2, 3, 29] + shape_list = [ + [20, 16, 16], + [20, 16, 15], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + # The key axis remains unchanged for NZ format in all cases. + # case1. base format + match_case1 = (item[1] == 2) + # case2. NZ format but no padding + match_case2 = (item[1] == 29 and item[2] == [20, 16, 16]) + + # contiguous and no offset + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[:10,:,:].clone() + # case3. NZ format with padding but no offset + match_case3 = (item[1] == 29 and True) + if match_case1 or match_case2 or match_case3: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + cpu_out1 = cpu_input[:10,:,:].clone() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # contiguous but has offset + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[1:10,:,:].clone() + match_case3 = False + if match_case1 or match_case2 or match_case3: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + cpu_out2 = cpu_input[1:10,:,:].clone() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_select_at_first_axis_to_single_element_tensor_copy(self, device): + dtype_list = [torch.float32] + format_list = [2, 3, 29] + shape_format = [ + [i, j] for i in dtype_list for j in format_list + ] + + for item in shape_format: + cpu_input = torch.tensor([1.0]).to(item[0]) + npu_input = cpu_input.npu().npu_format_cast(item[1]) + + match_case = (item[1] == 2) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[0].clone() + if match_case: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + cpu_out1 = cpu_input[0].clone() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[0] + 1 + if match_case: + self.assertEqual(check_operators_in_prof(['memory_repoint'], prof), True, "memory_repoint is not called!") + else: + # refresh storage desc after transdata + self.assertEqual(check_operators_in_prof(['Identity'], prof), True, "Identity is not called!") + cpu_out2 = cpu_input[0] + 1 + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_slice_copy_to_contiguous.py b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_slice_copy_to_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..75835d0bf21094cb3e31725d40311dffbec2450c --- /dev/null +++ b/pytorch1.8.1/test/test_npu/test_trans_contiguous/test_single_slice_copy_to_contiguous.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import numpy as np + +from common_utils import TestCase, run_tests +from common_device_type import instantiate_device_type_tests +from util_test import create_common_tensor, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_narrow_copy_contiguous(self, device): + # AssertionError: required dtype in [np.bool, np.int32, np.float16, np.float32, np.int8, np.uint8, np.int64] + # However, considering the dtypes that Transdata supports, only np.float16, np.float32 are tested. + dtype_list = [np.float16, np.float32] + format_list_4D = [0, 3, 29, 4] + shape_list_4D = [[2, 32, 16, 20]] + format_list_5D = [30, 32, 33] + shape_list_5D = [[2, 32, 16, 20, 15]] + shape_format_4D = [ + [i, j, k] for i in dtype_list for j in format_list_4D for k in shape_list_4D + ] + shape_format_5D = [ + [i, j, k] for i in dtype_list for j in format_list_5D for k in shape_list_5D + ] + shape_format = shape_format_4D + shape_format_5D + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # for narrow with step=1, if narrow at the first axis, it will generate a contiguous tensor + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[:,:16,:,:].contiguous() + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[:,:,1:16,:].contiguous() + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out3 = npu_input[:,:,:,2:16].contiguous() + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + + cpu_out1 = cpu_input[:,:16,:,:].contiguous() + cpu_out2 = cpu_input[:,:,1:16,:].contiguous() + cpu_out3 = cpu_input[:,:,:,2:16].contiguous() + if npu_input.dim() == 5: + cpu_out4 = cpu_input[:,:,:,:,3:10].contiguous() + npu_out4 = npu_input[:,:,:,:,3:10].contiguous() + self.assertRtolEqual(npu_out4.to("cpu").numpy(), cpu_out4.numpy()) + + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + self.assertRtolEqual(npu_out3.to("cpu").numpy(), cpu_out3.numpy()) + + def test_strideslice_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32, np.int8, np.int32, np.uint8, np.bool] + format_list = [-1] + shape_list = [[10,32,16,9], [10,32,16,9,10]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # for indexing with step>1 -- stridedSlice + if cpu_input.dim() == 4: + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[::2].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[:,1:17:4].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out3 = npu_input[:,:,2:16:5].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + # stridedSlice do not support slice at last dim + npu_out4 = npu_input[:,:,:,3:9:2].contiguous() + self.assertEqual(check_operators_in_prof(['d2dCopyWithPTCopy'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out5 = npu_input[::2,1:17:4,2:16:5,:].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + + cpu_out1 = cpu_input[::2].contiguous() + cpu_out2 = cpu_input[:,1:17:4].contiguous() + cpu_out3 = cpu_input[:,:,2:16:5].contiguous() + cpu_out4 = cpu_input[:,:,:,3:9:2].contiguous() + #strideslice at each axis + cpu_out5 = cpu_input[::2,1:17:4,2:16:5,:].contiguous() + + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + self.assertRtolEqual(npu_out3.to("cpu").numpy(), cpu_out3.numpy()) + self.assertRtolEqual(npu_out4.to("cpu").numpy(), cpu_out4.numpy()) + self.assertRtolEqual(npu_out5.to("cpu").numpy(), cpu_out5.numpy()) + if cpu_input.dim() == 5: + cpu_out6 = cpu_input[:,:,:,:,1:7:3].contiguous() + npu_out6 = npu_input[:,:,:,:,1:7:3].contiguous() + self.assertRtolEqual(npu_out6.to("cpu").numpy(), cpu_out6.numpy()) + + def test_select_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [[2,32,16,9], [2,32,16,9,10]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + for dim in range(1,len(item[2])): + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = npu_input.select(dim,1).contiguous() + self.assertEqual(check_operators_in_prof(['select_npuStridedSlice'], prof), True, "select_npuStridedSlice is not called!") + cpu_out = cpu_input.select(dim,1).contiguous() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + + def test_span_axis_strideslice_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [[32,8,2], [(8,6,2), (5,4,1), 1]] + shape_format = [ + [i, j, shape_list[0]] for i in dtype_list for j in format_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # npuStrideSlice do not support span-axis strideslice, can not be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = torch.as_strided(npu_input, shape_list[1][0], shape_list[1][1], shape_list[1][2]).contiguous() + self.assertEqual(check_operators_in_prof(['d2dCopyWithPTCopy'], prof, ['npuStridedSlice']), True, "Error operators called!") + cpu_out = torch.as_strided(cpu_input, shape_list[1][0], shape_list[1][1], shape_list[1][2]).contiguous() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/pytorch1.8.1/test/test_npu/test_trans_contiguous/util_test.py b/pytorch1.8.1/test/test_npu/test_trans_contiguous/util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7306e5adb136827650d0f297cbf4840ba6b254d2 --- /dev/null +++ b/pytorch1.8.1/test/test_npu/test_trans_contiguous/util_test.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import numpy as np +import torch + +common_path = os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "common") +if common_path not in sys.path: + sys.path.append(common_path) +from util_test_new import create_common_tensor, test_2args_broadcast,\ + create_dtype_tensor, create_common_tensor_for_broadcast + +def check_operators_in_prof(expected_operators, prof, unexpected_operators=None): + unexpected_operators = unexpected_operators or [] + prof_key_averages = prof.key_averages() + if not prof_key_averages: + return print("torch profiling is empty, please check it") + for prof_item in prof_key_averages: + if prof_item.key in unexpected_operators: + # if unexpected oprators are called, pattern inferring in trans-contiguous is failed + return False + elif prof_item.key in expected_operators: + # if expected oprator is called, empty it in expected_operators list + expected_operators.remove(prof_item.key) + + # if expected_operators list is empty, all oprators have been called + if not expected_operators: + return True + return False \ No newline at end of file