diff --git a/mindspore/ccsrc/include/common/pybind_api/api_register.h b/mindspore/ccsrc/include/common/pybind_api/api_register.h index e1c35c96fcd8af394bc85faba636d5583cf02e3c..d515e6a0a1b590ab9897ebc80945f5ec8934a210 100644 --- a/mindspore/ccsrc/include/common/pybind_api/api_register.h +++ b/mindspore/ccsrc/include/common/pybind_api/api_register.h @@ -55,6 +55,7 @@ void RegStorage(py::module *m); namespace hal { void RegStream(py::module *m); void RegEvent(py::module *m); +void RegResLimit(py::module *m); FRONTEND_EXPORT void RegCommHandle(py::module *m); void RegMemory(py::module *m); void RegUtils(py::module *m); diff --git a/mindspore/ccsrc/plugin/ascend/graph_optimizer/stream_assign/acl_stream_assign.cc b/mindspore/ccsrc/plugin/ascend/graph_optimizer/stream_assign/acl_stream_assign.cc index eaeb72bb193bc796b59a08b519ee227d37f6280f..73d8a268bee246266ecd165d4878dc84d550163c 100644 --- a/mindspore/ccsrc/plugin/ascend/graph_optimizer/stream_assign/acl_stream_assign.cc +++ b/mindspore/ccsrc/plugin/ascend/graph_optimizer/stream_assign/acl_stream_assign.cc @@ -104,7 +104,41 @@ void AssignStreamForMoveTo(const AnfNodePtr &node) { } } -void AddStreamIdByGroup(const AnfNodePtr &node, DeviceResManager *device_res_manager) { +bool IsUsersSetStreamsOp(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + if (cnode->HasAttr(kAttrStreamId)) { + int64_t stream_id = GetValue(cnode->GetAttr(kAttrStreamId)); + if (stream_id != kDefaultStreamIndex) { + MS_LOG(DEBUG) << "User set attr stream id for node " << node->fullname_with_scope() << ", stream id is " + << stream_id; + return true; + } + } + return false; +} + +void AddStreamIdForUsersSetStreamsOp(const AnfNodePtr &node, std::map *stream_map) { + auto cnode = node->cast(); + int64_t stream_id = GetValue(cnode->GetAttr(kAttrStreamId)); + size_t new_stream_id; + const auto &iter = stream_map->find(stream_id); + if (iter != stream_map->end()) { + new_stream_id = iter->second; + } else { + AscendStreamMng::GetInstance().CreateStream(&new_stream_id); + MS_LOG(INFO) << "Create ascend copy out stream, stream id: " << new_stream_id; + } + MS_LOG(INFO) << "Set stream id by no group for node " << node->fullname_with_scope(); + AnfAlgo::SetStreamId(new_stream_id, node.get()); + common::AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(new_stream_id), node); +} + +void AddStreamIdByGroup(const AnfNodePtr &node, DeviceResManager *device_res_manager, + std::map *stream_map) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { MS_LOG(EXCEPTION) << "Node is not a cnode: " << node->DebugString(); @@ -113,6 +147,10 @@ void AddStreamIdByGroup(const AnfNodePtr &node, DeviceResManager *device_res_man if (!common::AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { if (IsPrimitiveCNode(node, prim::kPrimMoveTo) || IsPrimitiveCNode(node, prim::kPrimMoveAssign)) { AssignStreamForMoveTo(node); + } else if (IsUsersSetStreamsOp(node)) { + AddStreamIdForUsersSetStreamsOp(node, stream_map); + MS_LOG(INFO) << "Set stream id by default for node " << node->fullname_with_scope() + << ", because it is users set stream operator."; } else { AnfAlgo::SetStreamId(kDefaultStreamIndex, node.get()); common::AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(kDefaultStreamIndex), node); @@ -144,12 +182,14 @@ void AddStreamIdByGroup(const AnfNodePtr &node, DeviceResManager *device_res_man } } } + } // namespace void AclStreamAssign::AssignStream( const NotNull &kernel_graph, const std::vector>> &mock_exec_order, DeviceResManager *device_res_manager) { + static std::map stream_map; // usr stream id, create stream id auto kernels = kernel_graph->execution_order(); if (kernels.empty()) { return; @@ -188,6 +228,8 @@ void AclStreamAssign::AssignStream( AddStreamIdForCommunicationOp(node); } else if (IsPrimitiveCNode(node, prim::kPrimMoveTo)) { AssignStreamForMoveTo(node); + } else if (IsUsersSetStreamsOp(node)) { + AddStreamIdForUsersSetStreamsOp(node, &stream_map); } else { AnfAlgo::SetStreamId(kDefaultStreamIndex, node.get()); common::AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(kDefaultStreamIndex), node); @@ -195,7 +237,7 @@ void AclStreamAssign::AssignStream( } else { // Default scene, multi_stream:group, all communication op use the communication stream by group MS_LOG(INFO) << "Set stream id by group for node " << node->fullname_with_scope(); - AddStreamIdByGroup(node, device_res_manager); + AddStreamIdByGroup(node, device_res_manager, &stream_map); } max_stream_id = std::max(max_stream_id, AnfAlgo::GetStreamId(node)); } @@ -437,7 +479,7 @@ void AclStreamAssign::UpdateEventsToExecutionOrder( (void)std::copy(before_iter->second.begin(), before_iter->second.end(), std::back_inserter(new_exec_orders)); } auto process_stream_id = AnfAlgo::GetStreamId(kernel); - if (process_stream_id != kDefaultStreamIndex) { + if (process_stream_id != kDefaultStreamIndex && !IsUsersSetStreamsOp(kernel)) { AddBoundarySendRecvKernel(kernel_graph, kDefaultStreamIndex, process_stream_id, &new_exec_orders, &no_event_streams, last_kernel, kernel); auto it = producer_streams.find(kernel); @@ -547,7 +589,7 @@ void AclStreamAssign::InsertEventsForOutputs(const NotNull &kern for (auto output_exec : stream_min_exec_node_map) { MS_EXCEPTION_IF_NULL(output_exec.second); - if (output_exec.second->stream_id == process_stream_id) { + if (output_exec.second->stream_id == process_stream_id || IsUsersSetStreamsOp(kernel)) { continue; } InsertEvents(kernel_graph, kernel, kernel, kernel_send, kernel_recv, output_exec.second->node); diff --git a/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.cc b/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.cc index 01d143647b444d57b505c03247ff5ded1417e342..1d6337e092a5c55e86080a7d18795ae71e2bb7ad 100644 --- a/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.cc +++ b/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.cc @@ -1712,6 +1712,103 @@ bool AscendResManager::DestroyAllEvents() { return true; } +void AscendResManager::GetDeviceLimit(int32_t device_id, uint32_t *cube_num, uint32_t *vector_num) { + MS_EXCEPTION_IF_NULL(cube_num); + MS_EXCEPTION_IF_NULL(vector_num); + auto ret = CALL_ASCEND_API(aclrtSetDevice, static_cast(device_id)); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Device " << device_id << " call aclrtSetDevice failed, ret[" << static_cast(ret) + << "]. The details refer to 'Ascend Error Message'."; + } + ret = CALL_ASCEND_API(aclrtGetDeviceResLimit, device_id, aclrtDevResLimitType::ACL_RT_DEV_RES_CUBE_CORE, cube_num); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtGetDeviceResLimit failed! Error flag is " << ret; + } + ret = + CALL_ASCEND_API(aclrtGetDeviceResLimit, device_id, aclrtDevResLimitType::ACL_RT_DEV_RES_VECTOR_CORE, vector_num); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtGetDeviceResLimit failed! Error flag is " << ret; + } +} + +void AscendResManager::SetDeviceLimit(int32_t device_id, int32_t cube_num, int32_t vector_num) { + enable_res_limit_ = true; + auto ret = CALL_ASCEND_API(aclrtSetDevice, static_cast(device_id)); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Device " << device_id << " call aclrtSetDevice failed, ret[" << static_cast(ret) + << "]. The details refer to 'Ascend Error Message'."; + } + if (cube_num > 0) { + ret = CALL_ASCEND_API(aclrtSetDeviceResLimit, device_id, aclrtDevResLimitType::ACL_RT_DEV_RES_CUBE_CORE, + static_cast(cube_num)); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtSetDeviceResLimit failed! Error flag is " << ret; + } + } + if (vector_num > 0) { + ret = CALL_ASCEND_API(aclrtSetDeviceResLimit, device_id, aclrtDevResLimitType::ACL_RT_DEV_RES_VECTOR_CORE, + static_cast(vector_num)); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtSetDeviceResLimit failed! Error flag is " << ret; + } + } +} + +void AscendResManager::GetStreamLimit(size_t stream_id, uint32_t *cube_num, uint32_t *vector_num) { + auto stream = AscendStreamMng::GetInstance().GetStream(stream_id); + MS_EXCEPTION_IF_NULL(cube_num); + MS_EXCEPTION_IF_NULL(vector_num); + auto ret = CALL_ASCEND_API(aclrtGetStreamResLimit, stream, aclrtDevResLimitType::ACL_RT_DEV_RES_CUBE_CORE, cube_num); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtGetStreamResLimit failed! Error flag is " << ret; + } + ret = CALL_ASCEND_API(aclrtGetStreamResLimit, stream, aclrtDevResLimitType::ACL_RT_DEV_RES_VECTOR_CORE, vector_num); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtGetStreamResLimit failed! Error flag is " << ret; + } +} +void AscendResManager::SetStreamLimit(size_t stream_id, int32_t cube_num, int32_t vector_num) { + enable_res_limit_ = true; + auto stream = AscendStreamMng::GetInstance().GetStream(stream_id); + if (cube_num > 0) { + auto ret = CALL_ASCEND_API(aclrtSetStreamResLimit, stream, aclrtDevResLimitType::ACL_RT_DEV_RES_CUBE_CORE, + static_cast(cube_num)); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtSetStreamResLimit failed! Error flag is " << ret; + } + } + if (vector_num > 0) { + auto ret = CALL_ASCEND_API(aclrtSetStreamResLimit, stream, aclrtDevResLimitType::ACL_RT_DEV_RES_VECTOR_CORE, + static_cast(vector_num)); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtSetStreamResLimit failed! Error flag is " << ret; + } + } +} + +void AscendResManager::ResetStreamLimit(size_t stream_id) { + auto stream = AscendStreamMng::GetInstance().GetStream(stream_id); + auto ret = CALL_ASCEND_API(aclrtResetStreamResLimit, stream); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtResetStreamResLimit failed! Error flag is " << ret; + } +} + +void AscendResManager::UseStreamResInCurrentThread(size_t stream_id) { + if (!enable_res_limit_) { + return; + } + if (prev_set_stream_id_ == stream_id) { + return; + } + auto stream = AscendStreamMng::GetInstance().GetStream(stream_id); + auto ret = CALL_ASCEND_API(aclrtUseStreamResInCurrentThread, stream); + if (ret != ACL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call aclrtUseStreamResInCurrentThread failed! Error flag is " << ret; + } + prev_set_stream_id_ = stream_id; +} + bool AscendResManager::GetMemUceInfo(int32_t device_id) { aclrtMemUceInfo info[MAX_MEM_UCE_INFO_ARRAY_SIZE]; size_t retSize = 0; diff --git a/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.h b/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.h index 7e0314b48538d27061c048790c5a5dabe5753263..48e4d491610e8c70f39c6842eea7cf505d8c76a6 100644 --- a/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.h +++ b/mindspore/ccsrc/plugin/ascend/res_manager/ascend_res_manager.h @@ -155,6 +155,13 @@ class ASCEND_RES_MANAGER_EXPORT AscendResManager : public DeviceResManager { bool DestroyEvent(const DeviceEventPtr &event) override; bool DestroyAllEvents() override; + void GetDeviceLimit(int32_t device_id, uint32_t *cube_num, uint32_t *vector_num) override; + void SetDeviceLimit(int32_t device_id, int32_t cube_num, int32_t vector_num) override; + void GetStreamLimit(size_t stream_id, uint32_t *cube_num, uint32_t *vector_num) override; + void SetStreamLimit(size_t stream_id, int32_t cube_num, int32_t vector_num) override; + void ResetStreamLimit(size_t stream_id) override; + void UseStreamResInCurrentThread(size_t stream_id) override; + bool single_op_multi_stream_enable() const override; void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) override; // Only used in graph_mode with MS_DISABLE_REF_MODE, delete it when delete MS_DISABLE_REF_MODEF @@ -238,6 +245,8 @@ class ASCEND_RES_MANAGER_EXPORT AscendResManager : public DeviceResManager { std::mutex device_events_mutex_; uint32_t device_id_{0}; bool enable_memory_tracker_{false}; + std::atomic enable_res_limit_ = false; + size_t prev_set_stream_id_ = UINT32_MAX; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.cc b/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.cc index 55e4927d406f8fe1c534b1c3c26800fff543f097..5548a0dfef4b98484994ee624f91d39ba1a85a7b 100644 --- a/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.cc +++ b/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.cc @@ -27,6 +27,12 @@ aclrtDestroyContextFunObj aclrtDestroyContext_ = nullptr; aclrtDestroyEventFunObj aclrtDestroyEvent_ = nullptr; aclrtDestroyStreamFunObj aclrtDestroyStream_ = nullptr; aclrtDestroyStreamForceFunObj aclrtDestroyStreamForce_ = nullptr; +aclrtGetDeviceResLimitFunObj aclrtGetDeviceResLimit_ = nullptr; +aclrtSetDeviceResLimitFunObj aclrtSetDeviceResLimit_ = nullptr; +aclrtGetStreamResLimitFunObj aclrtGetStreamResLimit_ = nullptr; +aclrtSetStreamResLimitFunObj aclrtSetStreamResLimit_ = nullptr; +aclrtResetStreamResLimitFunObj aclrtResetStreamResLimit_ = nullptr; +aclrtUseStreamResInCurrentThreadFunObj aclrtUseStreamResInCurrentThread_ = nullptr; aclrtEventElapsedTimeFunObj aclrtEventElapsedTime_ = nullptr; aclrtFreeFunObj aclrtFree_ = nullptr; aclrtFreeHostFunObj aclrtFreeHost_ = nullptr; @@ -96,6 +102,12 @@ void LoadAclRtApiSymbol(const std::string &ascend_path) { aclrtCreateEventWithFlag_ = DlsymAscendFuncObj(aclrtCreateEventWithFlag, handler); aclrtCreateEventExWithFlag_ = DlsymAscendFuncObj(aclrtCreateEventExWithFlag, handler); aclrtCreateStreamWithConfig_ = DlsymAscendFuncObj(aclrtCreateStreamWithConfig, handler); + aclrtGetDeviceResLimit_ = DlsymAscendFuncObj(aclrtGetDeviceResLimit, handler); + aclrtSetDeviceResLimit_ = DlsymAscendFuncObj(aclrtSetDeviceResLimit, handler); + aclrtGetStreamResLimit_ = DlsymAscendFuncObj(aclrtGetStreamResLimit, handler); + aclrtSetStreamResLimit_ = DlsymAscendFuncObj(aclrtSetStreamResLimit, handler); + aclrtResetStreamResLimit_ = DlsymAscendFuncObj(aclrtResetStreamResLimit, handler); + aclrtUseStreamResInCurrentThread_ = DlsymAscendFuncObj(aclrtUseStreamResInCurrentThread, handler); aclrtDestroyContext_ = DlsymAscendFuncObj(aclrtDestroyContext, handler); aclrtDestroyEvent_ = DlsymAscendFuncObj(aclrtDestroyEvent, handler); aclrtDestroyStream_ = DlsymAscendFuncObj(aclrtDestroyStream, handler); @@ -165,6 +177,11 @@ void LoadSimulationRtApi() { ASSIGN_SIMU(aclrtCreateEventWithFlag); ASSIGN_SIMU(aclrtCreateEventExWithFlag); ASSIGN_SIMU(aclrtCreateStreamWithConfig); + ASSIGN_SIMU(aclrtGetDeviceResLimit); + ASSIGN_SIMU(aclrtSetDeviceResLimit); + ASSIGN_SIMU(aclrtSetStreamResLimit); + ASSIGN_SIMU(aclrtSetStreamResLimit); + ASSIGN_SIMU(aclrtResetStreamResLimit); ASSIGN_SIMU(aclrtDestroyContext); ASSIGN_SIMU(aclrtDestroyEvent); ASSIGN_SIMU(aclrtDestroyStream); diff --git a/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.h b/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.h index a9ea2f0191c703d8ccc88389522a6958dc93ebe5..47ec542a0d639c78a323108647fac0706906d033 100644 --- a/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.h +++ b/mindspore/ccsrc/plugin/ascend/res_manager/symbol_interface/acl_rt_symbol.h @@ -25,6 +25,12 @@ ORIGIN_METHOD_WITH_SIMU_CREATE(aclrtCreateEvent, aclError, aclrtEvent *) ORIGIN_METHOD_WITH_SIMU_CREATE(aclrtCreateEventWithFlag, aclError, aclrtEvent *, uint32_t) ORIGIN_METHOD_WITH_SIMU_CREATE(aclrtCreateEventExWithFlag, aclError, aclrtEvent *, uint32_t) ORIGIN_METHOD_WITH_SIMU_CREATE(aclrtCreateStreamWithConfig, aclError, aclrtStream *, uint32_t, uint32_t) +ORIGIN_METHOD_WITH_SIMU(aclrtGetDeviceResLimit, aclError, int32_t, aclrtDevResLimitType, uint32_t *) +ORIGIN_METHOD_WITH_SIMU(aclrtSetDeviceResLimit, aclError, int32_t, aclrtDevResLimitType, uint32_t) +ORIGIN_METHOD_WITH_SIMU(aclrtGetStreamResLimit, aclError, aclrtStream, aclrtDevResLimitType, uint32_t *) +ORIGIN_METHOD_WITH_SIMU(aclrtSetStreamResLimit, aclError, aclrtStream, aclrtDevResLimitType, uint32_t) +ORIGIN_METHOD_WITH_SIMU(aclrtResetStreamResLimit, aclError, aclrtStream) +ORIGIN_METHOD_WITH_SIMU(aclrtUseStreamResInCurrentThread, aclError, aclrtStream) ORIGIN_METHOD_WITH_SIMU(aclrtDestroyContext, aclError, aclrtContext) ORIGIN_METHOD_WITH_SIMU(aclrtDestroyEvent, aclError, aclrtEvent) ORIGIN_METHOD_WITH_SIMU(aclrtDestroyStream, aclError, aclrtStream) diff --git a/mindspore/ccsrc/pybind_api/init.cc b/mindspore/ccsrc/pybind_api/init.cc index aadf8578de0865e4dd74527de510314515c6f324..0737988bc2702e626b3c93c08ff5af05f1e4cae7 100644 --- a/mindspore/ccsrc/pybind_api/init.cc +++ b/mindspore/ccsrc/pybind_api/init.cc @@ -196,6 +196,7 @@ void RegModule(py::module *m) { mindspore::graph::RegCustomPass(m); mindspore::hal::RegStream(m); mindspore::hal::RegEvent(m); + mindspore::hal::RegResLimit(m); mindspore::hal::RegCommHandle(m); mindspore::hal::RegMemory(m); mindspore::hal::RegUtils(m); diff --git a/mindspore/ccsrc/pybind_api/runtime/res_limit_py.cc b/mindspore/ccsrc/pybind_api/runtime/res_limit_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a9d4325edefb308018f61da1481c6026bd8b521 --- /dev/null +++ b/mindspore/ccsrc/pybind_api/runtime/res_limit_py.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind_api/runtime/res_limit_py.h" +#include +#include +#include +#include "runtime/pynative/op_executor.h" +#include "runtime/pipeline/pipeline.h" +#include "runtime/hardware_abstract/device_context/device_context_manager.h" +#include "runtime/hardware_abstract/stream/multi_stream_controller.h" +#include "utils/ms_context.h" +#include "include/common/pybind_api/api_register.h" +#include "pybind_api/runtime/utils_py.h" +#include "pynative/utils/pynative_utils.h" +#include "utils/stream_guard.h" +#include "utils/ms_exception.h" + +namespace mindspore { +namespace hal { +py::tuple GetDeviceLimit(int32_t device_id) { + runtime::Pipeline::Get().WaitForward(); + uint32_t cube_num; + uint32_t vector_num; + const auto &device_name = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); + auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {device::GetDeviceTypeByName(device_name), static_cast(device_id)}); + device_context->device_res_manager_->GetDeviceLimit(device_id, &cube_num, &vector_num); + return py::make_tuple(cube_num, vector_num); +} + +void SetDeviceLimit(int32_t device_id, int32_t cube_num, int32_t vector_num) { + runtime::Pipeline::Get().WaitForward(); + const auto &device_name = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_TARGET); + auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {device::GetDeviceTypeByName(device_name), static_cast(device_id)}); + device_context->device_res_manager_->SetDeviceLimit(device_id, cube_num, vector_num); +} + +void ResetStreamLimit(const StreamPyPtr &stream) { + MS_EXCEPTION_IF_NULL(stream); + runtime::Pipeline::Get().WaitForward(); + MS_LOG(DEBUG) << "stream_id:" << stream->stream_id(); + stream->device_ctx()->device_res_manager_->ResetStreamLimit(stream->stream_id()); +} + +py::tuple GetStreamLimit(const StreamPyPtr &stream) { + MS_EXCEPTION_IF_NULL(stream); + runtime::Pipeline::Get().WaitForward(); + MS_LOG(DEBUG) << "stream_id:" << stream->stream_id(); + uint32_t cube_num; + uint32_t vector_num; + stream->device_ctx()->device_res_manager_->GetStreamLimit(stream->stream_id(), &cube_num, &vector_num); + return py::make_tuple(cube_num, vector_num); +} + +void SetStreamLimit(const StreamPyPtr &stream, int32_t cube_num, int32_t vector_num) { + MS_EXCEPTION_IF_NULL(stream); + // runtime::Pipeline::Get().WaitForward(); + MS_LOG(DEBUG) << "stream_id:" << stream->stream_id(); + // stream->device_ctx()->device_res_manager_->SetStreamLimit(stream->stream_id(), cube_num, vector_num); + DispatchSetStreamLimitTask(stream, cube_num, vector_num); +} + +void DispatchSetStreamLimitTask(const StreamPyPtr &stream, int32_t cube_num, int32_t vector_num) { + // Wait event async. + pynative::DispatchOp(std::make_shared([stream, cube_num, vector_num]() { + auto wait_fn = [stream, cube_num, vector_num]() { + auto stream_id = stream->stream_id(); + MS_LOG(DEBUG) << "WaitEvent wait stream id:" << stream_id << ", cube_num:" << cube_num + << ", vectore_num:" << vector_num; + auto device_ctx = stream->device_ctx(); + MS_EXCEPTION_IF_NULL(device_ctx); + runtime::OpExecutor::DispatchLaunchTask([stream_id, cube_num, vector_num, device_ctx]() { + device_ctx->device_res_manager_->SetStreamLimit(stream_id, cube_num, vector_num); + }); + }; + if (!runtime::OpExecutor::NeedSync()) { + runtime::OpExecutor::GetInstance().PushSimpleOpRunTask( + std::make_shared(wait_fn)); + } else { + wait_fn(); + } + })); +} + +void RegResLimit(py::module *m) { + (void)m->def("get_device_limit", &mindspore::hal::GetDeviceLimit, "Get device limit"); + (void)m->def("set_device_limit", &mindspore::hal::SetDeviceLimit, "Set device limit"); + (void)m->def("get_stream_limit", &mindspore::hal::GetStreamLimit, "Get stream limit"); + (void)m->def("set_stream_limit", &mindspore::hal::SetStreamLimit, "Set stream limit"); + (void)m->def("reset_stream_limit", &mindspore::hal::ResetStreamLimit, "Reset stream limit"); +} +} // namespace hal +} // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/runtime/res_limit_py.h b/mindspore/ccsrc/pybind_api/runtime/res_limit_py.h new file mode 100644 index 0000000000000000000000000000000000000000..1ed89cfd51b710b128e01f5afadea40af5b0e8ae --- /dev/null +++ b/mindspore/ccsrc/pybind_api/runtime/res_limit_py.h @@ -0,0 +1,35 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PYBIND_API_HAL_RES_LIMIT_PY_H +#define MINDSPORE_CCSRC_PYBIND_API_HAL_RES_LIMIT_PY_H +#include "pybind11/pybind11.h" +#include "pybind_api/runtime/stream_py.h" +#include "runtime/hardware_abstract/device_context/device_context.h" + +namespace mindspore { +namespace hal { +namespace py = pybind11; +py::tuple GetDeviceLimit(int32_t device_id); +void SetDeviceLimit(int32_t device_id, int32_t cube_num, int32_t vector_num); +py::tuple GetStreamLimit(const StreamPyPtr &stream); +void SetStreamLimit(const StreamPyPtr &stream, int32_t cube_num, int32_t vector_num); +void ResetStreamLimit(const StreamPyPtr &stream); +void DispatchSetStreamLimitTask(const StreamPyPtr &stream, int32_t cube_num, int32_t vector_num); +} // namespace hal +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PYBIND_API_HAL_RES_LIMIT_PY_H diff --git a/mindspore/ccsrc/runtime/hardware_abstract/device_context/device_context.h b/mindspore/ccsrc/runtime/hardware_abstract/device_context/device_context.h index ce49263a9b17ae44d0abd4d551e722b32e4cb0f9..76610439e6b7030c6f2f290a52792451622b6955 100644 --- a/mindspore/ccsrc/runtime/hardware_abstract/device_context/device_context.h +++ b/mindspore/ccsrc/runtime/hardware_abstract/device_context/device_context.h @@ -294,7 +294,12 @@ class RUNTIME_HARDWARE_EXPORT DeviceResManager { virtual DeviceEventPtr CreateEventWithFlag(bool enable_timing, bool blocking, bool use_extensional_api = true) { return nullptr; } - + virtual void GetDeviceLimit(int32_t device_id, uint32_t *cube_num, uint32_t *vector_num) {} + virtual void SetDeviceLimit(int32_t device_id, int32_t cube_num, int32_t vector_num) {} + virtual void GetStreamLimit(size_t stream_id, uint32_t *cube_num, uint32_t *vector_num) {} + virtual void SetStreamLimit(size_t stream_id, int32_t cube_num, int32_t vector_num) {} + virtual void ResetStreamLimit(size_t stream_id) {} + virtual void UseStreamResInCurrentThread(size_t stream_id) {} // Destroy specified device event. virtual bool DestroyEvent(const DeviceEventPtr &event) { return true; } diff --git a/mindspore/ops/kernel/ascend/aclnn/pyboost_impl/aclnn_utils.h b/mindspore/ops/kernel/ascend/aclnn/pyboost_impl/aclnn_utils.h index 54bc478280c003808beb41ea8c574601dc994b4b..59019152a24de0b239388246dc18a108b67b30d1 100644 --- a/mindspore/ops/kernel/ascend/aclnn/pyboost_impl/aclnn_utils.h +++ b/mindspore/ops/kernel/ascend/aclnn/pyboost_impl/aclnn_utils.h @@ -147,6 +147,7 @@ using CacheTuple = std::tupledevice_res_manager_->UseStreamResInCurrentThread(stream_id); \ mindspore::runtime::ProfilerRecorder aclnn_profiler(mindspore::runtime::ProfilerModule::kPynative, \ mindspore::runtime::ProfilerEvent::kPyBoostLaunchAclnn, \ aclnn_name, false); \ @@ -221,6 +222,7 @@ using CacheTuple = std::tupledevice_res_manager_->UseStreamResInCurrentThread(real_stream_id); \ runtime::ProfilerRecorder aclnn_profiler(runtime::ProfilerModule::kPynative, \ runtime::ProfilerEvent::kPyBoostLaunchAclnn, aclnn_name, false); \ auto stream_ptr = device_context->device_res_manager_->GetStream(real_stream_id); \ diff --git a/mindspore/python/mindspore/runtime/__init__.py b/mindspore/python/mindspore/runtime/__init__.py index 7e7ec620c2e348713fc55792351903da9d8048c5..73c971936beea2bfb902df25f34959ee80c1efb4 100644 --- a/mindspore/python/mindspore/runtime/__init__.py +++ b/mindspore/python/mindspore/runtime/__init__.py @@ -26,15 +26,20 @@ from mindspore.runtime.memory import set_memory, memory_stats, memory_reserved, from mindspore.runtime.stream import Stream, synchronize, set_cur_stream, current_stream, \ default_stream, communication_stream, StreamCtx from mindspore.runtime.event import Event +from mindspore.runtime.res_limit import get_device_limit, set_device_limit, get_stream_limit, set_stream_limit,\ + reset_stream_limit, StreamLimitCtx from .executor import launch_blocking + __all__ = [ "launch_blocking", "dispatch_threads_num", "set_cpu_affinity", "set_kernel_launch_group", "set_kernel_launch_capture", "Stream", "communication_stream", "synchronize", "set_cur_stream", "current_stream", "default_stream", "StreamCtx", "set_memory", "memory_stats", "memory_reserved", "max_memory_reserved", "empty_cache", "memory_replay", "reset_peak_memory_stats", "memory_summary", "memory_allocated", "max_memory_allocated", - "reset_max_memory_reserved", "reset_max_memory_allocated", "Event", "PluggableAllocator", "MemPool", "use_mem_pool" + "reset_max_memory_reserved", "reset_max_memory_allocated", "Event", "PluggableAllocator", "MemPool", + "use_mem_pool", "get_device_limit", "set_device_limit", "get_stream_limit", "set_stream_limit", + "reset_stream_limit", "StreamLimitCtx" ] __all__.sort() diff --git a/mindspore/python/mindspore/runtime/res_limit.py b/mindspore/python/mindspore/runtime/res_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..2d225910e10308294d3908c8fff762a6b9863ba2 --- /dev/null +++ b/mindspore/python/mindspore/runtime/res_limit.py @@ -0,0 +1,184 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Runtime stream class""" +from mindspore._c_expression import get_device_limit as get_device_limit_ +from mindspore._c_expression import set_device_limit as set_device_limit_ +from mindspore._c_expression import get_stream_limit as get_stream_limit_ +from mindspore._c_expression import set_stream_limit as set_stream_limit_ +from mindspore._c_expression import reset_stream_limit as reset_stream_limit_ +from .stream import Stream + +def get_device_limit(device): + r""" + Return current stream limit core num. + + Args: + device (int): selected device id. + + Returns: + core num (int), stream limit core num. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> ms.runtime.get_stream_limit(0) + """ + cube_num, vector_num = get_device_limit_(device) + return {"cube_core_num": cube_num, "vector_core_num": vector_num} + + +def set_device_limit(device, cube_num=-1, vector_num=-1): + r""" + Sets the stream limit. + + Args: + device (int): set device id. + cube_num (int): set aic num for steam. + vector_num (int): set aiv num for steam. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> ms.runtime.set_stream_limit(0, 8, 8) + """ + set_device_limit_(device, cube_num, vector_num) + + +def get_stream_limit(stream): + r""" + Return current stream limit core num. + + Args: + stream (Stream): selected stream. + + Returns: + core num (int), stream limit core num. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> ms.runtime.get_stream_limit(ms.runtime.default_stream()) + """ + if not isinstance(stream, Stream): + raise TypeError( + f"For 'get_stream_limit', the argument 'stream' should be Stream," + f" but got {type(stream)}." + ) + cube_num, vector_num = get_stream_limit_(stream) + return {"cube_core_num": cube_num, "vector_core_num": vector_num} + + +def set_stream_limit(stream, cube_num=-1, vector_num=-1): + r""" + Sets the stream limit. + + Args: + stream (Stream): set stream id. + cube_num (int): set aic num for steam. + vector_num (int): set aiv num for steam. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> s1 = ms.runtime.Stream() + >>> ms.runtime.set_stream_limit(s1, 8, 8) + """ + if not isinstance(stream, Stream): + raise TypeError( + f"For 'set_stream_limit', the argument 'stream' should be Stream," + f" but got {type(stream)}." + ) + set_stream_limit_(stream, cube_num, vector_num) + + +def reset_stream_limit(stream): + r""" + Resets the stream limit. + + Args: + stream (Stream): set stream id. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> s1 = ms.runtime.Stream() + >>> ms.runtime.reset_stream_limit(s1) + """ + if not isinstance(stream, Stream): + raise TypeError( + f"For 'set_stream_limit', the argument 'stream' should be Stream," + f" but got {type(stream)}." + ) + reset_stream_limit_(stream) + + +class StreamLimitCtx: + r""" + Context-manager that selects a given stream. + + All kernels queued within its context will be enqueued on a selected + stream. + + Args: + ctx_stream (Stream): selected stream. + cube_num (int): set aic num for steam. + vector_num (int): set aiv num for steam. + + Raises: + TypeError: If 'stream' is neither a :class:`mindspore.runtime.Stream` nor a ``None``. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> from mindspore import Tensor, ops + >>> ms.set_device("Ascend", 0) + >>> a = Tensor(np.ones([1024, 2048]), ms.float32) + >>> b = Tensor(np.ones([2048, 4096]), ms.float32) + >>> s1 = ms.runtime.Stream() + >>> with ms.runtime.StreamLimitCtx(s1, 8, 8): + ... c = ops.matmul(a, b) + >>> ms.runtime.synchronize() + """ + def __init__(self, stream, cube_num=-1, vector_num=-1): + if not isinstance(stream, Stream): + raise TypeError( + f"For 'StreamLimitCtx', the argument 'stream' should be Stream," + f" but got {type(stream)}." + ) + self.stream = stream + self.cube_num = cube_num + self.vector_num = vector_num + self.prev_cube_num = -1 + self.prev_vector_num = -1 + + def __enter__(self): + self.prev_cube_num, self.prev_vector_num = get_stream_limit_(self.stream) + set_stream_limit_(self.stream, self.cube_num, self.vector_num) + + def __exit__(self, exc_type, exc_val, exc_tb): + set_stream_limit_(self.stream, self.prev_cube_num, self.prev_vector_num) diff --git a/tests/st/runtime/test_res_limit.py b/tests/st/runtime/test_res_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5f978243cda58a9322df29b0c772006e17e45d --- /dev/null +++ b/tests/st/runtime/test_res_limit.py @@ -0,0 +1,182 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.context as context +from mindspore import Tensor, ops +import mindspore as ms +from tests.mark_utils import arg_mark +ms.set_device("Ascend") +context.set_context(mode=context.PYNATIVE_MODE) + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', + card_mark='onecard', essential_mark='essential') +def test_runtime_get_device_limit(): + """ + Feature: runtime stream api. + Description: Test runtime.get_device_limit api. + Expectation: runtime.get_device_limit api performs as expected. + """ + ret = ms.runtime.get_device_limit(0) + print(ret) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + ms.runtime.synchronize() + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', + card_mark='onecard', essential_mark='essential') +def test_runtime_set_device_limit(): + """ + Feature: runtime stream api. + Description: Test runtime.set_device_limit api. + Expectation: runtime.set_device_limit api performs as expected. + """ + ret = ms.runtime.get_device_limit(0) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + + ms.runtime.set_device_limit(0, 8, 8) + ret = ms.runtime.get_device_limit(0) + assert ret["cube_core_num"] == 8 + assert ret["vector_core_num"] == 8 + + ms.runtime.set_device_limit(0, 4, 4) + ret = ms.runtime.get_device_limit(0) + assert ret["cube_core_num"] == 4 + assert ret["vector_core_num"] == 4 + + ms.runtime.set_device_limit(0, 3, -1) + ret = ms.runtime.get_device_limit(0) + assert ret["cube_core_num"] == 3 + assert ret["vector_core_num"] == 4 + + ms.runtime.set_device_limit(0, -1, 3) + ret = ms.runtime.get_device_limit(0) + assert ret["cube_core_num"] == 3 + assert ret["vector_core_num"] == 3 + + ms.runtime.set_device_limit(0, -1, -1) + ret = ms.runtime.get_device_limit(0) + assert ret["cube_core_num"] == 3 + assert ret["vector_core_num"] == 3 + ms.runtime.synchronize() + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', + card_mark='onecard', essential_mark='essential') +def test_runtime_get_stream_limit(): + """ + Feature: runtime stream api. + Description: Test runtime.get_stream_limit api. + Expectation: runtime.get_stream_limit api performs as expected. + """ + s1 = ms.runtime.Stream() + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + ms.runtime.synchronize() + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', + card_mark='onecard', essential_mark='essential') +def test_runtime_set_stream_limit(): + """ + Feature: runtime stream api. + Description: Test runtime.set_stream_limit api. + Expectation: runtime.set_stream_limit api performs as expected. + """ + s1 = ms.runtime.Stream() + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + + ms.runtime.set_stream_limit(s1, 8, 8) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 8 + assert ret["vector_core_num"] == 8 + + ms.runtime.set_stream_limit(s1, 4, 4) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 4 + assert ret["vector_core_num"] == 4 + + ms.runtime.set_stream_limit(s1, 3, -1) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 3 + assert ret["vector_core_num"] == 4 + + ms.runtime.set_stream_limit(s1, -1, 3) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 3 + assert ret["vector_core_num"] == 3 + + ms.runtime.set_stream_limit(s1, -1, -1) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 3 + assert ret["vector_core_num"] == 3 + ms.runtime.synchronize() + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', + card_mark='onecard', essential_mark='essential') +def test_runtime_reset_stream_limit(): + """ + Feature: runtime stream api. + Description: Test runtime.reset_stream_limit api. + Expectation: runtime.reset_stream_limit api performs as expected. + """ + s1 = ms.runtime.Stream() + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + + ms.runtime.set_stream_limit(s1, 8, 8) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 8 + assert ret["vector_core_num"] == 8 + + ms.runtime.reset_stream_limit(s1) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + ms.runtime.synchronize() + + +@arg_mark(plat_marks=['platform_ascend910b'], level_mark='level1', + card_mark='onecard', essential_mark='essential') +def test_runtime_stream_limit_ctx(): + """ + Feature: runtime stream api. + Description: Test runtime.StreamLimitCtx api. + Expectation: runtime.StreamLimitCtx api performs as expected. + """ + a = Tensor(2.0) + s1 = ms.runtime.Stream() + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + with ms.runtime.StreamCtx(s1): + with ms.runtime.StreamLimitCtx(s1, 8, 8): + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 8 + assert ret["vector_core_num"] == 8 + ops.abs(a) + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 8 + assert ret["vector_core_num"] == 8 + ret = ms.runtime.get_stream_limit(s1) + assert ret["cube_core_num"] == 24 + assert ret["vector_core_num"] == 48 + s1.synchronize()