diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index a5812fe4297b361b0ee51a62710931ce7b4ac17f..541f3ef61870cfb1868b5ed7bdf7ba23263c3915 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -1106,6 +1106,9 @@ "torch_npu.npu.check_uce_in_memory": { "signature": "(device_id)" }, + "torch_npu.npu.ipc_collect": { + "signature": "()" + }, "torch_npu.npu.clear_npu_overflow_flag": { "signature": "()" }, @@ -1649,6 +1652,9 @@ "torch_npu.npu.utils.check_uce_in_memory": { "signature": "(device_id)" }, + "torch_npu.npu.utils.ipc_collect": { + "signature": "()" + }, "torch_npu.npu.utils.clear_npu_overflow_flag": { "signature": "()" }, diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index 6d5ff43884ddf5ab9df90b489940195680d52f52..9e2b73e6ddfb9f97e5c2de4bfd6160d30ff984e1 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -586,9 +586,7 @@ private: // cannot call c10::npu::stream_synchronize because // it might grab the GIL which can lead to a deadlock // Locking order must be GIL -> Allocator Lock - if (stream_) { - NPU_CHECK_ERROR(aclrtSynchronizeStream(*stream_)); - } else { + { c10_npu::NPUGuard device_guard(device_); c10_npu::npuSynchronizeDevice(true); } @@ -3364,15 +3362,6 @@ public: emptyCacheImpl(check_error, false); } - void clearIpcHandles() override - { - std::lock_guard lock(ipcHandleMutex); - for (auto &handle : ipcHandles) { - NPU_CHECK_ERROR(c10_npu::acl::AclrtFreePhysical(handle)); - } - ipcHandles.clear(); - } - void *getBaseAllocation(void *ptr, size_t *outSize) override { Block *block = get_allocated_block(ptr); @@ -3686,7 +3675,10 @@ public: void clear() { if (npu_ipc_ptr_) { - c10_npu::NPUGuard device_guard(device_); + { + c10_npu::NPUGuard device_guard(device_); + c10_npu::npuSynchronizeDevice(true); + } NPU_CHECK_ERROR(c10_npu::acl::AclrtIpcMemClose(handle_s.c_str())); npu_ipc_ptr_ = nullptr; } diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.h b/torch_npu/csrc/core/npu/NPUCachingAllocator.h index 66fb00f30ba7d53611fbecc1e3a4bfc18da2de05..ed16cd24f121590ae27e538244d5dc36707f7c2c 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.h +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.h @@ -205,7 +205,6 @@ public: virtual void emptyCacheImpl(bool check_error, bool free_physical) = 0; virtual void emptyCache(bool check_error) = 0; virtual void emptyVirtAddrCache(bool check_error) = 0; - virtual void clearIpcHandles() = 0; virtual void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; virtual void recordStream(const c10::DataPtr& ptr, c10_npu::NPUStream stream) = 0; @@ -323,11 +322,6 @@ C10_NPU_API inline void emptyVirtAddrCache(bool check_error = true) return get()->emptyVirtAddrCache(check_error); } -inline void clearIpcHandles() -{ - return get()->clearIpcHandles(); -} - inline void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { return get()->cacheInfo(dev_id, cachedAndFree, largestBlock); diff --git a/torch_npu/csrc/ipc/StorageSharing.cpp b/torch_npu/csrc/ipc/StorageSharing.cpp index 02ff2f89e13352205e33b17d26062bb6f198ab0b..330ad45d57110d1a95eadd649ab843a30a3aa7c1 100644 --- a/torch_npu/csrc/ipc/StorageSharing.cpp +++ b/torch_npu/csrc/ipc/StorageSharing.cpp @@ -47,6 +47,7 @@ static PyObject* THNPStorage_shareNpu(PyObject* self, PyObject* args) } at::DeviceGuard device_guard(storage.device()); + c10_npu::LazySetDevice(storage.device().index()); THPObjectPtr tuple(PyTuple_New(8)); THPObjectPtr device(THPUtils_packInt32(storage.device().index())); THPObjectPtr _handle(Py_None); @@ -193,6 +194,7 @@ static PyObject* THNPStorage_newSharedNpu(PyObject* _unused, PyObject* args) const auto device = c10::checked_convert( THPUtils_unpackLong(_device), "c10::DeviceIndex"); c10_npu::NPUGuard device_guard(device); + c10_npu::LazySetDevice(device); if (PyObject_IsTrue(_event_sync_required)) { // TO BE DONE diff --git a/torch_npu/csrc/logging/Logger.cpp b/torch_npu/csrc/logging/Logger.cpp index 8bd0e1332d3dbb469f6ff814fffbd3f160c3cb45..57ad2edf95e356e8bcfb9f1245ff9ec68cebef29 100644 --- a/torch_npu/csrc/logging/Logger.cpp +++ b/torch_npu/csrc/logging/Logger.cpp @@ -4,7 +4,9 @@ #include #include #include +#include #include "torch_npu/csrc/logging/Logger.h" +#include "torch_npu/csrc/core/npu/npu_log.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" namespace npu_logging { @@ -54,11 +56,23 @@ void Logger::log(LoggingLevel level, const std::string& levelStr, const int log_ if (rank != -1) { oss << "[rank:" << rank << "]:"; } - oss << "[" << timeBuffer << ":" << std::setfill('0') << std::setw(3) << nowMs << "] " << name_ << ": [" << - levelStr << "] " << buffer << std::endl; + // Keep 3 decimal places for milliseconds. + oss << "[" << getpid() << "] [" << timeBuffer << ":" << std::setfill('0') << std::setw(3) << nowMs << "] " + << name_ << ": [" << levelStr << "] [" << syscall(SYS_gettid) << "] " << buffer << std::endl; std::string s = oss.str(); std::cerr.write(s.c_str(), s.size()); std::cerr.flush(); + + // plog + if (level == LoggingLevel::DEBUG) { + ASCEND_LOGD("[%s] %s", name_.c_str(), buffer); + } else if (level == LoggingLevel::INFO) { + ASCEND_LOGI("[%s] %s", name_.c_str(), buffer); + } else if (level == LoggingLevel::WARNING) { + ASCEND_LOGW("[%s] %s", name_.c_str(), buffer); + } else { + ASCEND_LOGE("[%s] %s", name_.c_str(), buffer); + } } void Logger::debug(const char* format, ...) diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index 7e6fc6bdfea9baeda692ffc3ca462a188f8e18ab..a4b63f599d3d2b26cbfd7b53528783d8e1f770ed 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -44,6 +44,7 @@ #include "torch_npu/csrc/utils/LazyInit.h" #include "third_party/acl/inc/acl/acl.h" #include "torch_npu/csrc/profiler/python/combined_traceback.h" +#include "torch_npu/csrc/ipc/NPUIPCTypes.h" #include "torch_npu/csrc/core/npu/interface/OpInterface.h" #include "torch_npu/csrc/core/npu/GetCANNInfo.h" #include "torch_npu/csrc/core/npu/NPUWorkspaceAllocator.h" @@ -992,6 +993,14 @@ PyObject* THNPModule_emptyCache(PyObject *_unused, PyObject *noargs) Py_RETURN_NONE; } +PyObject* THNPModule_npu_ipc_collect(PyObject *_unused, PyObject *noargs) +{ + HANDLE_TH_ERRORS + torch_npu::ipc::NpuIPCCollect(); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + PyObject* THNPModule_emptyVirtAddrCache(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS @@ -1956,6 +1965,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_is_jit_compile_false", (PyCFunction)THNPModule_is_jit_compile_false_wrap, METH_NOARGS, nullptr}, {"_npu_setMemoryFraction", (PyCFunction) THNPModule_setMemoryFraction, METH_VARARGS, nullptr}, {"_npu_emptyCache", (PyCFunction) THNPModule_emptyCache, METH_NOARGS, nullptr}, + {"_npu_ipc_collect", (PyCFunction) THNPModule_npu_ipc_collect, METH_NOARGS, nullptr}, {"_npu_emptyVirtAddrCache", (PyCFunction) THNPModule_emptyVirtAddrCache, METH_NOARGS, nullptr}, {"_npu_memoryStats", (PyCFunction) THNPModule_memoryStats, METH_O, nullptr}, {"_npu_resetAccumulatedMemoryStats", (PyCFunction) THNPModule_resetAccumulatedMemoryStats, METH_O, nullptr}, diff --git a/torch_npu/csrc/npu/NPUPluggableAllocator.cpp b/torch_npu/csrc/npu/NPUPluggableAllocator.cpp index dd03a0929b7aef6945d98dab5b09012b97ba4986..b3e1c5c961ca5ff5f9b945ac1f14d5f72f0bcaa0 100644 --- a/torch_npu/csrc/npu/NPUPluggableAllocator.cpp +++ b/torch_npu/csrc/npu/NPUPluggableAllocator.cpp @@ -208,12 +208,6 @@ void NPUPluggableAllocator::emptyVirtAddrCache(bool check_error) "If you need it, please file an issue describing your use case."); } -void NPUPluggableAllocator::clearIpcHandles() -{ - TORCH_NPU_WARN("NPUPluggableAllocator does not yet support clearIpcHandles. " - "If you need it, please file an issue describing your use case."); -} - void NPUPluggableAllocator::cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { TORCH_NPU_WARN("NPUPluggableAllocator does not yet support cacheInfo. " diff --git a/torch_npu/csrc/npu/NPUPluggableAllocator.h b/torch_npu/csrc/npu/NPUPluggableAllocator.h index 9d5c9965de1b614713c16ecfe38acee51d0d4947..5fa564d1b8692809870477017a0157ca6c57a697 100644 --- a/torch_npu/csrc/npu/NPUPluggableAllocator.h +++ b/torch_npu/csrc/npu/NPUPluggableAllocator.h @@ -63,7 +63,6 @@ struct NPUPluggableAllocator void emptyCacheImpl(bool check_error, bool free_physical) override; void emptyCache(bool check_error) override; void emptyVirtAddrCache(bool check_error) override; - void clearIpcHandles() override; void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) override; void* getBaseAllocation(void* ptr, size_t* size) override; void recordStream(const c10::DataPtr&, streamType stream) override; diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 957ad7e7c8b16674d68693825b5846f7e318138e..7c3ca9dab85f4ee05e17bb3b802970f09061bdf1 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -120,7 +120,8 @@ __all__ = [ "get_device_limit", "set_stream_limit", "reset_stream_limit", - "get_stream_limit" + "get_stream_limit", + "ipc_collect" ] from typing import Tuple, Union @@ -140,7 +141,7 @@ from .utils import (synchronize, device_count, can_device_access_peer, set_devic device, device_of, StreamContext, stream, set_stream, current_stream, default_stream, set_sync_debug_mode, get_sync_debug_mode, init_dump, current_blas_handle, is_bf16_supported, utilization, finalize_dump, set_dump, get_npu_overflow_flag, clear_npu_overflow_flag, mem_get_info, - check_uce_in_memory, stress_detect, _get_uce_addr) + check_uce_in_memory, stress_detect, _get_uce_addr, ipc_collect) from ._recovery import restart_device, stop_device from .streams import Stream, Event, SyncLaunchStream, ExternalEvent from .mstx import mstx diff --git a/torch_npu/npu/utils.py b/torch_npu/npu/utils.py index 1fa5d3848ff9eb92a5cfded983541cf845a9af6d..7c19846feb181c65e52c9411200ed46a70df8117 100644 --- a/torch_npu/npu/utils.py +++ b/torch_npu/npu/utils.py @@ -19,7 +19,7 @@ __all__ = ["synchronize", "device_count", "can_device_access_peer", "set_device" "stream", "set_stream", "current_stream", "default_stream", "set_sync_debug_mode", "get_sync_debug_mode", "init_dump", "set_dump", "finalize_dump", "is_support_inf_nan", "is_bf16_supported", "get_npu_overflow_flag", "npu_check_overflow", "clear_npu_overflow_flag", "current_blas_handle", - "check_uce_in_memory", "stress_detect", "get_cann_version"] + "check_uce_in_memory", "stress_detect", "get_cann_version", "ipc_collect"] def get_cann_version(module="CANN"): @@ -60,6 +60,19 @@ def synchronize(device=None): return torch_npu._C._npu_synchronize() +def ipc_collect(): + r"""Force collects NPU memory after it has been released by NPU IPC. + + .. note:: + Checks if any sent NPU tensors could be cleaned from the memory. Force + closes shared memory file used for reference counting if there is no + active counters. Useful when the producer process stopped actively sending + tensors and want to release unused memory. + """ + torch_npu.npu._lazy_init() + return torch_npu._C._npu_ipc_collect() + + def _parse_visible_devices() -> Union[List[int], List[str]]: r"""Parse ASCEND_RT_VISIBLE_DEVICES environment variable.""" var = os.getenv("ASCEND_RT_VISIBLE_DEVICES")