From 67ff113c12d94893128662dee3ab0e2398241673 Mon Sep 17 00:00:00 2001 From: ylw1234 Date: Fri, 12 Sep 2025 17:21:49 +0800 Subject: [PATCH] transfer_to_npu fit some apis --- test/contrib/test_transfer_to_npu.py | 87 +++++++++++++++++++++++++++- torch_npu/contrib/transfer_to_npu.py | 48 ++++++++++++++- 2 files changed, 129 insertions(+), 6 deletions(-) diff --git a/test/contrib/test_transfer_to_npu.py b/test/contrib/test_transfer_to_npu.py index 8c83a6e0d7..1703f98260 100644 --- a/test/contrib/test_transfer_to_npu.py +++ b/test/contrib/test_transfer_to_npu.py @@ -80,7 +80,7 @@ class TestTransferToNpu(TestCase): c = a @ b c_after_autocast_kwargs = c.dtype - self.assertNotEqual(c_before_autocast, c_after_autocast_args) + self.assertNotEqual(c_before_autocast, c_after_autocast_args) self.assertEqual(c_after_autocast_args, c_after_autocast_kwargs) def test_device_meta(self): @@ -114,7 +114,7 @@ class TestTransferToNpu(TestCase): def test_replace_to_method_in_allowed_methods(self): for method in UninitializedTensorMixin._allowed_methods: if method.__name__ == "to": - self.assertFalse(hasattr(method, "__self__")) # 替换后torch.Tensor.to变成普通函数,而不是原来的绑定方法 + self.assertFalse(hasattr(method, "__self__")) # 替换后torch.Tensor.to变成普通函数,而不是原来的绑定方法 break @patch('torch.distributed.Backend') @@ -145,7 +145,7 @@ class TestTransferToNpu(TestCase): transfer_to_npu._del_nccl_device_backend_map() # 没有抛出异常,测试通过 - + def test_input_validation(self): # Test file is a link with patch('os.path.islink', return_value=True): @@ -270,6 +270,87 @@ class TestTransferToNpu(TestCase): self.assertIn('experimental_config', result) self.assertIs(result['experimental_config'], correct_config) + def test_torch_Event(self): + event = torch.Event(device='cuda:0', enable_timing=True) + self.assertEqual(event.device, 'npu') + + def test_torch_get_device_module(self): + device_module1 = torch.get_device_module(device='cuda:1') + device_module2 = torch.get_device_module(device=torch.device('cuda:1')) + npu_device_module = torch.get_device_module(device='npu:1') + self.assertEqual(device_module1, npu_device_module) + self.assertEqual(device_module2, npu_device_module) + + def test_torch_cuda_can_device_access_peer(self): + can_access = torch.cuda.can_device_access_peer(torch.device('cuda:0'), torch.device('cuda:1')) + can_access_reverse = torch.cuda.can_device_access_peer(torch.device('cuda:1'), torch.device('cuda:0')) + self.assertTrue(can_access) + self.assertTrue(can_access_reverse) + + def test_torch_cuda_current_stream(self): + cur_stream = torch.cuda.current_stream(1) + cur_npu_stream = torch.npu.current_stream(1) + self.assertEqual(cur_stream, cur_npu_stream) + cur_stream = torch.cuda.current_stream(torch.device('cuda:0')) + cur_npu_stream = torch.npu.current_stream(torch.device('npu:0')) + self.assertEqual(cur_stream, cur_npu_stream) + + def test_torch_cuda_utilization(self): + # 获取device利用率 + use = torch.cuda.utilization(1) + self.assertEqual(use, 0) + use = torch.cuda.utilization(torch.device('cuda:1')) + self.assertEqual(use, 0) + + def test_torch_cuda_set_per_process_memory_fraction(self): + torch.cuda.set_per_process_memory_fraction(0.5, device=1) + torch.cuda.set_per_process_memory_fraction(1.0, device=torch.device('cuda:1')) + + def test_torch_cuda_caching_allocator_alloc(self): + size = 1024 * 1024 + ptr1 = torch.cuda.caching_allocator_alloc(size, 1) + ptr2 = torch.cuda.caching_allocator_alloc(size, torch.device('cuda:1')) + torch.cuda.caching_allocator_delete(ptr1) + torch.cuda.caching_allocator_delete(ptr2) + + def test_torch_cuda_memory(self): + device = torch.device('cuda:1') + a = torch.randn((1000, 1000), device=device) + torch.cuda.memory._record_memory_history(device=device) + snapshot = torch.cuda.memory._snapshot(device=device) + + def test_torch_fft(self): + freq_fft = torch.fft.fftfreq(5, device=torch.device('cuda:1')) + self.assertEqual(freq_fft, 'npu:1') + freq_rfft = torch.fft.rfftfreq(5, device=torch.device('cuda:1')) + self.assertEqual(freq_rfft, 'npu:1') + + def test_torch_autograd_profiler_util_Kernel(self): + kernel = torch.autograd.profiler_util.Kernel("model_inference", 'cuda', 11) + self.assertEqual(kernel.device, 'npu:0') + + def test_torch_sparse_compressed_tensor(self): + # 定义压缩的索引 + compressed_indices = torch.tensor([0, 3, 5], dtype=torch.int64) # 表示行索引的压缩 + plain_indices = torch.tensor([[0, 0], [1, 2], [2, 3]], dtype=torch.int64) # 非压缩索引 + values = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) # 非零值 + + # 创建稀疏压缩矩阵 + sparse_tensor = torch.sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + size=(3, 4), + dtype=torch.float32, + layout=torch.sparse_csr, + device=torch.device('cuda:1'), + requires_grad=True + ) + self.assertEqual(sparse_tensor.device, 'npu:1') + + def test_torch_utils_cpp_extension_include_paths(self): + torch.utils.cpp_extension.include_paths(device_type='cuda') + if __name__ == "__main__": run_tests() diff --git a/torch_npu/contrib/transfer_to_npu.py b/torch_npu/contrib/transfer_to_npu.py index d9f1eb393e..8c606068b1 100644 --- a/torch_npu/contrib/transfer_to_npu.py +++ b/torch_npu/contrib/transfer_to_npu.py @@ -12,6 +12,8 @@ from torch.utils._device import _device_constructors from torch.utils._triton import has_triton from torch.nn.parameter import UninitializedTensorMixin from torch._utils import _get_device_module +from torch.utils import cpp_extension +from torch.autograd.profiler_util import Kernel import torch_npu try: @@ -31,7 +33,8 @@ torch_fn_white_list = ['logspace', 'randint', 'hann_window', 'rand', 'full_like' 'eye', '_sparse_csr_tensor_unsafe', 'empty', '_sparse_coo_tensor_unsafe', 'blackman_window', 'zeros_like', 'range', 'sparse_csr_tensor', 'randn_like', 'from_file', '_cudnn_init_dropout_state', '_empty_affine_quantized', 'linspace', 'hamming_window', - 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device'] + 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device', 'Event', + 'get_device_module', 'sparse_compressed_tensor'] torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', 'pin_memory'] torch_module_fn_white_list = ['to', 'to_empty'] @@ -39,7 +42,8 @@ torch_cuda_fn_white_list = [ 'get_device_properties', 'get_device_name', 'get_device_capability', 'list_gpu_processes', 'set_device', 'synchronize', 'mem_get_info', 'memory_stats', 'memory_summary', 'memory_allocated', 'max_memory_allocated', 'reset_max_memory_allocated', 'memory_reserved', 'max_memory_reserved', 'reset_max_memory_cached', - 'reset_peak_memory_stats', 'default_stream' + 'reset_peak_memory_stats', 'default_stream', 'can_device_access_peer', 'current_stream', 'utilization', + 'set_per_process_memory_fraction', 'caching_allocator_alloc' ] torch_distributed_fn_white_list = ['__init__'] device_kwargs_list = ['device', 'device_type', 'map_location', 'device_id'] @@ -342,6 +346,25 @@ def _del_nccl_device_backend_map(): del torch.distributed.Backend.default_device_backend_map['cuda'] +def _patch_nametuple(nametuple): + original__new__ = nametuple.__new__ + + def new_nametuple__new__(cls, *args, **kwargs): + if args: + args_new = list(args) + args = _replace_cuda_to_npu_in_list(args_new, False) + if kwargs: + for device_arg in device_kwargs_list: + device = kwargs.get(device_arg, None) + if device is not None: + _replace_cuda_to_npu_in_kwargs(kwargs, device_arg, device) + device_ids = kwargs.get('device_ids', None) + if isinstance(device_ids, list): + device_ids = _replace_cuda_to_npu_in_list(device_ids, False) + return original__new__(cls, *args, **kwargs) + nametuple.__new__ = new_nametuple__new__ + + def _init(): _warning_fn(''' ************************************************************************************************************* @@ -362,7 +385,13 @@ def _init(): _device_wrapper(torch.cuda, torch_cuda_fn_white_list) torch.cuda.device.__init__ = _wrapper_cuda(torch.cuda.device.__init__) + # torch.cuda.memory.* + _device_wrapper(torch.npu.memory, ['_record_memory_history', '_snapshot']) + torch.cuda.memory._record_memory_history = torch.npu.memory._record_memory_history + torch.cuda.memory._snapshot = torch.npu.memory._snapshot + # torch.profiler.* + _device_wrapper(torch_npu.profiler._KinetoProfile, ['export_memory_timeline']) _patch_profiler() torch.profiler.profile = _wrapper_profiler(torch.profiler.profile) @@ -381,6 +410,9 @@ def _init(): _device_wrapper(torch.nn.Module, torch_module_fn_white_list) torch.nn.Module.cuda = torch.nn.Module.npu + # torch.fft.* + _device_wrapper(torch.fft, ['fftfreq', 'rfftfreq']) + # torch.distributed torch.distributed.init_process_group = _wrapper_hccl(torch.distributed.init_process_group) torch.distributed.is_nccl_available = torch.distributed.is_hccl_available @@ -392,7 +424,12 @@ def _init(): torch.distributed.device_mesh.init_device_mesh = _wrapper_cuda(torch.distributed.device_mesh.init_device_mesh) torch.distributed.distributed_c10d._new_group_with_tag = _wrapper_hccl( torch.distributed.distributed_c10d._new_group_with_tag) - + torch.distributed.divice_mesh.DeviceMesh = _wrapper_cuda(torch.distributed.divice_mesh.DeviceMesh) + + # torch.distributed.pipelining.* + if hasattr(torch.distributed, 'pipelining'): + _device_wrapper(torch.distributed.pipelining.stage, ['PipelineStage', 'build_stage']) + # CUDAGraph torch.cuda.CUDAGraph = torch.npu.NPUGraph @@ -415,5 +452,10 @@ def _init(): _replace_to_method_in_allowed_methods() + setattr(torch.utils, 'cpp_extension', cpp_extension) + _device_wrapper(torch.utils.cpp_extension, ['include_paths']) + + _patch_nametuple(Kernel) + _init() -- Gitee