diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index 32b24296553856a238e2a1348a808bc788ca5345..527b734bfc729f79f24cabe37340ba3664f558bd 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -743,6 +743,30 @@ "torch_npu.npu.Event.synchronize": { "signature": "(self)" }, + "torch_npu.npu.ExternalEvent": { + "signature": "()" + }, + "torch_npu.npu.ExternalEvent.record": { + "signature": "(self, stream=None)" + }, + "torch_npu.npu.ExternalEvent.wait": { + "signature": "(self, stream=None)" + }, + "torch_npu.npu.ExternalEvent.reset": { + "signature": "(self, stream=None)" + }, + "torch_npu.npu.graph_task_group_begin": { + "signature": "(stream)" + }, + "torch_npu.npu.graph_task_group_end": { + "signature": "(stream)" + }, + "torch_npu.npu.graph_task_update_begin": { + "signature": "(stream, handle)" + }, + "torch_npu.npu.graph_task_update_end": { + "signature": "(stream)" + }, "torch_npu.npu.FloatStorage": { "signature": "(*args, wrap_storage=None, dtype=None, device=None, _internal=False)" }, diff --git a/third_party/acl/inc/acl/acl_base.h b/third_party/acl/inc/acl/acl_base.h index 9780a01f8fbdf7b3e49d7b3dd2e7326c4f7da486..749e1d3c774915267b4b3a2d12610c807fbd7f55 100755 --- a/third_party/acl/inc/acl/acl_base.h +++ b/third_party/acl/inc/acl/acl_base.h @@ -56,6 +56,7 @@ typedef void *aclrtAllocatorDesc; typedef void *aclrtAllocator; typedef void *aclrtAllocatorBlock; typedef void *aclrtAllocatorAddr; +typedef void *aclrtTaskGrp; static const int ACL_ERROR_NONE = 0; static const int ACL_SUCCESS = 0; diff --git a/third_party/acl/inc/acl/acl_mdl.h b/third_party/acl/inc/acl/acl_mdl.h index 45a36898efd65816fe28bdeb738acaefca55bc02..f13950ab8504fcb45a80e009ef4d5f9fb6b90679 100755 --- a/third_party/acl/inc/acl/acl_mdl.h +++ b/third_party/acl/inc/acl/acl_mdl.h @@ -1545,6 +1545,44 @@ ACL_FUNC_VISIBILITY aclError aclmdlRICaptureEnd(aclrtStream stream, aclmdlRI *mo */ ACL_FUNC_VISIBILITY aclError aclmdlRIDebugPrint(aclmdlRI modelRI); +/** + * @ingroup AscendCL + * @brief the start interface of the task group + * @param stream [IN] capture stream + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclmdlRICaptureTaskGrpBegin(aclrtStream stream); + +/** + * @ingroup AscendCL + * @brief the end interface of the task group + * @param stream [IN] capture stream + * @param handle [OUT] task group handle + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclmdlRICaptureTaskGrpEnd(aclrtStream stream, aclrtTaskGrp *handle); + +/** + * @ingroup AscendCL + * @brief begin to update the task group specified by the handle + * @param stream [IN] specify the stream used for task update + * @param handle [IN] task group handle + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclmdlRICaptureTaskUpdateBegin(aclrtStream stream, aclrtTaskGrp handle); + +/** + * @ingroup AscendCL + * @brief end the update of the task + * @param stream [IN] specify the stream used for task update + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclmdlRICaptureTaskUpdateEnd(aclrtStream stream); + #ifdef __cplusplus } #endif diff --git a/third_party/acl/inc/acl/acl_rt.h b/third_party/acl/inc/acl/acl_rt.h index be607a828130aa010461aac0e98dcab704b28d32..963c29f93a9730bbe1937b05cbbb0eb66db07854 100755 --- a/third_party/acl/inc/acl/acl_rt.h +++ b/third_party/acl/inc/acl/acl_rt.h @@ -22,6 +22,7 @@ extern "C" { #define ACL_EVENT_SYNC 0x00000001u #define ACL_EVENT_CAPTURE_STREAM_PROGRESS 0x00000002u #define ACL_EVENT_TIME_LINE 0x00000008u +#define ACL_EVENT_EXTERNAL 0x00000020u #define ACL_STREAM_FAST_LAUNCH 0x00000001u #define ACL_STREAM_FAST_SYNC 0x00000002u diff --git a/torch_npu/csrc/core/npu/NPUEvent.cpp b/torch_npu/csrc/core/npu/NPUEvent.cpp index f43fd7d5dd4732b46e404d580c7b9278ad4338b8..a735b567e35bb1249ef96e4532b55138704fc684 100644 --- a/torch_npu/csrc/core/npu/NPUEvent.cpp +++ b/torch_npu/csrc/core/npu/NPUEvent.cpp @@ -91,6 +91,9 @@ void NPUEvent::record(const NPUStream& stream) void NPUEvent::block(const NPUStream& stream) { + if (!is_created_ && (flags_ == ACL_EVENT_EXTERNAL)) { + createEvent(stream.device_index()); + } if (is_created_) { NPUGuard guard(stream.device_index()); c10_npu::queue::LaunchWaitEventTask(event_, stream); @@ -162,6 +165,16 @@ void NPUEvent::synchronize() const } } +void NPUEvent::reset(const NPUStream& stream) const +{ + if (is_created_) { + TORCH_CHECK(flags_ == ACL_EVENT_EXTERNAL, + "API reset() only support ACL_EVENT_EXTERNAL flag event.", PTA_ERROR(ErrCode::INTERNAL)); + NPUGuard guard(stream.device_index()); + NPU_CHECK_ERROR_WITHOUT_UCE(aclrtResetEvent(event_, stream.stream())); + } +} + void NPUEvent::createEvent(c10::DeviceIndex device_index) { device_index_ = device_index; diff --git a/torch_npu/csrc/core/npu/NPUEvent.h b/torch_npu/csrc/core/npu/NPUEvent.h index 5eba816db69496545410a3bae53a1e3649185772..cf6e34ee9c73b9e544ca12adf625d9d27fa21f23 100644 --- a/torch_npu/csrc/core/npu/NPUEvent.h +++ b/torch_npu/csrc/core/npu/NPUEvent.h @@ -49,6 +49,7 @@ struct C10_NPU_API NPUEvent { float elapsed_time(const NPUEvent& other) const; uint64_t recorded_time() const; void synchronize() const; + void reset(const NPUStream& stream) const; // npu do not support IpcEventHandle until now diff --git a/torch_npu/csrc/core/npu/NPUGraph.cpp b/torch_npu/csrc/core/npu/NPUGraph.cpp index e259d1a7248cc8dcc14c0d589b6532bdb4682da8..48522306d2954dcb1fa808a480542b42fc65d021 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.cpp +++ b/torch_npu/csrc/core/npu/NPUGraph.cpp @@ -25,6 +25,30 @@ MempoolId_t graph_pool_handle() return new_pool.id(); } +void graph_task_group_begin(c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskGrpBegin(stream)); +} + +NPUTaskGroupHandle graph_task_group_end(c10_npu::NPUStream stream) +{ + aclrtTaskGrp group; + NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskGrpEnd(stream, &group)); + NPUTaskGroupHandle handle; + handle.task_group = group; + return handle; +} + +void graph_task_update_begin(c10_npu::NPUStream stream, NPUTaskGroupHandle handle) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskUpdateBegin(stream, handle.task_group)); +} + +void graph_task_update_end(c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskUpdateEnd(stream)); +} + /** * Note [CUDA Graph Wrapper Class] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch_npu/csrc/core/npu/NPUGraph.h b/torch_npu/csrc/core/npu/NPUGraph.h index ccb8c29067c0909557b9b578ac45024d1ced941c..442ae335ccae15506f4352ef5c204e185917672b 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.h +++ b/torch_npu/csrc/core/npu/NPUGraph.h @@ -14,6 +14,15 @@ namespace c10_npu { // to CUDAGraph::capture_begin TORCH_NPU_API MempoolId_t graph_pool_handle(); +struct TORCH_NPU_API NPUTaskGroupHandle { + aclrtTaskGrp task_group; +}; + +TORCH_NPU_API void graph_task_group_begin(c10_npu::NPUStream stream); +TORCH_NPU_API NPUTaskGroupHandle graph_task_group_end(c10_npu::NPUStream stream); +TORCH_NPU_API void graph_task_update_begin(c10_npu::NPUStream stream, NPUTaskGroupHandle handle); +TORCH_NPU_API void graph_task_update_end(c10_npu::NPUStream stream); + struct TORCH_NPU_API NPUGraph { NPUGraph(); ~NPUGraph(); diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index 80ff3654ac9d66970fdadd9192ad4f12e5a9947e..4ca0ee32a746434f8d8af8e3faa3230b78226b20 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -76,6 +76,10 @@ LOAD_FUNCTION(aclmdlRIDebugPrint) LOAD_FUNCTION(aclmdlRIExecuteAsync) LOAD_FUNCTION(aclmdlRIDestroy) LOAD_FUNCTION(aclsysGetCANNVersion) +LOAD_FUNCTION(aclmdlRICaptureTaskGrpBegin) +LOAD_FUNCTION(aclmdlRICaptureTaskGrpEnd) +LOAD_FUNCTION(aclmdlRICaptureTaskUpdateBegin) +LOAD_FUNCTION(aclmdlRICaptureTaskUpdateEnd) aclprofStepInfoPtr init_stepinfo() { typedef aclprofStepInfoPtr(*npdInitFunc)(); @@ -202,13 +206,16 @@ aclError AclrtCreateEventWithFlag(aclrtEvent *event, uint32_t flag) // 2. There is no limit on the number of events. // 3. Only support query event record status, aclrtQueryEvent and aclrtQueryEventWaitStatus are not supported. // 4. aclrtDestroyEvent change to asynchronous destroy event. - static AclrtCreateEventWithFlagFunc func = (AclrtCreateEventWithFlagFunc)GET_FUNC(aclrtCreateEventExWithFlag); - if (func == nullptr) { - TORCH_NPU_WARN_ONCE(func, "Failed to find function ", "aclrtCreateEventExWithFlag"); - func = (AclrtCreateEventWithFlagFunc)GET_FUNC(aclrtCreateEventWithFlag); + static AclrtCreateEventWithFlagFunc func_ex = (AclrtCreateEventWithFlagFunc)GET_FUNC(aclrtCreateEventExWithFlag); + if (func_ex == nullptr) { + TORCH_NPU_WARN_ONCE(func_ex, "Failed to find function ", "aclrtCreateEventExWithFlag"); } + static AclrtCreateEventWithFlagFunc func = (AclrtCreateEventWithFlagFunc)GET_FUNC(aclrtCreateEventWithFlag); TORCH_CHECK(func, "Failed to find function ", "aclrtCreateEventWithFlag", PROF_ERROR(ErrCode::NOT_FOUND)); - return func(event, flag); + if ((flag == ACL_EVENT_EXTERNAL) || (func_ex == nullptr)) { + return func(event, flag); + } + return func_ex(event, flag); } aclError AclQueryEventWaitStatus(aclrtEvent event, aclrtEventWaitStatus *waitStatus) @@ -846,5 +853,53 @@ bool IsCaptureSupported() return is_support; } +aclError AclmdlRICaptureTaskGrpBegin(aclrtStream stream) +{ + typedef aclError (*AclmdlRICaptureTaskGrpBegin)(aclrtStream); + static AclmdlRICaptureTaskGrpBegin func = nullptr; + if (func == nullptr) { + func = (AclmdlRICaptureTaskGrpBegin) GET_FUNC(aclmdlRICaptureTaskGrpBegin); + } + + TORCH_CHECK(func, "Failed to find function aclmdlRICaptureTaskGrpBegin", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(stream); +} + +aclError AclmdlRICaptureTaskGrpEnd(aclrtStream stream, aclrtTaskGrp *handle) +{ + typedef aclError (*AclmdlRICaptureTaskGrpEnd)(aclrtStream, aclrtTaskGrp*); + static AclmdlRICaptureTaskGrpEnd func = nullptr; + if (func == nullptr) { + func = (AclmdlRICaptureTaskGrpEnd) GET_FUNC(aclmdlRICaptureTaskGrpEnd); + } + + TORCH_CHECK(func, "Failed to find function aclmdlRICaptureTaskGrpEnd", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(stream, handle); +} + +aclError AclmdlRICaptureTaskUpdateBegin(aclrtStream stream, aclrtTaskGrp handle) +{ + typedef aclError (*AclmdlRICaptureTaskUpdateBegin)(aclrtStream, aclrtTaskGrp); + static AclmdlRICaptureTaskUpdateBegin func = nullptr; + if (func == nullptr) { + func = (AclmdlRICaptureTaskUpdateBegin) GET_FUNC(aclmdlRICaptureTaskUpdateBegin); + } + + TORCH_CHECK(func, "Failed to find function aclmdlRICaptureTaskUpdateBegin", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(stream, handle); +} + +aclError AclmdlRICaptureTaskUpdateEnd(aclrtStream stream) +{ + typedef aclError (*AclmdlRICaptureTaskUpdateEnd)(aclmdlRI); + static AclmdlRICaptureTaskUpdateEnd func = nullptr; + if (func == nullptr) { + func = (AclmdlRICaptureTaskUpdateEnd) GET_FUNC(aclmdlRICaptureTaskUpdateEnd); + } + + TORCH_CHECK(func, "Failed to find function aclmdlRICaptureTaskUpdateEnd", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(stream); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index ca5c03d30ea4ddd75f3fb674de62729103f293f4..2f34bd561ec5deb43af166d2e4bcfca7437e2a7f 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -199,5 +199,13 @@ aclError AclmdlRIDestroy(aclmdlRI modelRI); bool IsCaptureSupported(); +aclError AclmdlRICaptureTaskGrpBegin(aclrtStream stream); + +aclError AclmdlRICaptureTaskGrpEnd(aclrtStream stream, aclrtTaskGrp *handle); + +aclError AclmdlRICaptureTaskUpdateBegin(aclrtStream stream, aclrtTaskGrp handle); + +aclError AclmdlRICaptureTaskUpdateEnd(aclrtStream stream); + } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/npu/Event.cpp b/torch_npu/csrc/npu/Event.cpp index 3c92a335395df28b6baa3cf2821f4a903cedaf40..38db7ccdf22f650db1ce92e1740bb1559b293056 100644 --- a/torch_npu/csrc/npu/Event.cpp +++ b/torch_npu/csrc/npu/Event.cpp @@ -17,10 +17,11 @@ static PyObject* THNPEvent_pynew(PyTypeObject *type, PyObject *args, PyObject *k unsigned char enable_timing = 0; unsigned char blocking = 0; unsigned char interprocess = 0; + unsigned char external = 0; - constexpr const char* kwlist[] = {"enable_timing", "blocking", "interprocess", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|bbb", const_cast(kwlist), - &enable_timing, &blocking, &interprocess)) { + constexpr const char* kwlist[] = {"enable_timing", "blocking", "interprocess", "graph_external", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|bbbb", const_cast(kwlist), + &enable_timing, &blocking, &interprocess, &external)) { return nullptr; } @@ -37,6 +38,9 @@ static PyObject* THNPEvent_pynew(PyTypeObject *type, PyObject *args, PyObject *k } else { flags = enable_timing ? ACL_EVENT_TIME_LINE : ACL_EVENT_DEFAULT; } + if (external) { + flags = ACL_EVENT_EXTERNAL; + } new (&self->npu_event) c10_npu::NPUEvent(flags); return (PyObject *)ptr.release(); @@ -121,6 +125,18 @@ static PyObject* THNPEvent_synchronize(THNPEvent *self, PyObject *noargs) END_HANDLE_TH_ERRORS } +static PyObject* THNPEvent_reset(THNPEvent *self, THNPStream *stream) +{ + HANDLE_TH_ERRORS + { + pybind11::gil_scoped_release no_gil; + self->npu_event.reset(stream->npu_stream); + ASCEND_LOGI("Event: reset api is successfully executed, event=%p", self->npu_event.event()); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static struct PyGetSetDef THNPEvent_properties[] = { {"device", (getter)THNPEvent_get_device, nullptr, nullptr, nullptr}, {"npu_event", (getter)THNPEvent_get_npu_event, nullptr, nullptr, nullptr}, @@ -134,6 +150,7 @@ static PyMethodDef THNPEvent_methods[] = { {(char*)"elapsed_time", (PyCFunction)THNPEvent_elapsed_time, METH_O, nullptr}, {(char*)"recorded_time", (PyCFunction)THNPEvent_recorded_time, METH_NOARGS, nullptr}, {(char*)"synchronize", (PyCFunction)THNPEvent_synchronize, METH_NOARGS, nullptr}, + {(char*)"reset", (PyCFunction)THNPEvent_reset, METH_O, nullptr}, {nullptr} }; diff --git a/torch_npu/csrc/npu/Graph.cpp b/torch_npu/csrc/npu/Graph.cpp index 3a471cb2aa121a6d84fbab5b09d6417bd505c31f..c8d30cfa448b07e00d7671ff5e6aa7169686ee60 100644 --- a/torch_npu/csrc/npu/Graph.cpp +++ b/torch_npu/csrc/npu/Graph.cpp @@ -7,6 +7,7 @@ #include "torch_npu/csrc/core/npu/NPUGraph.h" #include "torch_npu/csrc/core/npu/NPUGraphsUtils.h" +#include "torch_npu/csrc/npu/Stream.h" template using shared_ptr_class_ = py::class_>; @@ -16,7 +17,26 @@ void TORCH_NPU_API THNPGraph_init(PyObject* module) { // but CI linter and some builds prefer "module". auto torch_N_m = py::handle(module).cast(); - torch_N_m.def("_graph_pool_handle", &c10_npu::graph_pool_handle); + py::class_(torch_N_m, "_NPUTaskGroupHandle") + .def_readonly("task_group", &c10_npu::NPUTaskGroupHandle::task_group); + + torch_N_m.def("_graph_pool_handle", &c10_npu::graph_pool_handle) + .def("_graph_task_group_begin", [](py::object py_stream) { + auto stream = (*py_stream).ptr(); + c10_npu::graph_task_group_begin(THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_graph_task_group_end", [](py::object py_stream) { + auto stream = (*py_stream).ptr(); + return c10_npu::graph_task_group_end(THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_graph_task_update_begin", [](py::object py_stream, c10_npu::NPUTaskGroupHandle handle) { + auto stream = (*py_stream).ptr(); + c10_npu::graph_task_update_begin(THNPUtils_PyObject_to_NPUStream(stream), handle); + }) + .def("_graph_task_update_end", [](py::object py_stream) { + auto stream = (*py_stream).ptr(); + c10_npu::graph_task_update_end(THNPUtils_PyObject_to_NPUStream(stream)); + }); shared_ptr_class_(torch_N_m, "_NPUGraph") .def(py::init<>()) diff --git a/torch_npu/csrc/npu/Stream.cpp b/torch_npu/csrc/npu/Stream.cpp index 8059cf3447616fae27e719ca828db5fd8cb91758..180fede5ec3acb34785a91e2fa21bb48640938bf 100644 --- a/torch_npu/csrc/npu/Stream.cpp +++ b/torch_npu/csrc/npu/Stream.cpp @@ -250,3 +250,12 @@ std::vector> THNPUtils_PySequence_to_NPUStream } return streams; } + +c10_npu::NPUStream THNPUtils_PyObject_to_NPUStream(PyObject* stream) +{ + TORCH_CHECK(PyObject_IsInstance(stream, THNPStreamClass), "Need torch_npu.npu.Stream argument type."); + return c10_npu::NPUStream::unpack3( + (reinterpret_cast(stream))->stream_id, + (reinterpret_cast(stream))->device_index, + static_cast((reinterpret_cast(stream))->device_type)); +} diff --git a/torch_npu/csrc/npu/Stream.h b/torch_npu/csrc/npu/Stream.h index f51479f2b06002a913b13eb840cc178a1cd4b32a..f6f084bca3be71c4d01bb126ebac21eb1991c29e 100644 --- a/torch_npu/csrc/npu/Stream.h +++ b/torch_npu/csrc/npu/Stream.h @@ -21,4 +21,6 @@ inline bool THNPStream_Check(PyObject* obj) TORCH_NPU_API std::vector> THNPUtils_PySequence_to_NPUStreamList(PyObject* obj); +c10_npu::NPUStream THNPUtils_PyObject_to_NPUStream(PyObject* py_stream); + #endif // THNP_STREAM_INC diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 12d67a0bb8e88c2710b6d2124621aaeca423a2d4..337ad40fc161638057d8be1b9f636c7e35385b81 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -108,7 +108,12 @@ __all__ = [ "graph", "graph_pool_handle", "is_current_stream_capturing", - "make_graphed_callables" + "make_graphed_callables", + "ExternalEvent", + "graph_task_group_begin", + "graph_task_group_end", + "graph_task_update_begin", + "graph_task_update_end" ] from typing import Tuple, Union @@ -130,7 +135,7 @@ from .utils import (synchronize, device_count, can_device_access_peer, set_devic utilization, finalize_dump, set_dump, get_npu_overflow_flag, clear_npu_overflow_flag, mem_get_info, check_uce_in_memory, stress_detect) from ._recovery import restart_device, stop_device -from .streams import Stream, Event, SyncLaunchStream +from .streams import Stream, Event, SyncLaunchStream, ExternalEvent from .mstx import mstx from .npu_config import * # noqa: F403 from .autocast_utils import * # noqa: F403 @@ -144,6 +149,10 @@ from .graphs import ( graph_pool_handle, is_current_stream_capturing, make_graphed_callables, + graph_task_group_begin, + graph_task_group_end, + graph_task_update_begin, + graph_task_update_end, ) # init profiler diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index c9608906fd2e7ec9ffc40678455439e8407c05c8..dd0a34c018e59576cd7ab8e87918a472e0cc3339 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -13,11 +13,19 @@ if not hasattr(torch_npu._C, "_NPUStreamBase"): torch_npu._C.__dict__["_npu_isCurrentStreamCapturing"] = _dummy_type( "_npu_isCurrentStreamCapturing" ) + torch_npu._C.__dict__["_graph_task_group_begin"] = _dummy_type("_graph_task_group_begin") + torch_npu._C.__dict__["_graph_task_group_end"] = _dummy_type("_graph_task_group_end") + torch_npu._C.__dict__["_graph_task_update_begin"] = _dummy_type("_graph_task_update_begin") + torch_npu._C.__dict__["_graph_task_update_end"] = _dummy_type("_graph_task_update_end") from torch_npu._C import ( # noqa: F401 _npu_isCurrentStreamCapturing, _NPUGraph, _graph_pool_handle, + _graph_task_group_begin, + _graph_task_group_end, + _graph_task_update_begin, + _graph_task_update_end, ) @@ -41,6 +49,22 @@ def graph_pool_handle(): return _graph_pool_handle() +def graph_task_group_begin(stream): + _graph_task_group_begin(stream) + + +def graph_task_group_end(stream): + return _graph_task_group_end(stream) + + +def graph_task_update_begin(stream, handle): + _graph_task_update_begin(stream, handle) + + +def graph_task_update_end(stream): + _graph_task_update_end(stream) + + # Python shim helps Sphinx process docstrings more reliably. class NPUGraph(torch_npu._C._NPUGraph): r"""Wrapper around a NPU graph. diff --git a/torch_npu/npu/streams.py b/torch_npu/npu/streams.py index 0bfd6967cef5bbe74187fa3249dd9bcdfe662150..fb78aa68d867f3c64b4939965ea9248a5cfb1a5a 100644 --- a/torch_npu/npu/streams.py +++ b/torch_npu/npu/streams.py @@ -4,7 +4,7 @@ from torch._streambase import _StreamBase, _EventBase import torch_npu import torch_npu._C -__all__ = ["Stream", "Event", "SyncLaunchStream"] +__all__ = ["Stream", "Event", "SyncLaunchStream", "ExternalEvent"] class Stream(torch_npu._C._NPUStreamBase, _StreamBase): @@ -133,7 +133,8 @@ class Event(torch_npu._C._NPUEventBase, _EventBase): """ def __new__(cls, enable_timing=False, blocking=False, interprocess=False): - return super(Event, cls).__new__(cls, enable_timing=enable_timing, blocking=blocking, interprocess=interprocess) + return super(Event, cls).__new__(cls, enable_timing=enable_timing, blocking=blocking, + interprocess=interprocess, graph_external=False) def record(self, stream=None): r"""Records the event in a given stream. @@ -197,6 +198,66 @@ class Event(torch_npu._C._NPUEventBase, _EventBase): return '' +class ExternalEvent(torch_npu._C._NPUEventBase, _EventBase): + r"""Wrapper around a NPU event with graph_external=True. + + The difference from torch.npu.Event is that you can call wait() before + record(). Before reusing ExternalEvent, you need to call reset() to clear + the flag. + + Event is captured in the graph as an external event node when performing + stream capture. + + The underlying NPU events are lazily initialized when the event is first + recorded or waited. + + """ + + def __new__(cls): + return super(ExternalEvent, cls).__new__(cls, enable_timing=False, blocking=False, + interprocess=False, graph_external=True) + + def record(self, stream=None): + r"""Records the event in a given stream. + + Uses ``torch_npu.npu.current_stream()`` if no stream is specified. The + stream's device must match the event's device. + """ + if stream is None: + stream = torch_npu.npu.current_stream() + super(ExternalEvent, self).record(stream) + + def wait(self, stream=None): + r"""Makes all future work submitted to the given stream wait for this + event. + + Use ``torch_npu.npu.current_stream()`` if no stream is specified. + """ + if stream is None: + stream = torch_npu.npu.current_stream() + super(ExternalEvent, self).wait(stream) + + def reset(self, stream=None): + r"""Reset an event. + + Users need to make sure to wait for the tasks in the Stream + to complete before resetting the Event. + """ + if stream is None: + stream = torch_npu.npu.current_stream() + super(ExternalEvent, self).reset(stream) + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.npu_event) + + def __repr__(self): + if self.npu_event: + return ''.format(self._as_parameter_.value) + else: + return '' + + class SyncLaunchStream(torch_npu._C._NPUStreamBase, _StreamBase): r"""Wrapper around a SyncLaunch NPU stream.