diff --git a/test/npu/test_expandable_segments.py b/test/npu/test_expandable_segments.py index 7535581550d2687eb935afdf3efadd6ca1c1dd09..892952037c0b268d14726dd2679327fd0cd44d12 100644 --- a/test/npu/test_expandable_segments.py +++ b/test/npu/test_expandable_segments.py @@ -1,9 +1,11 @@ import os import gc +import unittest import torch import torch_npu from torch_npu.testing.testcase import TestCase, run_tests +from torch.testing._internal.common_utils import TestCase, run_tests, TEST_PRIVATEUSE1 os.environ["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True" @@ -36,5 +38,65 @@ class Test_expandable_segments(TestCase): torch_npu.npu.empty_cache() self.assertEqual(torch_npu.npu.memory_reserved(), prev) + @unittest.skipIf(TEST_PRIVATEUSE1, "NPU not available for graph capture") + def test_set_segment_state_to_checkpoint_when_expandable_segments(self): + def tensor_metadata(x): + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset(), } + + def reconstruct_from_tensor_metadata(metadata): + s = torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"]) + t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"]) + t.set_(source=s, storage_offset=metadata["storage_offset"], + size=metadata["size"], stride=metadata["stride"], ) + return t + + def cudagraphify(fn, inputs, pool, stream): + torch.npu.synchronize() + gc.collect() + torch.npu.empty_cache() + + graph = torch.npu.NPUGraph() + with torch.npu.graph(graph, stream=stream, pool=pool): + static_outputs = fn(*inputs) + return graph, static_outputs + + def foo(x, idx): + r1 = x.expand([1, 2097152 // 8]).sqrt() + r2 = x.expand([idx, 2097152]).clone() + return r1, r2 + + # init + pool_id = torch.npu.graph_pool_handle() + com_stream = torch.npu.Stream() + com_device = torch_npu.npu.current_device() + inp = torch.tensor([7]).npu() + + # start capture graph1 + graph1, outputs1 = cudagraphify(foo, [inp, 1], pool=pool_id, stream=com_stream) + graph1_state = torch_npu._C._npu_getCheckpointState(com_device, pool_id) + output1_metadata = [tensor_metadata(t) for t in outputs1] + outputs1 = None + + # start capture graph2 + graph2, outputs2 = cudagraphify(foo, [inp, 2], pool=pool_id, stream=com_stream) + graph2_state = torch_npu._C._npu_getCheckpointState(com_device, pool_id) + graph2.replay() + outputs2 = None + + # replay graph1 + graph1.replay() + reconstructed_tensors1 = [reconstruct_from_tensor_metadata(metadata) for metadata in output1_metadata] + output1_new_storage = [output.untyped_storage()._cdata for output in reconstructed_tensors1] + torch_npu._C._npu_setCheckpointPoolState(com_device, graph1_state, [], output1_new_storage) + + if __name__ == '__main__': run_tests() diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index 10dc59d4a2a5d3fb8926908327c4ac3d4ab82729..fe677fc2642aa270ffe8b978e0dbe9490f2e9d91 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -1918,8 +1918,13 @@ public: // available unmapped virtual address space. We shouldn't change it but // instead check it is correctly formed then skip over allocating it. if (i == segment_len - 1 && curr_block->expandable_segment_) { - TORCH_CHECK(curr_block->next == nullptr, PTA_ERROR(ErrCode::PTR)); - TORCH_CHECK(!curr_block->mapped, PTA_ERROR(ErrCode::PTR)); + // In case where expandable_segment is enabled, memory blocks will be merged when they are released. + // Therefore, when a small memory block is allocated first, followed by a larger memory block, + // and both are subsequently freed. If we attempt to restore the segment state by check point + // of the allocated small memory block, we will observe that the next memory block of the last block + // is not nullptr, and the last block is also a mapped block. + // This is reasonable because blocks are merged. Hence, we will remove those excessive validations. + // For more details, see https://github.com/pytorch/pytorch/issues/161356. TORCH_CHECK(curr_block->allocated == false, PTA_ERROR(ErrCode::VALUE)); continue; } @@ -1961,8 +1966,7 @@ public: for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) { if (i == segment_len - 1 && curr_block->expandable_segment_) { - TORCH_CHECK(curr_block->next == nullptr, PTA_ERROR(ErrCode::PTR)); - TORCH_CHECK(!curr_block->mapped, PTA_ERROR(ErrCode::PTR)); + // The same reason as above. TORCH_CHECK(curr_block->allocated == false, PTA_ERROR(ErrCode::VALUE)); continue; } @@ -1979,7 +1983,12 @@ public: TORCH_CHECK(curr_block->ptr == block_state.ptr, PTA_ERROR(ErrCode::PTR)); TORCH_CHECK(curr_block->allocated == block_state.allocated, PTA_ERROR(ErrCode::VALUE)); - TORCH_CHECK(curr_block->size == block_state.size, PTA_ERROR(ErrCode::VALUE)); + if (!curr_block->expandable_segment_) { + // In case where expandable_segment is enabled, memory blocks will be merged when they are released. + // The size of curr_block may be greater than the size of block_state. + // Therefore the block size assertion is also excessive in expandable_segment. + TORCH_CHECK(curr_block->size == block_state.size, PTA_ERROR(ErrCode::VALUE)); + } } } /* *