diff --git a/torch_npu/csrc/core/npu/NPUBlockHandle.h b/torch_npu/csrc/core/npu/NPUBlockHandle.h index 1641b3414c92bcb453c7b026bff8703f2bf00d2f..f2d98ead9db15c36a0d598b10b2f6465980366f4 100644 --- a/torch_npu/csrc/core/npu/NPUBlockHandle.h +++ b/torch_npu/csrc/core/npu/NPUBlockHandle.h @@ -32,5 +32,12 @@ C10_NPU_API void* GetBlockPtr(const void *handle); /// @param [in] handle: the block handle to query size /// @return size: the device memory size managed by block C10_NPU_API size_t GetBlockSize(const void *handle); + +/// @ingroup torch_npu +/// @brief Get device shut_down status of the block according to handle +/// @param [in] handle: the block handle to query device shut down status +/// @return size: true : means device is shutdown +/// false : means device is activate when query +C10_NPU_API bool GetDeviceShutDownStatusByHandle(const void *handle); } // namespace NPUCachingAllocator } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index 73f2d3698a76bf71bf55ff97f436f99a4b5d3b79..e19178f6d0e0f02e2572b75f0359528d41ec24af 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -1082,10 +1082,16 @@ class DeviceCachingAllocator { release_cached_blocks(check_error); } - void devSetShutdownStats() { + void setDevShutdownStats() { + std::lock_guard lock(mutex); shutdown_stats = true; } + bool getDevShutdownStats() { + std::lock_guard lock(mutex); + return shutdown_stats; + } + /** Retrieves info (total size + largest block) of the memory cache **/ void cacheInfo(size_t* total, size_t* largest) { std::lock_guard lock(mutex); @@ -2005,7 +2011,7 @@ class THNCachingAllocator { void THNSetShutdownStats() { int count = static_cast(device_allocator.size()); for (int i = 0; i < count; i++) - device_allocator[i]->devSetShutdownStats(); + device_allocator[i]->setDevShutdownStats(); } void* getBaseAllocation(void* ptr, size_t* outSize) { @@ -2240,6 +2246,14 @@ size_t GetBlockSize(const void *handle) { return block->size; } +bool GetDeviceShutDownStatusByHandle(const void *handle) { + const Block *block = reinterpret_cast(handle); + AT_ASSERT(block); + assertValidDevice(block->device); + AT_ASSERT(caching_allocator.device_allocator[block->device]); + return caching_allocator.device_allocator[block->device]->getDevShutdownStats(); +} + std::string name() { return "native"; }