From 693bfba20bc1c16f1ec7faf27999f483615e52fa Mon Sep 17 00:00:00 2001 From: liuzihan000 Date: Tue, 2 Sep 2025 16:50:18 +0800 Subject: [PATCH] add collective add all_gather op --- inferrt/python/mrt/collective.py | 20 + inferrt/python/mrt/torch/fx_backend.py | 60 ++- inferrt/src/hardware/CMakeLists.txt | 2 + inferrt/src/hardware/ascend/CMakeLists.txt | 7 + .../hardware/ascend/collective/CMakeLists.txt | 0 .../ascend/res_manager/ascend_res_manager.cc | 25 +- .../ascend/res_manager/ascend_res_manager.h | 4 + .../mem_manager/ascend_memory_adapter.cc | 16 +- .../mem_manager/ascend_vmm_adapter.cc | 9 +- .../collective/collective_manager.cc | 81 +++ .../collective/collective_manager.h | 72 +++ .../collective/communication_group.cc | 36 ++ .../collective/communication_group.h | 51 ++ .../device_context_manager.cc | 5 +- inferrt/src/ir/tensor/storage.cc | 1 - inferrt/src/ops/ascend/CMakeLists.txt | 3 + inferrt/src/ops/ascend/hccl/CMakeLists.txt | 18 + inferrt/src/ops/ascend/hccl/hccl_adapter.cc | 472 ++++++++++++++++++ inferrt/src/ops/ascend/hccl/hccl_adapter.h | 208 ++++++++ .../src/ops/ascend/hccl/hccl_all_gather.cc | 64 +++ inferrt/src/ops/ascend/hccl/hccl_all_gather.h | 44 ++ inferrt/src/ops/ascend/hccl/hccl_kernel.cc | 32 ++ inferrt/src/ops/ascend/hccl/hccl_kernel.h | 55 ++ inferrt/src/ops/ascend/hccl/hccl_plugin.h | 89 ++++ inferrt/src/ops/ascend/hccl/hcom_utils.cc | 153 ++++++ inferrt/src/ops/ascend/hccl/hcom_utils.h | 130 +++++ inferrt/src/ops/ascend/hccl/tensor_copy.cc | 64 +++ inferrt/src/ops/ascend/hccl/tensor_copy.h | 44 ++ inferrt/src/ops/ascend/hccl/wait_tensor.cc | 64 +++ inferrt/src/ops/ascend/hccl/wait_tensor.h | 44 ++ inferrt/src/ops/op_base/op_all_gather.cc | 42 ++ inferrt/src/ops/op_base/op_all_gather.h | 37 ++ inferrt/src/ops/op_def/ops.list | 3 + inferrt/src/pybind/CMakeLists.txt | 1 - inferrt/src/pybind/mrt/CMakeLists.txt | 29 +- inferrt/src/pybind/mrt/pybind11_collective.cc | 40 ++ inferrt/src/pybind/mrt_torch/CMakeLists.txt | 11 +- setup.py | 1 + tests/st/check/check_distributed_backend.py | 107 ++++ 39 files changed, 2118 insertions(+), 26 deletions(-) create mode 100644 inferrt/python/mrt/collective.py delete mode 100644 inferrt/src/hardware/ascend/collective/CMakeLists.txt create mode 100644 inferrt/src/hardware/hardware_abstract/collective/collective_manager.cc create mode 100644 inferrt/src/hardware/hardware_abstract/collective/collective_manager.h create mode 100644 inferrt/src/hardware/hardware_abstract/collective/communication_group.cc create mode 100644 inferrt/src/hardware/hardware_abstract/collective/communication_group.h create mode 100644 inferrt/src/ops/ascend/hccl/CMakeLists.txt create mode 100644 inferrt/src/ops/ascend/hccl/hccl_adapter.cc create mode 100644 inferrt/src/ops/ascend/hccl/hccl_adapter.h create mode 100644 inferrt/src/ops/ascend/hccl/hccl_all_gather.cc create mode 100644 inferrt/src/ops/ascend/hccl/hccl_all_gather.h create mode 100644 inferrt/src/ops/ascend/hccl/hccl_kernel.cc create mode 100644 inferrt/src/ops/ascend/hccl/hccl_kernel.h create mode 100644 inferrt/src/ops/ascend/hccl/hccl_plugin.h create mode 100644 inferrt/src/ops/ascend/hccl/hcom_utils.cc create mode 100644 inferrt/src/ops/ascend/hccl/hcom_utils.h create mode 100644 inferrt/src/ops/ascend/hccl/tensor_copy.cc create mode 100644 inferrt/src/ops/ascend/hccl/tensor_copy.h create mode 100644 inferrt/src/ops/ascend/hccl/wait_tensor.cc create mode 100644 inferrt/src/ops/ascend/hccl/wait_tensor.h create mode 100644 inferrt/src/ops/op_base/op_all_gather.cc create mode 100644 inferrt/src/ops/op_base/op_all_gather.h create mode 100644 inferrt/src/pybind/mrt/pybind11_collective.cc create mode 100644 tests/st/check/check_distributed_backend.py diff --git a/inferrt/python/mrt/collective.py b/inferrt/python/mrt/collective.py new file mode 100644 index 00000000..9abebeac --- /dev/null +++ b/inferrt/python/mrt/collective.py @@ -0,0 +1,20 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +CollectiveManager is a class to manage communication groups and global distributed information. +''' +from mrt._mrt_collective import CollectiveManager + +__all__ = ['CollectiveManager'] diff --git a/inferrt/python/mrt/torch/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py index e436f686..5cab5425 100644 --- a/inferrt/python/mrt/torch/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -1,16 +1,19 @@ """ A simple torch.fx backend that converts a GraphModule to a mrt GraphExecutor. """ - +import os import operator from typing import List, Dict, Any import sympy import torch from torch.fx.node import Node from torch.fx.graph_module import GraphModule +from torch._C._distributed_c10d import _resolve_process_group +from torch import distributed as dist from mrt.ir import GraphExecutor, Op from mrt.torch.utils import from_torch, to_torch, update_tensor_data +from mrt._mrt_collective import CollectiveManager _GLOBAL_GRAPH_ID = 0 @@ -21,7 +24,7 @@ def _next_unique_graph_id(): _GLOBAL_GRAPH_ID += 1 return _GLOBAL_GRAPH_ID - +# pylint: disable=protected-access # A comprehensive mapping from torch fx ops to our custom ops. _OP_MAP = { # torch functions @@ -38,6 +41,8 @@ _OP_MAP = { torch.rsqrt: Op.rsqrt, torch.relu: Op.relu, torch.sigmoid: Op.sigmoid, + torch.ops._c10d_functional.all_gather_into_tensor: Op.all_gather, + torch.ops._c10d_functional.wait_tensor: Op.wait_tensor, # torch.nn.functional torch.nn.functional.relu: Op.relu, torch.nn.functional.sigmoid: Op.sigmoid, @@ -66,8 +71,12 @@ _OP_MAP = { "square": Op.square, "rsqrt": Op.rsqrt, "view": Op.reshape, # view is often used like reshape + "copy_": Op.copy, } +_DIST_OP_LIST = [ + Op.all_gather +] def _get_op(target): """Get the corresponding Op enum for a given target.""" @@ -86,6 +95,44 @@ def _get_op(target): return None +def _extract_global_comm_info(): + rank = dist.get_rank() if dist.is_initialized() else 0 + local_rank = int(os.getenv('LOCAL_RANK', "0")) + world_size = dist.get_world_size() + + CollectiveManager.instance().set_global_rank_id(rank) + CollectiveManager.instance().set_local_rank_id(local_rank) + CollectiveManager.instance().set_global_rank_size(world_size) + + +def _set_communication_info(ptd): + '''Get communication info from torch and set to CollectiveManager for a given process group.''' + pg = _resolve_process_group(ptd) + rank = dist.get_rank() if dist.is_initialized() else 0 + local_rank = int(os.getenv('LOCAL_RANK', "0")) + world_size = dist.get_world_size() + + group_rank = dist.get_rank(pg) + rank_list = dist.get_process_group_ranks(pg) + + hccl_comm_handle = pg._get_backend(torch.device('npu')).get_hccl_comm(rank) + + CollectiveManager.instance().set_global_rank_id(rank) + CollectiveManager.instance().set_local_rank_id(local_rank) + CollectiveManager.instance().set_global_rank_size(world_size) + + CollectiveManager.instance().create_communication_group(f"{ptd}", rank_list, group_rank, hccl_comm_handle) + + +def _extract_and_setup_comm_groups(node_args): + ptd_arg = node_args[2] + + if CollectiveManager.instance().is_group_exist(f"{ptd_arg}"): + return + + _set_communication_info(ptd_arg) + + def _map_args(args, env, executor: GraphExecutor) -> List[Node]: """ Map torch.fx node arguments to GraphExecutor nodes. @@ -115,9 +162,16 @@ def backend(gm: GraphModule, example_inputs: List[torch.Tensor]): executor = GraphExecutor(f"fx_graph_{_next_unique_graph_id()}") env: Dict[Node, Any] = {} - input_iterator = iter(example_inputs) + if dist.is_initialized(): + _extract_global_comm_info() + for node in gm.graph.nodes: + if node.op in ("call_function", "call_method"): + op = _get_op(node.target) + if op in _DIST_OP_LIST: + _extract_and_setup_comm_groups(node.args) + for node in gm.graph.nodes: if node.op == "placeholder": env[node] = executor.add_value_node(from_torch(next(input_iterator))) diff --git a/inferrt/src/hardware/CMakeLists.txt b/inferrt/src/hardware/CMakeLists.txt index ed0af482..20fa242c 100644 --- a/inferrt/src/hardware/CMakeLists.txt +++ b/inferrt/src/hardware/CMakeLists.txt @@ -1,4 +1,6 @@ check_debug_log_out() + + add_subdirectory(hardware_abstract) if(ENABLE_ASCEND) add_subdirectory(ascend) diff --git a/inferrt/src/hardware/ascend/CMakeLists.txt b/inferrt/src/hardware/ascend/CMakeLists.txt index 92b22319..675e6b42 100644 --- a/inferrt/src/hardware/ascend/CMakeLists.txt +++ b/inferrt/src/hardware/ascend/CMakeLists.txt @@ -10,6 +10,7 @@ if(ENABLE_ASCEND) message("Note compile ascend path: ${ASCEND_PATH}") include_directories(${ASCEND_PATH}/latest/include/) include_directories(${ASCEND_PATH}/latest/lib64/) + include_directories(${ASCEND_PATH}/latest/aarch64-linux/include/experiment) find_library(ASCENDCL_LIB ascendcl PATHS ${ASCEND_PATH}/latest/lib64 REQUIRED @@ -20,6 +21,11 @@ if(ENABLE_ASCEND) REQUIRED NO_DEFAULT_PATH ) + find_library(HCCL hccl + PATHS ${ASCEND_PATH}/latest/lib64 + REQUIRED + NO_DEFAULT_PATH + ) link_directories(${ASCEND_PATH}/latest/lib64/) file(GLOB_RECURSE HARDWARE_ASCEND_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") @@ -28,6 +34,7 @@ if(ENABLE_ASCEND) hardware_abstract ${ASCENDCL_LIB} ${RUNTIME_LIB} + ${HCCL} dl ) endif() diff --git a/inferrt/src/hardware/ascend/collective/CMakeLists.txt b/inferrt/src/hardware/ascend/collective/CMakeLists.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.cc b/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.cc index 74fac61b..39884f7b 100644 --- a/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.cc +++ b/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.cc @@ -39,6 +39,8 @@ #include "hardware/ascend/res_manager/ascend_hal_manager.h" #include "common/common.h" +#include "hardware/hardware_abstract/collective/collective_manager.h" + namespace mrt { namespace device { namespace ascend { @@ -56,8 +58,7 @@ void AclrtLaunchCallback(void *userData) { } // namespace void AscendResManager::Initialize() { - // use 0 temporarily. - deviceId_ = 0; + deviceId_ = mrt::collective::CollectiveManager::Instance().local_rank_id(); if (initialized_) { AscendHalManager::GetInstance().SetContextForce(deviceId_); return; @@ -455,6 +456,26 @@ void AscendResManager::ResetStreamAndCtx() const { AscendStreamMng::GetInstance().CreateDefaultStream(); } +bool AscendResManager::MemcpyDeviceToDevice(void *dst, size_t dst_size, const void *src, size_t src_size, + aclrtStream stream) { + auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, dst, dst_size, src, src_size, ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != ACL_SUCCESS) { + LOG_ERROR << " call aclrtMemcpyAsync failed, ret:" << static_cast(ret); + return false; + } + return true; +} + +bool AscendResManager::MemcpyDeviceToHost(void *dst, size_t dst_size, const void *src, size_t src_size, + aclrtStream stream) { + auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, dst, dst_size, src, src_size, ACL_MEMCPY_DEVICE_TO_HOST, stream); + if (ret != ACL_SUCCESS) { + LOG_ERROR << " call aclrtMemcpyAsync failed, ret:" << static_cast(ret); + return false; + } + return true; +} + } // namespace ascend } // namespace device } // namespace mrt diff --git a/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.h b/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.h index c9c68d43..92197559 100644 --- a/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.h +++ b/inferrt/src/hardware/ascend/res_manager/ascend_res_manager.h @@ -119,6 +119,10 @@ class MRT_EXPORT AscendResManager : public DeviceResManager { void ResetStreamAndCtx() const override; + // Memcpy + static bool MemcpyDeviceToDevice(void *dst, size_t dst_size, const void *src, size_t src_size, aclrtStream stream); + static bool MemcpyDeviceToHost(void *dst, size_t dst_size, const void *src, size_t src_size, aclrtStream stream); + private: bool initialized_ = false; std::shared_ptr memManager_{nullptr}; diff --git a/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_memory_adapter.cc b/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_memory_adapter.cc index 67bec5b5..bc455220 100644 --- a/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_memory_adapter.cc +++ b/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_memory_adapter.cc @@ -20,6 +20,8 @@ #include "hardware/ascend/res_manager/mem_manager/ascend_vmm_adapter.h" #include "hardware/ascend/res_manager/symbol_interface/acl_rt_symbol.h" #include "hardware/ascend/res_manager/symbol_interface/symbol_utils.h" +#include "hardware/hardware_abstract/collective/collective_manager.h" + #include "common/common.h" namespace mrt { @@ -55,7 +57,9 @@ size_t AscendMemAdapter::GetRoundUpAlignSize(size_t inputSize) { size_t AscendMemAdapter::GetDeviceMemSizeFromContext() const { size_t sizeFromContext; - float totalDeviceMemory = 32.0f; + // float totalDeviceMemory = 32.0f; + // set to 0 temporary + float totalDeviceMemory = 0.0f; auto maxDeviceMemory = totalDeviceMemory; // if (context->ascend_soc_version() == kAscendVersion910b || context->ascend_soc_version() == kAscendVersion910_93) { // totalDeviceMemory = 64.0f; @@ -90,8 +94,7 @@ bool AscendMemAdapter::Initialize() { } if (deviceHbmFreeSize_ < LongToSize(DoubleToLong(deviceHbmTotalSize_ * kHalfRatio))) { - // use 0 temporarily. - unsigned int deviceId = 0; + unsigned int deviceId = mrt::collective::CollectiveManager::Instance().local_rank_id(); LOG_OUT << "Free memory size is less " "than half of total memory size." << "Device " << deviceId << " Device MOC total size:" << deviceHbmTotalSize_ @@ -100,9 +103,7 @@ bool AscendMemAdapter::Initialize() { } // get user define max backend memory - // Set the default value to 0 temporarily, and an API for configuration will be provided subsequently. - // auto userDefineMsSize = GetDeviceMemSizeFromContext(); - size_t userDefineMsSize = 0; + auto userDefineMsSize = GetDeviceMemSizeFromContext(); auto recommendMemSizeForOthers = LongToSize(DoubleToLong(deviceHbmFreeSize_ * kReservedMemoryRatio)); size_t reservedMemSizeForOthers; if (userDefineMsSize == 0) { @@ -258,8 +259,7 @@ uint8_t *AscendMemAdapter::MallocFromRts(size_t size) const { auto ret = CALL_ASCEND_API(aclrtMalloc, reinterpret_cast(&ptr), size, ACL_MEM_TYPE_HIGH_BAND_WIDTH); if (ret != ACL_RT_SUCCESS) { if (ret == ACL_ERROR_RT_MEMORY_ALLOCATION) { - // use 0 temporarily. - unsigned int deviceId = 0; + unsigned int deviceId = mrt::collective::CollectiveManager::Instance().local_rank_id(); size_t freeSize = 0; size_t total = 0; (void)CALL_ASCEND_API(aclrtGetMemInfo, ACL_HBM_MEM, &freeSize, &total); diff --git a/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_vmm_adapter.cc b/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_vmm_adapter.cc index b7e94aae..bc920124 100644 --- a/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_vmm_adapter.cc +++ b/inferrt/src/hardware/ascend/res_manager/mem_manager/ascend_vmm_adapter.cc @@ -22,6 +22,8 @@ #include "hardware/ascend/res_manager/symbol_interface/acl_rt_symbol.h" #include "common/common.h" +#include "hardware/hardware_abstract/collective/collective_manager.h" + namespace mrt { namespace device { namespace ascend { @@ -106,7 +108,8 @@ size_t AscendVmmAdapter::MmapDeviceMem(const size_t size, const DeviceMemPtr add LOG_OUT << "VMM MmapDeviceMem size:" << size << ", addr:" << addr << ", cachedHandleSets_ size : " << cachedHandleSets_.size() << "."; // use 0 temporarily - auto deviceId = 0; + auto local_rank_id = mrt::collective::CollectiveManager::Instance().local_rank_id(); + auto deviceId = local_rank_id; auto vmmStartAddr = FindVmmSegment(addr); if (vmmStartAddr == nullptr) { @@ -142,8 +145,8 @@ size_t AscendVmmAdapter::MmapDeviceMem(const size_t size, const DeviceMemPtr add cachedHandleSets_.erase(cachedHandleSets_.begin()); } else { if (physicalHandleSize_ * vmmAlignSize_ >= maxSize) { - LOG_OUT << "Mapped too much memory, physicalHandleSize_ : " << physicalHandleSize_ - << ", maxSize : " << maxSize << ", addr : " << addr << ", size : " << size << "."; + LOG_OUT << "Mapped too much memory, physicalHandleSize_ : " << physicalHandleSize_ << ", maxSize : " << maxSize + << ", addr : " << addr << ", size : " << size << "."; MoveBackMappedHandle(&mappedVmmHandle, &vmmMap_, &cachedHandleSets_); return 0; } diff --git a/inferrt/src/hardware/hardware_abstract/collective/collective_manager.cc b/inferrt/src/hardware/hardware_abstract/collective/collective_manager.cc new file mode 100644 index 00000000..904c9323 --- /dev/null +++ b/inferrt/src/hardware/hardware_abstract/collective/collective_manager.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hardware/hardware_abstract/collective/collective_manager.h" + +namespace mrt { +namespace collective { + +CollectiveManager::CollectiveManager() : globalRankId(0), localRankId(0), globalRankSize(0) {} + +CollectiveManager::~CollectiveManager() {} + +CollectiveManager &CollectiveManager::Instance() { + static CollectiveManager instance; + return instance; +} + +bool CollectiveManager::CreateCommunicationGroup(const std::string &groupName, const std::vector &groupRanks, + uint32_t groupRank, int64_t communicator) { + if (communicationGroups.find(groupName) != communicationGroups.end()) { + return false; + } + communicationGroups[groupName] = std::make_shared(groupName, groupRanks, groupRank, communicator); + return true; +} + +bool CollectiveManager::IsGroupExist(const std::string &groupName) { + return communicationGroups.find(groupName) != communicationGroups.end(); +} + +std::shared_ptr CollectiveManager::GetCommunicationGroup(const std::string &groupName) { + if (communicationGroups.find(groupName) == communicationGroups.end()) { + LOG_EXCEPTION << "can not find group for given group name " << groupName; + return nullptr; + } + return communicationGroups[groupName]; +} + +uint32_t CollectiveManager::GetGroupRank(const std::string &groupName) { + if (communicationGroups.find(groupName) == communicationGroups.end()) { + LOG_ERROR << "can not find group for given group name " << groupName; + return false; + } + return communicationGroups[groupName]->group_rank(); +} + +uint32_t CollectiveManager::GetGroupSize(const std::string &groupName) { + if (communicationGroups.find(groupName) == communicationGroups.end()) { + LOG_ERROR << "can not find group for given group name " << groupName; + return false; + } + return communicationGroups[groupName]->group_size(); +} + +void CollectiveManager::SetGlobalRankId(uint32_t globalRankId) { this->globalRankId = globalRankId; } + +void CollectiveManager::SetGlobalRankSize(uint32_t globalRankSize) { this->globalRankSize = globalRankSize; } + +void CollectiveManager::SetLocalRankId(uint32_t localRankId) { this->localRankId = localRankId; } + +uint32_t CollectiveManager::global_rank_id() const { return globalRankId; } + +uint32_t CollectiveManager::local_rank_id() const { return localRankId; } + +uint32_t CollectiveManager::global_rank_size() const { return globalRankSize; } + +} // namespace collective +} // namespace mrt diff --git a/inferrt/src/hardware/hardware_abstract/collective/collective_manager.h b/inferrt/src/hardware/hardware_abstract/collective/collective_manager.h new file mode 100644 index 00000000..37843e7a --- /dev/null +++ b/inferrt/src/hardware/hardware_abstract/collective/collective_manager.h @@ -0,0 +1,72 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INFERRT_SRC_HARDWARE_CLUSTER_COLLECTIVE_MANAGER_H_ +#define INFERRT_SRC_HARDWARE_CLUSTER_COLLECTIVE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/logger.h" +#include "common/visible.h" +#include "hardware/hardware_abstract/collective/communication_group.h" + +namespace mrt { +namespace collective { +class DA_API CollectiveManager { + public: + ~CollectiveManager(); + static CollectiveManager &Instance(); + + void SetGlobalRankId(uint32_t globalRankId); + void SetGlobalRankSize(uint32_t globalRankSize); + void SetLocalRankId(uint32_t localRankId); + + uint32_t global_rank_id() const; + uint32_t local_rank_id() const; + uint32_t global_rank_size() const; + + bool CreateCommunicationGroup(const std::string &groupName, const std::vector &groupRanks, + uint32_t groupRank, int64_t communicator); + bool IsGroupExist(const std::string &groupName); + std::shared_ptr GetCommunicationGroup(const std::string &groupName); + + uint32_t GetGroupRank(const std::string &groupName); + uint32_t GetGroupSize(const std::string &groupName); + + private: + CollectiveManager(); + CollectiveManager(const CollectiveManager &) = delete; + CollectiveManager &operator=(const CollectiveManager &) = delete; + + uint32_t globalRankId; + uint32_t localRankId; + uint32_t globalRankSize; + std::vector globalGroupRanks; + std::unordered_map> communicationGroups; +}; + +} // namespace collective +} // namespace mrt +#endif // INFERRT_SRC_HARDWARE_CLUSTER_COLLECTIVE_MANAGER_H_ diff --git a/inferrt/src/hardware/hardware_abstract/collective/communication_group.cc b/inferrt/src/hardware/hardware_abstract/collective/communication_group.cc new file mode 100644 index 00000000..56868d37 --- /dev/null +++ b/inferrt/src/hardware/hardware_abstract/collective/communication_group.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hardware/hardware_abstract/collective/communication_group.h" + +namespace mrt { +namespace collective { +CommunicationGroup::CommunicationGroup(const std::string &name, const std::vector &groupRanks, + uint32_t groupRank, int64_t comm) + : groupName(name), groupRanks(groupRanks), groupRank(groupRank), comm(comm) {} + +const std::string &CommunicationGroup::group_name() const { return groupName; } + +const std::vector &CommunicationGroup::group_ranks() const { return groupRanks; } + +uint32_t CommunicationGroup::group_size() const { return groupRanks.size(); } + +uint32_t CommunicationGroup::group_rank() const { return groupRank; } + +int64_t CommunicationGroup::communicator() const { return comm; } + +} // namespace collective +} // namespace mrt diff --git a/inferrt/src/hardware/hardware_abstract/collective/communication_group.h b/inferrt/src/hardware/hardware_abstract/collective/communication_group.h new file mode 100644 index 00000000..3e076017 --- /dev/null +++ b/inferrt/src/hardware/hardware_abstract/collective/communication_group.h @@ -0,0 +1,51 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INFERRT_SRC_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_ +#define INFERRT_SRC_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_ + +#include +#include +#include +#include "common/visible.h" + +namespace mrt { +namespace collective { + +class DA_API CommunicationGroup { + public: + explicit CommunicationGroup(const std::string &name, const std::vector &groupRanks, uint32_t groupRank, + int64_t comm); + + ~CommunicationGroup() = default; + + virtual const std::string &group_name() const; + virtual const std::vector &group_ranks() const; + virtual uint32_t group_rank() const; + virtual uint32_t group_size() const; + virtual int64_t communicator() const; + + protected: + std::string groupName; + std::vector groupRanks; + uint32_t groupRank; + int64_t comm; +}; + +using CommunicationGroupPtr = std::shared_ptr; +} // namespace collective +} // namespace mrt + +#endif // INFERRT_SRC_HARDWARE_COLLECTIVE_COMMUNICATION_GROUP_H_ diff --git a/inferrt/src/hardware/hardware_abstract/device_context_manager.cc b/inferrt/src/hardware/hardware_abstract/device_context_manager.cc index 9d134fd4..0e6f43a4 100644 --- a/inferrt/src/hardware/hardware_abstract/device_context_manager.cc +++ b/inferrt/src/hardware/hardware_abstract/device_context_manager.cc @@ -28,6 +28,8 @@ #include #include "hardware/hardware_abstract/dlopen_macro.h" #include "hardware/hardware_abstract/multi_stream_controller.h" +#include "hardware/hardware_abstract/collective/collective_manager.h" + #include "common/logger.h" namespace mrt { @@ -148,8 +150,7 @@ MultiStreamControllerPtr &DeviceContextManager::GetMultiStreamController(const s return iter->second; } LOG_ERROR << "Found multi stream controller failed, and try to initialize, deviceName : " << deviceName << "."; - // use 0 temporarily. - uint32_t deviceId = 0; + uint32_t deviceId = mrt::collective::CollectiveManager::Instance().local_rank_id(); DeviceContextKey hostKey = {deviceName, deviceId}; const auto &realDeviceContext = GetOrCreateDeviceContext(hostKey); if (realDeviceContext == nullptr) { diff --git a/inferrt/src/ir/tensor/storage.cc b/inferrt/src/ir/tensor/storage.cc index beafafbc..f3a1f4f1 100644 --- a/inferrt/src/ir/tensor/storage.cc +++ b/inferrt/src/ir/tensor/storage.cc @@ -67,7 +67,6 @@ void Storage::AllocateMemory() { LOG_EXCEPTION << "Device memory has already been allocated, or a device memory leak has occurred, device type: " << GetDeviceNameByType(device_.type) << ", data: " << data_; } - data_ = alloc_.Allocate(sizeBytes_); CHECK_IF_NULL(data_); } diff --git a/inferrt/src/ops/ascend/CMakeLists.txt b/inferrt/src/ops/ascend/CMakeLists.txt index 3d0115d3..4987be43 100644 --- a/inferrt/src/ops/ascend/CMakeLists.txt +++ b/inferrt/src/ops/ascend/CMakeLists.txt @@ -6,6 +6,9 @@ endif() include_directories(${ASCEND_PATH}/latest/include/) include_directories(${ASCEND_PATH}/latest/lib64/) +include_directories(${ASCEND_PATH}/latest/aarch64-linux/include/experiment) + link_directories(${ASCEND_PATH}/latest/lib64/) add_subdirectory(aclnn) +add_subdirectory(hccl) diff --git a/inferrt/src/ops/ascend/hccl/CMakeLists.txt b/inferrt/src/ops/ascend/hccl/CMakeLists.txt new file mode 100644 index 00000000..1244d3ee --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/CMakeLists.txt @@ -0,0 +1,18 @@ +check_debug_log_out() + +if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) +else() + set(ASCEND_PATH /usr/local/Ascend) +endif() + +find_library(HCCL hccl + PATHS ${ASCEND_PATH}/latest/lib64 + REQUIRED + NO_DEFAULT_PATH +) + +file(GLOB_RECURSE OPS_HCCL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + +add_library(ops_ascend_hccl SHARED ${OPS_HCCL_SRC_FILES}) +target_link_libraries(ops_ascend_hccl PRIVATE ops_base_obj hardware_ascend kernel mrt::securec ${HCCL}) diff --git a/inferrt/src/ops/ascend/hccl/hccl_adapter.cc b/inferrt/src/ops/ascend/hccl/hccl_adapter.cc new file mode 100644 index 00000000..226fd28d --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_adapter.cc @@ -0,0 +1,472 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/ascend/hccl/hccl_adapter.h" +#include +#include +#include +#include + +#include "hccl/hccl.h" +#include "hccl/hcom.h" + +static constexpr const auto kHcclPluginFileName = "libhccl.so"; + +#define CHECK_SYMBOL_NULL(symbol) \ + if ((symbol) == nullptr) { \ + LOG_ERROR << #symbol << " is null, hccl has not been inited, do nothing."; \ + return HcclResult::HCCL_E_RESERVED; \ + } + +namespace mrt::ops { +const char kDefaultGroup[] = "__default_group"; +constexpr uint32_t kDeviceNumOfServer = 8; + +HcclAdapter &HcclAdapter::GetInstance() { + static HcclAdapter instance; + return instance; +} + +void HcclAdapter::InitPlugin() { + if (plugin_handle_ != nullptr) { + return; + } +#ifndef ENABLE_ASAN + plugin_handle_ = dlopen(kHcclPluginFileName, RTLD_DEEPBIND | RTLD_NOW | RTLD_LOCAL); +#else + plugin_handle_ = dlopen(kHcclPluginFileName, RTLD_NOW | RTLD_LOCAL); +#endif + if (plugin_handle_ == nullptr) { + LOG_EXCEPTION << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg(); + } + + launch_hccl_all_gather_ = DlsymFuncObj(HcclAllGather, plugin_handle_); +} + +void HcclAdapter::FinalizePlugin() { + if (plugin_handle_ == nullptr) { + return; + } + set_hccl_global_comm_info_ = nullptr; + init_hccl_root_info_config_ = nullptr; + init_hccl_global_comm_ranktable_ = nullptr; + init_hccl_sub_comm_ranktable_ = nullptr; + get_hccl_comm_config_capability_ = nullptr; + init_hccl_comm_ = nullptr; + finalize_hccl_comm_ = nullptr; + launch_hccl_broadcast_ = nullptr; + launch_hccl_all_reduce_ = nullptr; + launch_hccl_reduce_ = nullptr; + launch_hccl_scatter_ = nullptr; + launch_hccl_reduce_scatter_ = nullptr; + launch_hccl_all_gather_ = nullptr; + launch_hccl_send_ = nullptr; + launch_hccl_recv_ = nullptr; + launch_hccl_barrier_ = nullptr; + launch_hccl_batch_isend_irecv_ = nullptr; + hccl_create_group_ = nullptr; + hccl_destroy_group_ = nullptr; + hccl_get_rank_id_ = nullptr; + hccl_get_local_rank_id_ = nullptr; + hccl_get_local_rank_size_ = nullptr; + hccl_get_world_rank_by_group_rank_ = nullptr; + hccl_get_group_rank_by_world_rank_ = nullptr; + hccl_get_rank_size_ = nullptr; + hccl_exec_enqueue_op_ = nullptr; + hccl_exec_enqueue_all_to_all_v_ = nullptr; + hccl_comm_working_dev_nic_set_ = nullptr; + launch_hccl_all_to_allv_ = nullptr; + launch_hccl_reduce_scatterv_ = nullptr; + launch_hccl_all_gatherv_ = nullptr; + launch_hccl_comm_resume_ = nullptr; + hcom_destroy_ = nullptr; + (void)dlclose(plugin_handle_); + plugin_handle_ = nullptr; +} + +std::string HcclAdapter::GetHcclModeString(HcclMode hccl_mode) { + static std::map kHcclModeString = {{HcclMode::kGraph, "GE_MODE"}, + {HcclMode::kPynative, "PYNATIVE_MODE"}, + {HcclMode::kKernelByKernel, "KERNEL_BY_KERNEL_MODE"}}; + return kHcclModeString.at(hccl_mode); +} + +bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id) { + LOG_OUT << "Start init hccl adapter."; + std::lock_guard lock(init_mutex_); + if (init_flag_) { + LOG_OUT << "Hccl has been inited, skip."; + return true; + } + InitPlugin(); + + init_flag_ = true; + LOG_OUT << "Init hccl adapter success."; + return true; +} + +bool HcclAdapter::HcclWatchdogThread(HcclComm comm, std::string *error_info, bool *disable) { + if (!init_flag_) { + LOG_OUT << "Hccl has never been inited, skip."; + return true; + } + CHECK_IF_NULL(disable); + if (hccl_get_comm_async_error_ == nullptr) { + LOG_OUT << "Hccl has never been inited, skip."; + return true; + } + if (hccl_get_error_string_ == nullptr) { + LOG_OUT << "Hccl has never been inited, skip."; + return true; + } + HcclResult hccl_async_error; + auto ret = hccl_get_comm_async_error_(comm, &hccl_async_error); + if (ret != HCCL_SUCCESS) { + LOG_OUT << "Call HcclGetCommAsyncError failed, close watchdog."; + *disable = true; + return true; + } + if (hccl_async_error != HCCL_SUCCESS) { + std::ostringstream oss; + oss << "Hccl get comm async error failed, error code is: " << hccl_async_error + << ", detail info: " << hccl_get_error_string_(hccl_async_error); + *error_info = oss.str(); + return false; + } + return true; +} + +bool HcclAdapter::FinalizeHccl() { + std::lock_guard lock(init_mutex_); + LOG_OUT << "Start destroy hccl adapter for " << GetHcclModeString(hccl_mode_); + if (!init_flag_) { + LOG_OUT << "Hccl has never been inited, skip."; + return true; + } + (void)FinalizeHcclExec(); + (void)FinalizeKernelInfoStore(); + (void)FinalizeHcclComm(); + if (hcom_destroy_ != nullptr) { + hcom_destroy_(); + } + FinalizePlugin(); + init_flag_ = false; + LOG_OUT << "Destroy hccl adapter success."; + return true; +} + +HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, + aclrtStream stream, HcclComm hccl_comm) const { + HcclResult ret = launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream, HcclComm hccl_comm) const { + HcclResult ret = launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, + HcclReduceOp op, uint32_t root, const aclrtStream stream, HcclComm hccl_comm) const { + HcclResult ret = launch_hccl_reduce_(send_buf, recv_buf, count, dataType, op, root, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, + uint32_t root, HcclComm comm, aclrtStream stream) const { + HcclResult ret = launch_hccl_scatter_(send_buf, recv_buf, count, dataType, root, comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream, HcclComm hccl_comm) const { + HcclResult ret = launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, + const aclrtStream stream, HcclComm hccl_comm) const { + CHECK_SYMBOL_NULL(launch_hccl_all_gather_); + CHECK_IF_NULL(hccl_comm); + CHECK_IF_NULL(send_buf); + CHECK_IF_NULL(recv_buf); + HcclResult ret = launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream); + return ret; +} + +HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, + const aclrtStream stream, HcclComm hccl_comm) const { + HcclResult ret = launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, + const aclrtStream stream, HcclComm hccl_comm) const { + HcclResult ret = launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclBarrier(const aclrtStream stream, HcclComm hccl_comm) const { + return launch_hccl_barrier_(hccl_comm, stream); +} + +HcclResult HcclAdapter::HcclBatchISendIRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm comm, + aclrtStream stream) const { + HcclResult ret = launch_hccl_batch_isend_irecv_(sendRecvInfo, itemNum, comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclCommResume(HcclComm comm) const { + if (launch_hccl_comm_resume_ == nullptr) { + LOG_EXCEPTION << "Dynamically load HcclCommResume failed."; + } + return launch_hccl_comm_resume_(comm); +} + +uint32_t HcclAdapter::HcclGetCommConfigCapability() { + CHECK_IF_NULL(get_hccl_comm_config_capability_); + return get_hccl_comm_config_capability_(); +} + +HcclResult HcclAdapter::HcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_t totalRankSize, + uint32_t nodeId, uint32_t localRankSize) { + if (set_hccl_global_comm_info_ == nullptr) { + set_hccl_global_comm_info_ = DlsymAscendFuncObj(HcclSetGlobalCommInfo, plugin_handle_); + if (set_hccl_global_comm_info_ == nullptr) { + LOG_OUT << "Func HcclSetGlobalCommInfo is not supported in CANN package."; + return HCCL_E_NOT_SUPPORT; + } + } + return set_hccl_global_comm_info_(masterIp, masterPort, totalRankSize, nodeId, localRankSize); +} + +HcclResult HcclAdapter::HcclCommInitClusterInfoConfig(const char *rank_table, uint32_t rank_id, HcclCommConfig *config, + HcclComm *hccl_comm) { + if (init_hccl_global_comm_ranktable_ == nullptr) { + init_hccl_global_comm_ranktable_ = DlsymFuncObj(HcclCommInitClusterInfoConfig, plugin_handle_); + } + return init_hccl_global_comm_ranktable_(rank_table, rank_id, config, hccl_comm); +} + +HcclResult HcclAdapter::HcclCommInitRootInfoConfig(uint32_t n_ranks, const HcclRootInfo *root_info, uint32_t rank, + const HcclCommConfig *config, HcclComm *hccl_comm_) { + if (init_hccl_root_info_config_ == nullptr) { + init_hccl_root_info_config_ = DlsymFuncObj(HcclCommInitRootInfoConfig, plugin_handle_); + if (init_hccl_root_info_config_ == nullptr) { + // new api in CANN C20 + return HcclCommInitRootInfo(n_ranks, root_info, rank, hccl_comm_); + } + } + + return init_hccl_root_info_config_(n_ranks, root_info, rank, config, hccl_comm_); +} + +HcclResult HcclAdapter::HcclCreateSubCommConfig(HcclComm *global_comm, uint32_t rank_size, uint32_t *rank_ids, + uint64_t comm_id, uint32_t rank_id, HcclCommConfig *config, + HcclComm *hccl_comm) { + if (init_hccl_sub_comm_ranktable_ == nullptr) { + init_hccl_sub_comm_ranktable_ = DlsymFuncObj(HcclCreateSubCommConfig, plugin_handle_); + } + return init_hccl_sub_comm_ranktable_(global_comm, rank_size, rank_ids, comm_id, rank_id, config, hccl_comm); +} + +bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_file) { + LOG_OUT << "Start init hccl comm."; + int rank_id_i = -1; + try { + rank_id_i = std::stoi(rank_id.data()); + } catch (std::invalid_argument &) { + LOG_EXCEPTION << "Invalid rank id env:" << rank_id; + } + if (rank_id_i < 0) { + LOG_ERROR << "rank_id cannot be negative"; + return false; + } + CHECK_IF_NULL(init_hccl_comm_); + auto hccl_result = init_hccl_comm_(rank_file.data(), rank_id_i, &hccl_comm_); + if (hccl_result != HCCL_SUCCESS) { + LOG_ERROR << "HcclCommInitClusterInfo failed, ret:" << hccl_result; + return false; + } + LOG_OUT << "InitHcclComm success"; + return true; +} + +bool HcclAdapter::FinalizeHcclComm() { + LOG_OUT << "Start finalize hccl comm."; + if (hccl_comm_ == nullptr) { + return true; + } + + CHECK_IF_NULL(finalize_hccl_comm_); + auto hccl_result = finalize_hccl_comm_(hccl_comm_); + if (hccl_result != HCCL_SUCCESS) { + LOG_ERROR << "HcclComm destroy failed, ret:" << hccl_result; + return false; + } + hccl_comm_ = nullptr; + LOG_OUT << "HcclComm destroy success"; + return true; +} + +HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const { + CHECK_SYMBOL_NULL(hccl_create_group_); + return hccl_create_group_(group.c_str(), rank_num, rank_ids); +} + +HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const { + CHECK_SYMBOL_NULL(hccl_destroy_group_); + return hccl_destroy_group_(group.c_str()); +} + +HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const { + if (hccl_mode_ != HcclMode::kGraph) { + CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_); + return single_op_hccl_get_rank_id_(hccl_comm_, rank_id); + } else { + CHECK_SYMBOL_NULL(hccl_get_rank_id_); + return hccl_get_rank_id_(group.c_str(), rank_id); + } +} + +HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const { + if (hccl_mode_ != HcclMode::kGraph) { + CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_); + return single_op_hccl_get_rank_size_(hccl_comm_, rank_size); + } else { + CHECK_SYMBOL_NULL(hccl_get_rank_size_); + return hccl_get_rank_size_(group.c_str(), rank_size); + } +} + +HcclResult HcclAdapter::HcclGetLocalRankId(const std::string &group, uint32_t *local_rank_id) const { + CHECK_SYMBOL_NULL(hccl_get_local_rank_id_); + return hccl_get_local_rank_id_(group.c_str(), local_rank_id); +} + +HcclResult HcclAdapter::HcclGetLocalRankSize(const std::string &group, uint32_t *local_rank_size) const { + if (hccl_mode_ != HcclMode::kGraph) { + LOG_ERROR << "The pynative mode doesn't support get local rank szie."; + return HCCL_E_NOT_SUPPORT; + } else { + CHECK_SYMBOL_NULL(hccl_get_local_rank_size_); + return hccl_get_local_rank_size_(group.c_str(), local_rank_size); + } +} + +HcclResult HcclAdapter::HcclGetWorldRankFromGroupRank(const std::string &group, uint32_t local_rank, + uint32_t *world_rank) const { + if (hccl_mode_ != HcclMode::kGraph) { + LOG_ERROR << "The pynative mode doesn't support get world rank by group rank."; + return HCCL_E_NOT_SUPPORT; + } else { + CHECK_SYMBOL_NULL(hccl_get_world_rank_by_group_rank_); + return hccl_get_world_rank_by_group_rank_(group.c_str(), local_rank, world_rank); + } +} + +HcclResult HcclAdapter::HcclGetGroupRankFromWorldRank(uint32_t world_rank, const std::string &group, + uint32_t *local_rank) const { + if (hccl_mode_ != HcclMode::kGraph) { + LOG_ERROR << "The pynative mode doesn't support get group rank by world rank."; + return HCCL_E_NOT_SUPPORT; + } else { + CHECK_SYMBOL_NULL(hccl_get_group_rank_by_world_rank_); + return hccl_get_group_rank_by_world_rank_(world_rank, group.c_str(), local_rank); + } +} + +HcclResult HcclAdapter::HcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks) { + if (hccl_comm_working_dev_nic_set_ == nullptr) { + hccl_comm_working_dev_nic_set_ = DlsymFuncObj(HcclCommWorkingDevNicSet, plugin_handle_); + } + CHECK_SYMBOL_NULL(hccl_comm_working_dev_nic_set_); + return hccl_comm_working_dev_nic_set_(comm, ranks, useBackup, nRanks); +} + +HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const { + CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_); + return hccl_exec_enqueue_op_(op_info, callback); +} + +HcclResult HcclAdapter::HcclExecAlltoAllV(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const { + CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_); + return hccl_exec_enqueue_all_to_all_v_(params, callback); +} + +bool HcclAdapter::UseHcclCM() const { + // This MS_HCCL_CM_INIT env is deperacated since MindSpore 2.3 version. + return false; +} + +HcclResult HcclAdapter::HcclAlltoAllV(void *send_buf, void *recv_buf, HcclAllToAllVParams params, HcclDataType dataType, + aclrtStream stream, HcclComm hccl_comm) const { + CHECK_SYMBOL_NULL(launch_hccl_all_to_allv_); + CHECK_IF_NULL(hccl_comm); + HcclResult ret = + launch_hccl_all_to_allv_(send_buf, params.sendcounts.data(), params.sdispls.data(), dataType, recv_buf, + params.recvcounts.data(), params.rdispls.data(), dataType, hccl_comm, stream); + + return ret; +} + +HcclResult HcclAdapter::HcclReduceScatterV(void *send_buf, void *recv_buf, HcclReduceScatterVParams params, + HcclDataType data_type, const HcclReduceOp op, const aclrtStream stream, + HcclComm hccl_comm) const { + CHECK_SYMBOL_NULL(launch_hccl_reduce_scatterv_); + CHECK_IF_NULL(hccl_comm); + HcclResult ret = launch_hccl_reduce_scatterv_(send_buf, params.send_counts.data(), params.sdispls.data(), recv_buf, + params.recv_count, data_type, op, hccl_comm, stream); + return ret; +} + +HcclResult HcclAdapter::HcclAllGatherV(void *send_buf, void *recv_buf, HcclAllGatherVParams params, + HcclDataType data_type, const aclrtStream stream, HcclComm hccl_comm) const { + CHECK_SYMBOL_NULL(launch_hccl_all_gatherv_); + CHECK_IF_NULL(hccl_comm); + HcclResult ret = launch_hccl_all_gatherv_(send_buf, params.send_count, recv_buf, params.recv_counts.data(), + params.rdispls.data(), data_type, hccl_comm, stream); + return ret; +} + +HcclResult HcclAdapter::HcclAllToAll(void *send_buf, void *recv_buf, HcclAllToAllParams params, HcclDataType dataType, + aclrtStream stream, HcclComm hccl_comm) const { + CHECK_SYMBOL_NULL(launch_hccl_all_to_all_); + CHECK_IF_NULL(hccl_comm); + + HcclResult ret = launch_hccl_all_to_all_(send_buf, params.sendcount, dataType, recv_buf, params.recvcount, dataType, + hccl_comm, stream); + + return ret; +} + +bool HcclAdapter::IsSameServer(const std::vector &rank_ids) const { + auto min_iter = min_element(rank_ids.begin(), rank_ids.end()); + uint32_t min = (min_iter != rank_ids.end()) ? *min_iter : 0; + auto max_iter = max_element(rank_ids.begin(), rank_ids.end()); + uint32_t max = (max_iter != rank_ids.end()) ? *max_iter : 0; + return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer)); +} + +} // namespace mrt::ops diff --git a/inferrt/src/ops/ascend/hccl/hccl_adapter.h b/inferrt/src/ops/ascend/hccl/hccl_adapter.h new file mode 100644 index 00000000..c716f183 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_adapter.h @@ -0,0 +1,208 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OPS_ASCEND_HCCL_ADAPTER_H_ +#define OPS_ASCEND_HCCL_ADAPTER_H_ + +#include "ops/ascend/hccl/hccl_plugin.h" +#include +#include +#include +#include +#include + +#include "hccl/hccl_types.h" + +namespace mrt::ops { +struct HcclTaskInfo { + std::string private_def; + int64_t workspace_size; + int64_t stream_num; +}; + +struct HcclAllToAllVParams { + std::vector sendcounts; + std::vector sdispls; + std::vector recvcounts; + std::vector rdispls; +}; + +struct HcclAllGatherVParams { + uint64_t send_count; + std::vector recv_counts; + std::vector rdispls; +}; + +struct HcclReduceScatterVParams { + std::vector send_counts; + std::vector sdispls; + uint64_t recv_count; +}; + +struct HcclAllToAllParams { + uint64_t sendcount; + uint64_t recvcount; +}; + +enum HcclMode { kGraph, kPynative, kKernelByKernel }; + +class HcclAdapter { + public: + static HcclAdapter &GetInstance(); + + // common + bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file, HcclMode hccl_mode); + bool InitHccl(uint32_t device_id, std::string_view rank_id); + uint32_t HcclGetCommConfigCapability(); + HcclResult HcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_t totalRankSize, uint32_t nodeId, + uint32_t localRankSize); + HcclResult HcclCommInitClusterInfoConfig(const char *rank_table, uint32_t rank_id, HcclCommConfig *config, + HcclComm *hccl_comm_); + HcclResult HcclCommInitRootInfoConfig(uint32_t n_ranks, const HcclRootInfo *root_info, uint32_t rank, + const HcclCommConfig *config, HcclComm *hccl_comm_); + HcclResult HcclCreateSubCommConfig(HcclComm *global_comm, uint32_t rank_size, uint32_t *rank_ids, uint64_t comm_id, + uint32_t rank_id, HcclCommConfig *config, HcclComm *hccl_comm_); + bool FinalizeHccl(); + bool HcclWatchdogThread(HcclComm comm, std::string *error_info, bool *ret); + const bool Inited() const { return init_flag_; } + const HcclComm get_hccl_comm() const { return hccl_comm_; } + HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const; + HcclResult HcclDestroyGroup(const std::string &group) const; + HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const; + HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const; + HcclResult HcclGetLocalRankId(const std::string &group, uint32_t *lcoal_rank_id) const; + HcclResult HcclGetLocalRankSize(const std::string &group, uint32_t *local_rank_size) const; + HcclResult HcclGetWorldRankFromGroupRank(const std::string &group, uint32_t local_rank, uint32_t *world_rank) const; + HcclResult HcclGetGroupRankFromWorldRank(uint32_t world_rank, const std::string &group, uint32_t *local_rank) const; + // for single op + HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, aclrtStream stream, + HcclComm comm) const; + HcclResult HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + const aclrtStream stream, HcclComm comm) const; + HcclResult HcclReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + uint32_t root, const aclrtStream stream, HcclComm comm) const; + HcclResult HcclScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t root, + HcclComm comm, aclrtStream stream) const; + HcclResult HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, + const aclrtStream stream, HcclComm comm) const; + HcclResult HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + const aclrtStream stream, HcclComm comm) const; + HcclResult HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, + const aclrtStream stream, HcclComm comm) const; + HcclResult HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, const aclrtStream stream, + HcclComm comm) const; + HcclResult HcclAlltoAllV(void *send_buf, void *recv_buf, HcclAllToAllVParams params, HcclDataType dataType, + const aclrtStream stream, HcclComm comm) const; + + HcclResult HcclReduceScatterV(void *send_buf, void *recv_buf, HcclReduceScatterVParams params, HcclDataType data_type, + const HcclReduceOp op, const aclrtStream stream, HcclComm hccl_comm) const; + + HcclResult HcclAllGatherV(void *send_buf, void *recv_buf, HcclAllGatherVParams params, HcclDataType data_type, + const aclrtStream stream, HcclComm hccl_comm) const; + + HcclResult HcclAllToAll(void *send_buf, void *recv_buf, HcclAllToAllParams params, HcclDataType dataType, + const aclrtStream stream, HcclComm comm) const; + HcclResult HcclBarrier(const aclrtStream stream, HcclComm comm) const; + HcclResult HcclBatchISendIRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm comm, + aclrtStream stream) const; + + // for enqueue op + HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const; + HcclResult HcclExecAlltoAllV(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const; + + HcclResult HcclCommResume(HcclComm comm) const; + + HcclResult HcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks); + + // Return whether using CM to initialize HCCL. + bool UseHcclCM() const; + static void AddCMEnvToHcclOption(std::map *hccl_opt_map); + + bool IsSameServer(const std::vector &rank_ids) const; + + private: + HcclAdapter() = default; + ~HcclAdapter() = default; + void InitPlugin(); + void FinalizePlugin(); + + bool InitKernelInfoStore(const std::map options); + bool FinalizeKernelInfoStore(); + + bool InitHcclComm(std::string_view rank_id, std::string_view rank_file); + bool FinalizeHcclComm(); + + bool InitHcclExec(); + bool FinalizeHcclExec(); + + static std::string GetHcclModeString(HcclMode hccl_mode); + + static bool IsSimulation(); + void *plugin_handle_ = nullptr; + + HcomDestroyFunObj hcom_destroy_ = nullptr; + + HcclGetCommConfigCapabilityFunObj get_hccl_comm_config_capability_ = nullptr; + HcclSetGlobalCommInfoFunObj set_hccl_global_comm_info_ = nullptr; + HcclCommInitClusterInfoFunObj init_hccl_comm_ = nullptr; + HcclCommInitClusterInfoConfigFunObj init_hccl_global_comm_ranktable_ = nullptr; + HcclCommInitRootInfoConfigFunObj init_hccl_root_info_config_ = nullptr; + HcclCreateSubCommConfigFunObj init_hccl_sub_comm_ranktable_ = nullptr; + HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr; + HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr; + HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr; + HcclReduceFunObj launch_hccl_reduce_ = nullptr; + HcclScatterFunObj launch_hccl_scatter_ = nullptr; + HcclReduceScatterFunObj launch_hccl_reduce_scatter_ = nullptr; + HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr; + HcclSendFunObj launch_hccl_send_ = nullptr; + HcclRecvFunObj launch_hccl_recv_ = nullptr; + HcclBarrierFunObj launch_hccl_barrier_ = nullptr; + HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr; + HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr; + HcclAlltoAllVFunObj launch_hccl_all_to_allv_ = nullptr; + HcclReduceScatterVFunObj launch_hccl_reduce_scatterv_ = nullptr; + HcclAllGatherVFunObj launch_hccl_all_gatherv_ = nullptr; + HcclAlltoAllFunObj launch_hccl_all_to_all_ = nullptr; + HcclBatchSendRecvFunObj launch_hccl_batch_isend_irecv_ = nullptr; + HcclCommResumeFunObj launch_hccl_comm_resume_ = nullptr; + HcclGetCommAsyncErrorFunObj hccl_get_comm_async_error_ = nullptr; + HcclGetErrorStringFunObj hccl_get_error_string_ = nullptr; + HcomCreateGroupFunObj hccl_create_group_ = nullptr; + HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr; + HcomGetRankIdFunObj hccl_get_rank_id_ = nullptr; + HcomGetRankSizeFunObj hccl_get_rank_size_ = nullptr; + HcomGetLocalRankIdFunObj hccl_get_local_rank_id_ = nullptr; + HcomGetLocalRankSizeFunObj hccl_get_local_rank_size_ = nullptr; + HcomGetWorldRankFromGroupRankFunObj hccl_get_world_rank_by_group_rank_ = nullptr; + HcomGetGroupRankFromWorldRankFunObj hccl_get_group_rank_by_world_rank_ = nullptr; + HcclCommWorkingDevNicSetFunObj hccl_comm_working_dev_nic_set_ = nullptr; + + HcomExecInitializeFunObj hccl_exec_initialize_ = nullptr; + HcomExecFinalizeFunObj hccl_exec_finalize_ = nullptr; + HcomExecEnqueueOperationFunObj hccl_exec_enqueue_op_ = nullptr; + HcomExecEnqueueAllToAllVFunObj hccl_exec_enqueue_all_to_all_v_ = nullptr; + + HcclComm hccl_comm_ = nullptr; + + bool init_flag_ = false; + bool init_kernel_info_store_ = false; + bool init_hccl_exec_ = false; + HcclMode hccl_mode_ = HcclMode::kGraph; + std::mutex init_mutex_; +}; +} // namespace mrt::ops +#endif // OPS_ASCEND_HCCL_ADAPTER_H_ diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc b/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc new file mode 100644 index 00000000..8f19913d --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "ops/ascend/hccl/hccl_all_gather.h" +#include "ops/ascend/hccl/hccl_adapter.h" +#include "hardware/hardware_abstract/collective/collective_manager.h" +#include "ops/ascend/hccl/hcom_utils.h" +#include "hccl/hccl_types.h" +#include "hccl/hccl.h" + +#include "common/logger.h" +#include "ops/op_register.h" + +#include "hardware/ascend/res_manager/ascend_stream_manager.h" + +namespace mrt { +namespace ops { +OpsErrorCode HcclAllGather::CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspace_size) { + LOG_OUT << "HcclAllGather CalcWorkspace"; + auto rank_id = mrt::collective::CollectiveManager::Instance().local_rank_id(); + std::string rank_id_str = std::to_string(0); + HcclAdapter::GetInstance().InitHccl(rank_id, rank_id_str); + auto [hccl_count, hccl_data_type] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); + hcclKernel.hccl_count_ = hccl_count; + hcclKernel.hccl_data_type_ = hccl_data_type; + const string &group_name = input[kIndex2]->ToString(); + hcclKernel.comm_ = HcomUtil::LoadHcclLibrary(group_name); + + return SUCCESS; +} + +OpsErrorCode HcclAllGather::Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) { + LOG_OUT << "HcclAllGather launch"; + + auto hccl_result = HcclAdapter::GetInstance().HcclAllGather(const_cast(input[kIndex0]->ToTensor()->DataPtr()), + output->ToTensor()->DataPtr(), hcclKernel.hccl_count_, + hcclKernel.hccl_data_type_, stream, hcclKernel.comm_); + if (hccl_result != ::HcclResult::HCCL_SUCCESS) { + LOG_ERROR << "HcomAllGather failed, hccl_result: " << hccl_result; + } + + return SUCCESS; +} +MRT_REG_OP(all_gather, HcclAllGather, Ascend); +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_gather.h b/inferrt/src/ops/ascend/hccl/hccl_all_gather.h new file mode 100644 index 00000000..34bba835 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_all_gather.h @@ -0,0 +1,44 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OPS_ASCEND_HCCL_ALL_GATHER_H_ +#define OPS_ASCEND_HCCL_ALL_GATHER_H_ + +#include + +#include "ops/op_base/op_all_gather.h" + +#include "ops/operator.h" +#include "ops/ascend/hccl/hccl_kernel.h" + +namespace mrt { +namespace ops { +class HcclAllGather : public OpAllGather { + public: + HcclAllGather() = default; + ~HcclAllGather() = default; + + OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspace_size) override; + OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) override; + + private: + HcclKernel hcclKernel; +}; +} // namespace ops +} // namespace mrt +#endif // OPS_ASCEND_HCCL_ALL_GATHER_H_ diff --git a/inferrt/src/ops/ascend/hccl/hccl_kernel.cc b/inferrt/src/ops/ascend/hccl/hccl_kernel.cc new file mode 100644 index 00000000..6f1c98c3 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/ascend/hccl/hccl_kernel.h" + +#include +#include +#include + +#include "ops/ascend/hccl/hccl_adapter.h" +#include "ops/ascend/hccl/hcom_utils.h" + +namespace mrt { +namespace ops { + +HcclKernel::HcclKernel() : hccl_count_(0), root_id_(0), src_rank_(0), dest_rank_(0), comm_(nullptr) {} + +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hccl_kernel.h b/inferrt/src/ops/ascend/hccl/hccl_kernel.h new file mode 100644 index 00000000..bb6e896f --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OPS_ASCEND_HCCL_KERNEL_H_ +#define OPS_ASCEND_HCCL_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "ops/operator.h" +#include "ops/ascend/hccl/hcom_utils.h" +#include "hccl/hcom.h" +#include "hccl/hccl_types.h" + +namespace mrt { +namespace ops { +class HcclKernel { + public: + HcclKernel(); + ~HcclKernel() = default; + + public: + HcclDataType hccl_data_type_; + uint64_t hccl_count_; + uint32_t root_id_; + uint32_t src_rank_; + uint32_t dest_rank_; + std::string group_; + HcclComm comm_; + ulong loop_size_{0}; + bool is_graph_mode_{false}; + std::string hccl_inner_comm_name_; +}; + +} // namespace ops +} // namespace mrt +#endif // OPS_ASCEND_HCCL_KERNEL_H_ diff --git a/inferrt/src/ops/ascend/hccl/hccl_plugin.h b/inferrt/src/ops/ascend/hccl/hccl_plugin.h new file mode 100644 index 00000000..bac11855 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hccl_plugin.h @@ -0,0 +1,89 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OPS_ASCEND_HCCL_PLUGIN_H +#define OPS_ASCEND_HCCL_PLUGIN_H + +#include +#include +#include +#include + +#include "hccl/hccl.h" +#include "hccl/hcom.h" + +#include "hardware/hardware_abstract/dlopen_macro.h" +#include "hardware/ascend/res_manager/ascend_stream_manager.h" + +extern "C" { +struct HcomOperation; +} // extern C + +using OptionsType = std::map; +using HExecCallBack = std::function; + +ORIGIN_METHOD(HcclBroadcast, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclAllReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclReduce, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, uint32_t, HcclComm, + aclrtStream); +ORIGIN_METHOD(HcclScatter, HcclResult, void *, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclReduceScatter, HcclResult, void *, void *, uint64_t, HcclDataType, HcclReduceOp, HcclComm, + aclrtStream); +ORIGIN_METHOD(HcclAllGather, HcclResult, void *, void *, uint64_t, HcclDataType, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclSend, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclRecv, HcclResult, void *, uint64_t, HcclDataType, uint32_t, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclAlltoAllV, HcclResult, const void *, const void *, const void *, HcclDataType, const void *, + const void *, const void *, HcclDataType, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclAllGatherV, HcclResult, void *, uint64_t, void *, const void *, const void *, HcclDataType, HcclComm, + aclrtStream); +ORIGIN_METHOD(HcclReduceScatterV, HcclResult, void *, const void *, const void *, void *, uint64_t, HcclDataType, + HcclReduceOp, HcclComm, aclrtStream); + +ORIGIN_METHOD(HcclAlltoAll, HcclResult, const void *, uint64_t, HcclDataType, const void *, uint64_t, HcclDataType, + HcclComm, aclrtStream); +ORIGIN_METHOD(HcclBarrier, HcclResult, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclBatchSendRecv, HcclResult, HcclSendRecvItem *, uint32_t, HcclComm, aclrtStream); +ORIGIN_METHOD(HcclCommResume, HcclResult, HcclComm) + +ORIGIN_METHOD(HcclGetCommAsyncError, HcclResult, HcclComm, HcclResult *); +ORIGIN_METHOD(HcclGetErrorString, const char *, HcclResult); +ORIGIN_METHOD(HcclGetCommConfigCapability, uint32_t); +ORIGIN_METHOD(HcclSetGlobalCommInfo, HcclResult, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); +ORIGIN_METHOD(HcclCommInitClusterInfo, HcclResult, const char *, uint32_t, HcclComm *); +ORIGIN_METHOD(HcclCommInitClusterInfoConfig, HcclResult, const char *, uint32_t, HcclCommConfig *, HcclComm *); +ORIGIN_METHOD(HcclCommInitRootInfoConfig, HcclResult, uint32_t, const HcclRootInfo *, uint32_t, const HcclCommConfig *, + HcclComm *); +ORIGIN_METHOD(HcclCreateSubCommConfig, HcclResult, HcclComm *, uint32_t, uint32_t *, uint64_t, uint32_t, + HcclCommConfig *, HcclComm *) +ORIGIN_METHOD(HcclCommDestroy, HcclResult, HcclComm); +ORIGIN_METHOD(HcclGetRankId, HcclResult, void *, uint32_t *); +ORIGIN_METHOD(HcclGetRankSize, HcclResult, void *, uint32_t *); +ORIGIN_METHOD(HcclGetCommName, HcclResult, HcclComm, char *) +ORIGIN_METHOD(HcomGetLocalRankId, HcclResult, const char *, uint32_t *); +ORIGIN_METHOD(HcomGetLocalRankSize, HcclResult, const char *, uint32_t *); +ORIGIN_METHOD(HcomGetWorldRankFromGroupRank, HcclResult, const char *, uint32_t, uint32_t *); +ORIGIN_METHOD(HcomGetGroupRankFromWorldRank, HcclResult, uint32_t, const char *, uint32_t *); +ORIGIN_METHOD(HcclCommWorkingDevNicSet, HcclResult, HcclComm, uint32_t *, bool *, uint32_t); + +ORIGIN_METHOD(HcomCreateGroup, HcclResult, const char *, uint32_t, uint32_t *); +ORIGIN_METHOD(HcomDestroyGroup, HcclResult, const char *); +ORIGIN_METHOD(HcomGetRankId, HcclResult, const char *, uint32_t *); +ORIGIN_METHOD(HcomGetRankSize, HcclResult, const char *, uint32_t *); +ORIGIN_METHOD(HcomExecInitialize, HcclResult); +ORIGIN_METHOD(HcomExecFinalize, HcclResult); +ORIGIN_METHOD(HcomExecEnqueueOperation, HcclResult, ::HcomOperation, HExecCallBack); +ORIGIN_METHOD(HcomExecEnqueueAllToAllV, HcclResult, ::HcomAllToAllVParams, HExecCallBack); +ORIGIN_METHOD(HcomDestroy, HcclResult); +#endif // OPS_ASCEND_HCCL_PLUGIN_H diff --git a/inferrt/src/ops/ascend/hccl/hcom_utils.cc b/inferrt/src/ops/ascend/hccl/hcom_utils.cc new file mode 100644 index 00000000..40578daa --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hcom_utils.cc @@ -0,0 +1,153 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/ascend/hccl/hcom_utils.h" + +#include +#include +#include +#include + +namespace mrt::ops { + +inline int64_t LongMulWithOverflowCheck(int64_t a, int64_t b) { + int64_t out = a * b; + if (a != 0) { + bool overflow = ((out / a) != b); + if (overflow) { + LOG_EXCEPTION << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + return out; +} + +inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { + size_t out = a * b; + if (a != 0) { + if ((out / a) != b) { + LOG_EXCEPTION << "Mul: a(" << a << ") * b(" << b << ") result is overflow"; + } + } + return out; +} + +inline size_t LongToSizeClipNeg(int64_t u) { return u < 0 ? 0 : static_cast(u); } + +::HcclDataType HcomUtil::ConvertHcclType(DataType type_id) { + auto iter = kConstOpHcomDataTypeMap.find(type_id); + if (iter == kConstOpHcomDataTypeMap.end()) { + LOG_EXCEPTION << "HcomDataType can't support Current Ascend Data Type : " << type_id.ToString(); + } + return iter->second; +} + +bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const std::vector &shape, size_t *size) { + CHECK_IF_NULL(size); + int64_t tmp_size = 1; + uint32_t type_size = 4; + for (size_t i = 0; i < shape.size(); i++) { + tmp_size = LongMulWithOverflowCheck(tmp_size, shape[i]); + } + + if (!GetHcomTypeSize(data_type, &type_size)) { + return false; + } + + *size = SizetMulWithOverflowCheck(LongToSizeClipNeg(tmp_size), type_size); + return true; +} + +bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) { + CHECK_IF_NULL(size); + auto iter = kConstOpHcomDataTypeSizeMap.find(data_type); + if (iter == kConstOpHcomDataTypeSizeMap.end()) { + LOG_ERROR << "HcomUtil::HcomDataTypeSize, No DataTypeSize!"; + return false; + } + *size = iter->second; + return true; +} + +bool HcomUtil::GetHcomCount(const std::vector &data_type_list, + const std::vector> &shape_list, const size_t input_tensor_num, + const std::optional rank_size_opt, uint64_t *total_count) { + CHECK_IF_NULL(total_count); + + const uint32_t align_size = 512; + const uint32_t filled_size = 32; + uint64_t total_size = 0; + size_t input_size; + uint32_t type_size = 4; + // size_t rank_size = 1; + CHECK_IF_FAIL(data_type_list.size() == shape_list.size()); + + for (size_t i = 0; i < data_type_list.size(); ++i) { + if (!GetHcomTypeSize(data_type_list[i], &type_size)) { + return false; + } + + if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) { + LOG_ERROR << "Get GetHcclOpSize failed"; + return false; + } + + if (input_tensor_num > 1) { + // communication operator with dynamic input should have continuous memory. + input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + } + + bool all_dynamic = std::all_of(shape_list[i].begin(), shape_list[i].end(), [](int64_t x) { return x == -1; }); + if (!all_dynamic && (type_size == 0 || input_size % type_size != 0)) { + return false; + } + total_size += input_size / type_size; + } + *total_count = total_size; + return true; +} + +std::pair HcomUtil::GetHcclCountAndTypeFromTensor( + const ir::TensorPtr &tensor, const std::optional rank_size_opt) { + auto type_id = tensor->Dtype(); + auto shape = tensor->Shape(); + + auto hccl_type = ConvertHcclType(type_id); + + uint64_t hccl_count = 0; + constexpr size_t input_tensor_size = 1; + if (!GetHcomCount({hccl_type}, {shape}, input_tensor_size, rank_size_opt, &hccl_count)) { + LOG_EXCEPTION << "GetHcomCount fail!"; + } + return std::make_pair(hccl_count, hccl_type); +} + +CollectiveOpReduceType HcomUtil::GetCollectiveOpReduceType(const std::string &reduce_op) { + auto iter = kConstOpCollectiveOpReduceTypeMap.find(reduce_op); + if (iter == kConstOpCollectiveOpReduceTypeMap.end()) { + LOG_EXCEPTION << "HcomUtil::Get CollectiveOpReduceType fail, [" << reduce_op << "] not support!"; + } + return iter->second; +} + +HcclReduceOp HcomUtil::GetHcomReduceOpType(const std::string &reduce_op) { + auto iter = kConstOpHcomReduceOpTypeMap.find(reduce_op); + if (iter == kConstOpHcomReduceOpTypeMap.end()) { + LOG_EXCEPTION << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << reduce_op << "] not support!"; + } + return iter->second; +} + +} // namespace mrt::ops diff --git a/inferrt/src/ops/ascend/hccl/hcom_utils.h b/inferrt/src/ops/ascend/hccl/hcom_utils.h new file mode 100644 index 00000000..7c81ba1b --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/hcom_utils.h @@ -0,0 +1,130 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OPS_ASCEND_HCCL_HCOM_UTILS_H_ +#define OPS_ASCEND_HCCL_HCOM_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "hardware/hardware_abstract/collective/collective_manager.h" +#include "ir/common/dtype.h" +#include "ir/tensor/tensor.h" +#include "common/common.h" + +#include "hccl/base.h" +#include "hccl/hccl_types.h" + +namespace mrt::ops { +using ir::DataType; +using mrt::ir::TensorPtr; +using std::map; +using std::string; +using std::vector; +constexpr int64_t kComplex64ConvertFloat32Num = 2; + +enum CollectiveOpReduceType : int64_t { + Reduce_Mean = 0, + Reduce_Max = 1, + Reduce_Min = 2, + Reduce_Prod = 3, + Reduce_Sum = 4, + Reduce_Sum_Square = 5, + Reduce_ASum = 6, + Reduce_All = 7 +}; + +/* Correspondence between data_type and hcom data type in Ascend */ +static const map kConstOpHcomDataTypeMap = { + {DataType::Int8, HCCL_DATA_TYPE_INT8}, + {DataType::Int16, HCCL_DATA_TYPE_INT16}, + {DataType::Int32, HCCL_DATA_TYPE_INT32}, + {DataType::Float32, HCCL_DATA_TYPE_FP32}, + {DataType::Int64, HCCL_DATA_TYPE_INT64}, + {DataType::UInt8, HCCL_DATA_TYPE_UINT8}, + {DataType::Float64, HCCL_DATA_TYPE_FP64}, + {DataType::Bool, HCCL_DATA_TYPE_INT8}, +#ifdef EXPERIMENT_A5 + {DataType::kNumberTypeHiFloat8, HCCL_DATA_TYPE_HIF8}, + {DataType::kNumberTypeFloat8E5M2, HCCL_DATA_TYPE_FP8E5M2}, + {DataType::kNumberTypeFloat8E4M3FN, HCCL_DATA_TYPE_FP8E4M3}, +#endif +}; + +/* Correspondence between data_type and occupied byte size in hcom */ +static const map kConstOpHcomDataTypeSizeMap = { + {HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, {HCCL_DATA_TYPE_INT16, sizeof(int32_t) / 2}, + {HCCL_DATA_TYPE_INT32, sizeof(int32_t)}, {HCCL_DATA_TYPE_FP16, sizeof(float) / 2}, + {HCCL_DATA_TYPE_FP32, sizeof(float)}, {HCCL_DATA_TYPE_INT64, sizeof(int64_t)}, + {HCCL_DATA_TYPE_UINT64, sizeof(uint64_t)}, {HCCL_DATA_TYPE_UINT8, sizeof(uint8_t)}, + {HCCL_DATA_TYPE_UINT16, sizeof(uint32_t) / 2}, {HCCL_DATA_TYPE_UINT32, sizeof(uint32_t)}, + {HCCL_DATA_TYPE_FP64, sizeof(double)}, {HCCL_DATA_TYPE_BFP16, sizeof(float) / 2}, +#ifdef EXPERIMENT_A5 + {HCCL_DATA_TYPE_HIF8, sizeof(float) / 4}, {HCCL_DATA_TYPE_FP8E5M2, sizeof(float) / 4}, + {HCCL_DATA_TYPE_FP8E4M3, sizeof(float) / 4}, +#endif +}; + +static const std::map kHcomOpReduceTypeMap = { + {CollectiveOpReduceType::Reduce_Max, HCCL_REDUCE_MAX}, + {CollectiveOpReduceType::Reduce_Min, HCCL_REDUCE_MIN}, + {CollectiveOpReduceType::Reduce_Prod, HCCL_REDUCE_PROD}, + {CollectiveOpReduceType::Reduce_Sum, HCCL_REDUCE_SUM}}; + +/* Correspondence between reduce str and enum in hcom */ +static const std::unordered_map kConstOpHcomReduceOpTypeMap = { + {"min", HCCL_REDUCE_MIN}, + {"max", HCCL_REDUCE_MAX}, + {"prod", HCCL_REDUCE_PROD}, + {"sum", HCCL_REDUCE_SUM}, +}; + +/* Correspondence between reduce str and enum in collective op */ +static const std::unordered_map kConstOpCollectiveOpReduceTypeMap = { + {"min", CollectiveOpReduceType::Reduce_Min}, + {"max", CollectiveOpReduceType::Reduce_Max}, + {"prod", CollectiveOpReduceType::Reduce_Prod}, + {"sum", CollectiveOpReduceType::Reduce_Sum}, +}; + +class HcomUtil { + public: + static ::HcclDataType ConvertHcclType(DataType type_id); + static HcclComm LoadHcclLibrary(const std::string &group_name) { + int64_t hccl_comm = collective::CollectiveManager::Instance().GetCommunicationGroup(group_name)->communicator(); + return reinterpret_cast(static_cast(hccl_comm)); + } + // static bool GetHcomDataType(const std::string &kernel_name, const std::vector &inputs, + // const std::vector &outputs, std::vector *data_type_list); + static bool GetHcclOpSize(const HcclDataType &data_type, const std::vector &shape, size_t *size); + static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size); + static bool GetHcomCount(const std::vector &data_type_list, + const std::vector> &shape_list, const size_t input_tensor_num, + const std::optional rank_size_opt, uint64_t *total_count); + + static std::pair GetHcclCountAndTypeFromTensor( + const ir::TensorPtr &tensor, const std::optional rank_size_opt = std::nullopt); + static CollectiveOpReduceType GetCollectiveOpReduceType(const std::string &reduce_op); + static HcclReduceOp GetHcomReduceOpType(const std::string &reduce_op); +}; +} // namespace mrt::ops + +#endif // OPS_ASCEND_HCCL_HCOM_UTILS_H_ diff --git a/inferrt/src/ops/ascend/hccl/tensor_copy.cc b/inferrt/src/ops/ascend/hccl/tensor_copy.cc new file mode 100644 index 00000000..27d1cd5f --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/tensor_copy.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/ascend/hccl/tensor_copy.h" +#include "ops/ascend/hccl/hccl_adapter.h" +#include "ops/ascend/hccl/hcom_utils.h" +#include "hccl/hccl_types.h" +#include "hccl/hccl.h" + +#include "hardware/ascend/res_manager/ascend_res_manager.h" + +#include "common/logger.h" +#include "ops/op_register.h" + +namespace mrt { +namespace ops { + +OpsErrorCode HcclTensorCopy::InferShape(const std::vector &input, ir::Value *output) { + LOG_OUT << "TensorCopy InferShape"; + auto &input0Shape = input[kIndex0]->ToTensor()->Shape(); + auto &outputTensor = output->ToTensor(); + auto &outputShape = outputTensor->Shape(); + outputShape = input0Shape; + outputTensor->Resize(); + return SUCCESS; +} + +OpsErrorCode HcclTensorCopy::CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspace_size) { + return SUCCESS; +} + +OpsErrorCode HcclTensorCopy::Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) { + LOG_OUT << "TensorCopy launch"; + auto src_tensor = input[kIndex1]->ToTensor(); + auto out_tensor = input[kIndex0]->ToTensor(); + auto dst_size = out_tensor->Numel() * out_tensor->Dtype().GetSize(); + + // host_ptr, size, device_ptr, size, ACL_MEMCPY_DEVICE_TO_HOST, stream_ptr + auto ret = mrt::device::ascend::AscendResManager::MemcpyDeviceToDevice(out_tensor->DataPtr(), dst_size, + src_tensor->DataPtr(), dst_size, stream); + if (ret == false) { + LOG_ERROR << " call aclrtMemcpyAsync in Op HcclTensorCopy failed"; + } + + return SUCCESS; +} +MRT_REG_OP(copy, HcclTensorCopy, Ascend); +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/tensor_copy.h b/inferrt/src/ops/ascend/hccl/tensor_copy.h new file mode 100644 index 00000000..6c13af49 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/tensor_copy.h @@ -0,0 +1,44 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OPS_ASCEND_HCCL_TENSOR_COPY_H_ +#define OPS_ASCEND_HCCL_TENSOR_COPY_H_ + +#include + +#include "ops/operator.h" +#include "ops/ascend/hccl/hccl_kernel.h" + +namespace mrt { +namespace ops { +class HcclTensorCopy : public Operator { + public: + HcclTensorCopy() = default; + ~HcclTensorCopy() = default; + + OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; + + OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspace_size) override; + OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) override; + + private: + HcclKernel hcclKernel; +}; +} // namespace ops +} // namespace mrt +#endif // OPS_ASCEND_HCCL_TENSOR_COPY_H_ diff --git a/inferrt/src/ops/ascend/hccl/wait_tensor.cc b/inferrt/src/ops/ascend/hccl/wait_tensor.cc new file mode 100644 index 00000000..7d9be8e2 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/wait_tensor.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/ascend/hccl/wait_tensor.h" +#include "ops/ascend/hccl/hccl_adapter.h" +#include "ops/ascend/hccl/hcom_utils.h" +#include "hccl/hccl_types.h" +#include "hccl/hccl.h" + +#include "common/logger.h" +#include "ops/op_register.h" +#include "hardware/ascend/res_manager/ascend_stream_manager.h" +#include "hardware/ascend/res_manager/ascend_res_manager.h" +namespace mrt { +namespace ops { + +OpsErrorCode HcclWaitTensor::InferShape(const std::vector &input, ir::Value *output) { + LOG_OUT << "WaitTensor InferShape"; + auto &input0Shape = input[kIndex0]->ToTensor()->Shape(); + auto &outputTensor = output->ToTensor(); + auto &outputShape = outputTensor->Shape(); + outputShape = input0Shape; + outputTensor->Resize(); + return SUCCESS; +} + +OpsErrorCode HcclWaitTensor::CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspace_size) { + return SUCCESS; +} + +OpsErrorCode HcclWaitTensor::Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) { + LOG_OUT << "WaitTensor launch"; + + auto src_tensor = input[kIndex0]->ToTensor(); + auto out_tensor = output->ToTensor(); + auto dst_size = out_tensor->Numel() * out_tensor->Dtype().GetSize(); + + auto ret = mrt::device::ascend::AscendResManager::MemcpyDeviceToDevice(out_tensor->DataPtr(), dst_size, + src_tensor->DataPtr(), dst_size, stream); + if (ret == false) { + LOG_ERROR << " call aclrtMemcpyAsync in Op HcclTensorCopy failed"; + } + + mrt::device::ascend::AscendStreamMng::GetInstance().SyncStream(stream); + return SUCCESS; +} +MRT_REG_OP(wait_tensor, HcclWaitTensor, Ascend); +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/wait_tensor.h b/inferrt/src/ops/ascend/hccl/wait_tensor.h new file mode 100644 index 00000000..cbd86574 --- /dev/null +++ b/inferrt/src/ops/ascend/hccl/wait_tensor.h @@ -0,0 +1,44 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OPS_ASCEND_HCCL_WAIT_TENSOR_H_ +#define OPS_ASCEND_HCCL_WAIT_TENSOR_H_ + +#include + +#include "ops/operator.h" +#include "ops/ascend/hccl/hccl_kernel.h" + +namespace mrt { +namespace ops { +class HcclWaitTensor : public Operator { + public: + HcclWaitTensor() = default; + ~HcclWaitTensor() = default; + + OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; + + OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, + size_t *workspace_size) override; + OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, + ir::Value *output, void *stream) override; + + private: + HcclKernel hcclKernel; +}; +} // namespace ops +} // namespace mrt +#endif // OPS_ASCEND_HCCL_WAIT_TENSOR_H_ diff --git a/inferrt/src/ops/op_base/op_all_gather.cc b/inferrt/src/ops/op_base/op_all_gather.cc new file mode 100644 index 00000000..4f26e370 --- /dev/null +++ b/inferrt/src/ops/op_base/op_all_gather.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/op_base/op_all_gather.h" + +namespace mrt { +namespace ops { +OpsErrorCode OpAllGather::InferShape(const std::vector &input, ir::Value *output) { + LOG_OUT << "HcclAllGather InferShape"; + + auto &input0Shape = input[kIndex0]->ToTensor()->Shape(); + auto rank_size = input[kIndex1]->ToInt(); + + auto outputShape = input0Shape; + outputShape[0] *= rank_size; + + auto outputTensor = output->ToTensor(); + CHECK_IF_NULL(outputTensor); + outputTensor->SetShape(outputShape); + auto outputDtype = input[kIndex0]->ToTensor()->Dtype(); + outputTensor->SetDtype(outputDtype); + outputTensor->Resize(); + + return SUCCESS; +} +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/op_base/op_all_gather.h b/inferrt/src/ops/op_base/op_all_gather.h new file mode 100644 index 00000000..1b9886b7 --- /dev/null +++ b/inferrt/src/ops/op_base/op_all_gather.h @@ -0,0 +1,37 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_OP_BASE_OP_ALL_GATHER_H__ +#define __OPS_OP_BASE_OP_ALL_GATHER_H__ + +#include +#include +#include + +#include "ops/operator.h" + +namespace mrt { +namespace ops { +class OpAllGather : public Operator { + public: + OpAllGather() = default; + ~OpAllGather() override = default; + + OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; +}; +} // namespace ops +} // namespace mrt +#endif // __OPS_OP_BASE_OP_ALL_GATHER_H__ diff --git a/inferrt/src/ops/op_def/ops.list b/inferrt/src/ops/op_def/ops.list index 17d1fa6b..285b00bd 100644 --- a/inferrt/src/ops/op_def/ops.list +++ b/inferrt/src/ops/op_def/ops.list @@ -39,3 +39,6 @@ OP(update_state) OP(load) OP(depend) OP(return) +OP(all_gather) +OP(wait_tensor) +OP(copy) diff --git a/inferrt/src/pybind/CMakeLists.txt b/inferrt/src/pybind/CMakeLists.txt index bd3fe259..669de4b9 100644 --- a/inferrt/src/pybind/CMakeLists.txt +++ b/inferrt/src/pybind/CMakeLists.txt @@ -55,7 +55,6 @@ else() # Include pybind11 directories include_directories(${PYBIND11_PATH}/include) - add_subdirectory(${PYBIND11_PATH}) endif() diff --git a/inferrt/src/pybind/mrt/CMakeLists.txt b/inferrt/src/pybind/mrt/CMakeLists.txt index 64fe5bf7..9750b22b 100644 --- a/inferrt/src/pybind/mrt/CMakeLists.txt +++ b/inferrt/src/pybind/mrt/CMakeLists.txt @@ -6,7 +6,7 @@ target_link_libraries(_mrt_api PUBLIC inferrt) set_target_properties(_mrt_api PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/lib" - BUILD_WITH_INSTALL_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE ) install( @@ -19,13 +19,36 @@ install( pybind11_add_module(_mrt_ir NO_EXTRAS pybind11_ir.cc) target_link_libraries(_mrt_ir PUBLIC inferrt) +# Add collective pybind11 sub module +if(ENABLE_ASCEND) + if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH}) + else() + set(ASCEND_PATH /usr/local/Ascend) + endif() + include_directories(${ASCEND_PATH}/latest/include/) + include_directories(${ASCEND_PATH}/latest/lib64/) + include_directories(${ASCEND_PATH}/latest/aarch64-linux/include/experiment) + pybind11_add_module(_mrt_collective NO_EXTRAS pybind11_collective.cc) + target_link_libraries(_mrt_collective PUBLIC inferrt hardware_ascend) + set_target_properties(_mrt_collective PROPERTIES + INSTALL_RPATH "$ORIGIN:$ORIGIN/lib" + BUILD_WITH_INSTALL_RPATH TRUE + ) + install( + TARGETS _mrt_collective + LIBRARY DESTINATION . + RUNTIME DESTINATION bin + ) +endif() + set_target_properties(_mrt_ir PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/lib" - BUILD_WITH_INSTALL_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE ) install( TARGETS _mrt_ir LIBRARY DESTINATION . RUNTIME DESTINATION bin -) \ No newline at end of file +) diff --git a/inferrt/src/pybind/mrt/pybind11_collective.cc b/inferrt/src/pybind/mrt/pybind11_collective.cc new file mode 100644 index 00000000..4235ad12 --- /dev/null +++ b/inferrt/src/pybind/mrt/pybind11_collective.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "hardware/hardware_abstract/collective/collective_manager.h" + +namespace py = pybind11; +using CollectiveManager = mrt::collective::CollectiveManager; + +PYBIND11_MODULE(_mrt_collective, mod) { + (void)py::class_(mod, "CollectiveManager") + .def_static("instance", &CollectiveManager::Instance, py::return_value_policy::reference) + .def("create_communication_group", &CollectiveManager::CreateCommunicationGroup) + .def("is_group_exist", &CollectiveManager::IsGroupExist) + .def("get_group_rank", &CollectiveManager::GetGroupRank) + .def("get_group_size", &CollectiveManager::GetGroupSize) + .def("set_global_rank_id", &CollectiveManager::SetGlobalRankId) + .def("set_local_rank_id", &CollectiveManager::SetLocalRankId) + .def("set_global_rank_size", &CollectiveManager::SetGlobalRankSize) + .def("global_rank_id", &CollectiveManager::global_rank_id) + .def("local_rank_id", &CollectiveManager::local_rank_id) + .def("global_rank_size", &CollectiveManager::global_rank_size); +} diff --git a/inferrt/src/pybind/mrt_torch/CMakeLists.txt b/inferrt/src/pybind/mrt_torch/CMakeLists.txt index 52f9b564..a8c90742 100644 --- a/inferrt/src/pybind/mrt_torch/CMakeLists.txt +++ b/inferrt/src/pybind/mrt_torch/CMakeLists.txt @@ -1,7 +1,10 @@ check_debug_log_out() # Link against torch libraries for _mrt_torch -execute_process(COMMAND python -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'share/cmake'))" OUTPUT_VARIABLE PYTORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) +execute_process(COMMAND python -c +"import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'share/cmake'))" +OUTPUT_VARIABLE PYTORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) + set(CMAKE_PREFIX_PATH "${PYTORCH_CMAKE_PATH}") find_package(Torch REQUIRED) find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") @@ -12,7 +15,9 @@ target_link_libraries(_mrt_torch PUBLIC inferrt ${TORCH_LIBRARIES} ${TORCH_PYTHO if(ENABLE_ASCEND) add_compile_definitions(ENABLE_TORCH_NPU) - execute_process(COMMAND python -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))" OUTPUT_VARIABLE TORCH_NPU_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND python -c + "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))" + OUTPUT_VARIABLE TORCH_NPU_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) message("TORCH_NPU_PATH: ${TORCH_NPU_PATH}") include_directories(${TORCH_NPU_PATH}/include/) @@ -22,7 +27,7 @@ endif() set_target_properties(_mrt_torch PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/lib:${TORCH_INSTALL_PREFIX}/lib" - BUILD_WITH_INSTALL_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE ) install( diff --git a/setup.py b/setup.py index 1172c55e..ecc2ebcf 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,7 @@ package_name = 'mrt' special_so_files_patterns = [ '_mrt_api*.so', # Matches all .so files starting with _mrt_api '_mrt_ir*.so', # Matches all .so files starting with _mrt_ir + '_mrt_collective*.so', # Matches all .so files starting with _mrt_collective '_mrt_torch*.so' # Matches all .so files starting with _mrt_torch ] diff --git a/tests/st/check/check_distributed_backend.py b/tests/st/check/check_distributed_backend.py new file mode 100644 index 00000000..755d3759 --- /dev/null +++ b/tests/st/check/check_distributed_backend.py @@ -0,0 +1,107 @@ +import os +from typing import List + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d +from torch._C._distributed_c10d import _resolve_process_group + +import mrt +from mrt import jit + +from mrt.collective import CollectiveManager +from mrt.torch import backend + +BACKEND_HCCL = "hccl" +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = os.getenv('MASTER_ADDR', 'localhost') + os.environ['MASTER_PORT'] = os.getenv('MASTER_PORT', '6689') + + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + os.environ['RANK_ID'] = str(rank) + + dist.init_process_group( + BACKEND_HCCL, + rank=rank, + world_size=world_size, + init_method='env://' + ) + + +def cleanup(): + dist.destroy_process_group() + + +def check_allgather_info(pg=None): + ptd = pg.group_name + rank = dist.get_rank() if dist.is_initialized() else 0 + local_rank = int(os.getenv('LOCAL_RANK', "0")) + world_size = dist.get_world_size() + + group_rank = dist.get_rank(pg) + rank_list = dist.get_process_group_ranks(pg) + group_size = dist.get_world_size(pg) + + group_rank_id = CollectiveManager.instance().get_group_rank(f"{ptd}") + group_rank_size = CollectiveManager.instance().get_group_size(f"{ptd}") + + assert CollectiveManager.instance().global_rank_id() == rank, f"got global_rank_id {CollectiveManager.instance().global_rank_id()}, but expected {rank}" + assert CollectiveManager.instance().local_rank_id() == int(local_rank), f"got local_rank_id {CollectiveManager.instance().local_rank_id()}, but expected {local_rank}" + assert CollectiveManager.instance().global_rank_size() == world_size, f"got global_rank_size {CollectiveManager.instance().global_rank_size()}, but expected {world_size}" + assert group_rank_id == group_rank + assert group_rank_size == group_size + + return + +def check_allgather_output(output): + output = output.cpu() + expect_output = torch.tensor([ 100, 400, 400, 1600]) + assert (output == expect_output).all() + +class SimpleNetwork(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, gathered, pg=None): + dist.all_gather_into_tensor(gathered, x, group=pg) + output = torch.mul(gathered, gathered) + return output + + +def train(rank, world_size): + + setup(rank, world_size) + + rank = dist.get_rank() if dist.is_initialized() else 0 + new_pg = dist.new_group([0,1]) + + model = SimpleNetwork().npu() + example_input = torch.tensor([1,2]).npu() * (rank+1) * 10 + + compiled_model = torch.compile( + model, + backend=backend, + fullgraph=True, + ) + + world_size = dist.get_world_size(new_pg) + print(f"rank {rank} world size {world_size}") + gathered = [torch.zeros_like(example_input) for _ in range(world_size)] + gathered = torch.cat(gathered, dim=0) + output = compiled_model(example_input, gathered, new_pg) + + check_allgather_info(new_pg) + check_allgather_output(output) + + print(f"rank {rank} world size {world_size} output is {output}") + dist.destroy_process_group() + +if __name__ == "__main__": + rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv('WORLD_SIZE', '1')) + local_rank = int(os.getenv('LOCAL_RANK', '0')) + + torch.npu.set_device(local_rank) + train(rank, world_size) -- Gitee