diff --git a/cmake/scripts/gen_kernel_tiling_data_def.py b/cmake/scripts/gen_kernel_tiling_data_def.py index 1fe4e80cbf78869f0ba9dcc15cb4bae72d263d31..25ff78302a43772a3c7fae7bfc6910b4c006b6e0 100644 --- a/cmake/scripts/gen_kernel_tiling_data_def.py +++ b/cmake/scripts/gen_kernel_tiling_data_def.py @@ -29,7 +29,7 @@ def gen_tiling(tiling_header_file): if (line.startswith('BEGIN_TILING_DATA_DEF')): single_tiling_source += '#pragma pack(push, 8)\n' single_tiling_source += 'struct ' - struct_def = re.findall(pattern, line)[0] + struct_def = re.findall(pattern, line)[0] single_tiling_source += struct_def + ' {\n' elif (line.startswith('TILING_DATA_FIELD_DEF_ARR')): field_params = re.findall(pattern, line)[0] @@ -64,7 +64,7 @@ if __name__ == '__main__': """ print("[LOG]: ", sys.argv[1], sys.argv[2]) file_list = [] - for root, dirs, files in os.walk(sys.argv[1]): + for root, files in os.walk(sys.argv[1]): for file in files: if file.endswith("tilingdata.h"): file_list.append(os.path.join(root, file)) diff --git a/inc/hccl/hccl_types.h b/inc/hccl/hccl_types.h index 096d410f8249313af56ab6ff744db0251ad932a1..976a2b5999dd757860d7dc1211300f1277090b5e 100644 --- a/inc/hccl/hccl_types.h +++ b/inc/hccl/hccl_types.h @@ -43,6 +43,7 @@ typedef enum { HCCL_E_NETWORK = 19, /**< call network api fail */ HCCL_E_AGAIN = 20, /**< try again */ HCCL_E_REMOTE = 21, /**< error cqe */ + HCCL_E_SUSPENDING = 22, /**< error communicator suspending */ HCCL_E_RESERVED /**< reserved */ } HcclResult; diff --git a/src/domain/collective_communication/algorithm/CMakeLists.txt b/src/domain/collective_communication/algorithm/CMakeLists.txt index 2bbdca44557e7821aeb4cefe19475b3e7d9a3f62..16c5e243d929113c1de4673348eb52fd3c1b5eb6 100644 --- a/src/domain/collective_communication/algorithm/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/CMakeLists.txt @@ -31,6 +31,14 @@ target_include_directories(hccl_alg PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_send_receive ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_reduce ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_reduce/310P + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_to_all + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather/310P + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_reduce_scatter + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_reduce_scatter/310P + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_scatter + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_broadcast + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_broadcast/310P ${HCCL_BASE_DIR}/../../../../inc/external ${HCCL_BASE_DIR}/../../../../inc/external/hccl ${HCCL_BASE_DIR}/../../../../inc/hccl @@ -133,4 +141,4 @@ if(BUILD_OPEN_PROJECT) set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${ASCEND_CANN_PACKAGE_PATH}/tools/ascend_project/cmake/makeself.cmake) set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME}) include(CPack) -endif() \ No newline at end of file +endif() diff --git a/src/domain/collective_communication/algorithm/base/communicator/CMakeLists.txt b/src/domain/collective_communication/algorithm/base/communicator/CMakeLists.txt index e1f095b435b8f03c042c3933a1563ae692e33737..2ad5a64ce2ae4652219f06c22345c2428e72083e 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/base/communicator/CMakeLists.txt @@ -11,6 +11,7 @@ set(src_list ${CMAKE_CURRENT_SOURCE_DIR}/calc_partial_mesh_transport_req.cc ${CMAKE_CURRENT_SOURCE_DIR}/calc_ring_transport_req.cc ${CMAKE_CURRENT_SOURCE_DIR}/calc_transport_req_base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/calc_p2p_transport_req.cc ) target_sources(hccl_alg PRIVATE diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_hd_transport_req.cc b/src/domain/collective_communication/algorithm/base/communicator/calc_hd_transport_req.cc index d46eb487ab296f18c660a1e50706c64ee34cad34..adcc3ab792d0952dc1b86b95444d211464b3174f 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/calc_hd_transport_req.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_hd_transport_req.cc @@ -11,7 +11,7 @@ #include "calc_hd_transport_req.h" namespace hccl { -CalcHDTransportReq::CalcHDTransportReq(std::vector> &subCommPlaneVector, +CalcHDTransportReq::CalcHDTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank) : CalcTransportReqBase(subCommPlaneVector, isBridgeVector, userRank) { @@ -23,7 +23,7 @@ CalcHDTransportReq::~CalcHDTransportReq() HcclResult CalcHDTransportReq::CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) + std::vector &commTransport, u32 subUserRankRoot) { u32 ringSize = subCommPlaneVector_.size(); commTransport.resize(ringSize); @@ -46,8 +46,13 @@ HcclResult CalcHDTransportReq::CalcTransportRequest(const std::string &tag, Tran HCCL_INFO("comm base needn't to create links, rankSize_[%u].", rankSize); return HCCL_SUCCESS; } + + u32 subRoot = INVALID_VALUE_RANKID; + if (subUserRankRoot != INVALID_VALUE_RANKID) { + CHK_RET(GetRankByUserRank(subCommPlaneVector_[ringIndex], subUserRankRoot, subRoot)); + } - std::vector linkRelation = ExecutorBase::CalcLinksRelation(rank, rankSize, INVALID_VALUE_RANKID, + std::vector linkRelation = ExecutorBase::CalcLinksRelation(rank, rankSize, subRoot, HalvingDoublingType::RECURSIVE_HALVING_DOUBLING); for (u32 rankIndex = 0; rankIndex < rankSize; rankIndex++) { @@ -55,7 +60,7 @@ HcclResult CalcHDTransportReq::CalcTransportRequest(const std::string &tag, Tran if (linkRelation[rankIndex] == true) { tmpTransport.isValid = true; tmpTransport.localUserRank = userRank_; - tmpTransport.remoteUserRank = subCommPlaneVector_[ringIndex][rankIndex].userRank; + tmpTransport.remoteUserRank = subCommPlaneVector_[ringIndex][rankIndex]; tmpTransport.inputMemType = inputMemType; tmpTransport.outputMemType = outputMemType; HCCL_INFO("[CommFactory][CalcHDCommInfo] param_.tag[%s] ringIndex[%u], localRank[%u], "\ diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_mesh_transport_req.cc b/src/domain/collective_communication/algorithm/base/communicator/calc_mesh_transport_req.cc index defd92cb0beab9fe09e081e9fb034162d624ade4..d6dfc5ca5a0e645cf06042f22a9c1b7b690cb329 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/calc_mesh_transport_req.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_mesh_transport_req.cc @@ -12,7 +12,7 @@ namespace hccl { -CalcMeshTransportReq::CalcMeshTransportReq(std::vector> &subCommPlaneVector, +CalcMeshTransportReq::CalcMeshTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank) : CalcTransportReqBase(subCommPlaneVector, isBridgeVector, userRank) { @@ -24,7 +24,7 @@ CalcMeshTransportReq::~CalcMeshTransportReq() HcclResult CalcMeshTransportReq::CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) + std::vector &commTransport, u32 subUserRankRoot) { u32 ringSize = subCommPlaneVector_.size(); // 910B非确定性计算场景,server内MESH组网只需要创建一个commbase平面 @@ -57,7 +57,7 @@ HcclResult CalcMeshTransportReq::CalcTransportRequest(const std::string &tag, Tr if (rankIndex != rank) { tmpTransport.isValid = true; tmpTransport.localUserRank = userRank_; - tmpTransport.remoteUserRank = subCommPlaneVector_[ringIndex][rankIndex].userRank; + tmpTransport.remoteUserRank = subCommPlaneVector_[ringIndex][rankIndex]; tmpTransport.inputMemType = inputMemType; tmpTransport.outputMemType = outputMemType; HCCL_INFO("[CommFactory][CalcMeshCommInfo] param_.tag[%s] ringIndex[%u], localRank[%u], "\ diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_p2p_transport_req.cc b/src/domain/collective_communication/algorithm/base/communicator/calc_p2p_transport_req.cc new file mode 100644 index 0000000000000000000000000000000000000000..8bccfd67c0e662fda1241988bcf0057679922eb5 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_p2p_transport_req.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "calc_p2p_transport_req.h" + + +namespace hccl { +CalcP2PTransportReq::CalcP2PTransportReq(std::vector> &subCommPlaneVector, + std::vector &isBridgeVector, u32 userRank) + : CalcTransportReqBase(subCommPlaneVector, isBridgeVector, userRank) +{ +} + +CalcP2PTransportReq::~CalcP2PTransportReq() +{ +} + +HcclResult CalcP2PTransportReq::CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, + TransportMemType outputMemType, const CommParaInfo &commParaInfo, + std::vector &commTransport, u32 subUserRankRoot) +{ + u32 planeSize = subCommPlaneVector_.size(); + commTransport.resize(planeSize); + // 看一下是否需要循环 + + for (u32 planeIndex = 0; planeIndex < planeSize; planeIndex++) { + u32 rankSize = subCommPlaneVector_[planeIndex].size(); + SingleSubCommTransport &subCommTransport = commTransport[planeIndex]; + subCommTransport.transportRequests.resize(rankSize); + + // send,recv算子只有一张卡时报错 + if (rankSize == 1) { + HCCL_ERROR("[CommFactory][CalcP2PCommInfo] sendrecv rankSize is 1"); + } + TransportRequest &tmpTransport = subCommTransport.transportRequests[0]; + + tmpTransport.isValid = true; + tmpTransport.localUserRank = userRank_; + tmpTransport.remoteUserRank = commParaInfo.peerUserRank; + tmpTransport.inputMemType = inputMemType; + tmpTransport.outputMemType = outputMemType; + HCCL_INFO("[CommFactory][CalcP2PCommInfo] param_.tag[%s] planeIndex[%u], localRank[%u], \ + remoteRank[%u], inputMemType[%d], outputMemType[%d]", tag.c_str(), planeIndex, userRank_, + tmpTransport.remoteUserRank, inputMemType, outputMemType); + } + return HCCL_SUCCESS; +} + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_p2p_transport_req.h b/src/domain/collective_communication/algorithm/base/communicator/calc_p2p_transport_req.h new file mode 100644 index 0000000000000000000000000000000000000000..95d3da7b511823c184f2f5b1494ea0d7e369a061 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_p2p_transport_req.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CALC_P2P_TRANSPORT_REQ_H +#define CALC_P2P_TRANSPORT_REQ_H + +#include "calc_transport_req_base_pub.h" + +namespace hccl { +class CalcP2PTransportReq : public CalcTransportReqBase { +public: + explicit CalcP2PTransportReq(std::vector> &subCommPlaneVector, + std::vector &isBridgeVector, u32 userRank); + + ~CalcP2PTransportReq(); + + HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, + TransportMemType outputMemType, const CommParaInfo &commParaInfo, + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; +}; +} // namespace hccl +#endif /* CALC_RING_TRANSPORT_REQ_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_partial_mesh_transport_req.cc b/src/domain/collective_communication/algorithm/base/communicator/calc_partial_mesh_transport_req.cc index 5854d3c2bcb7c2797c6774b7158cf80e776a0ae6..4fb89bfa96d22d23351c848055e94b766c973dd5 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/calc_partial_mesh_transport_req.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_partial_mesh_transport_req.cc @@ -13,10 +13,9 @@ #include "dtype_common.h" namespace hccl { -CalcPartialMeshTransportReq::CalcPartialMeshTransportReq(std::vector> &subCommPlaneVector, - std::vector &isBridgeVector, u32 userRank, RdmaEnableCheckInfo& rdmaCheckInfo) - : CalcTransportReqBase(subCommPlaneVector, isBridgeVector, userRank), rankData_(rdmaCheckInfo.rankData), - isDiffModuleInServer_(rdmaCheckInfo.isDiffModuleInServer), isUsedRdma_(rdmaCheckInfo.isUsedRdma) +CalcPartialMeshTransportReq::CalcPartialMeshTransportReq(std::vector> &subCommPlaneVector, + std::vector &isBridgeVector, u32 userRank) + : CalcTransportReqBase(subCommPlaneVector, isBridgeVector, userRank) { } @@ -26,12 +25,12 @@ CalcPartialMeshTransportReq::~CalcPartialMeshTransportReq() HcclResult CalcPartialMeshTransportReq::CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) + std::vector &commTransport, u32 subUserRankRoot) { // send/recv分别使用一个comm u32 ringSize = 2; commTransport.resize(ringSize); - + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { if (commParaInfo.commPlane == COMM_LEVEL1 && !isBridgeVector_.empty() && !isBridgeVector_[0]) { continue; // 跳出本次循环 @@ -42,7 +41,7 @@ HcclResult CalcPartialMeshTransportReq::CalcTransportRequest(const std::string & if (rank == INVALID_VALUE_RANKID) { continue; } - + u32 rankSize = subCommPlaneVector_[0].size(); SingleSubCommTransport &subCommTransport = commTransport[ringIndex]; subCommTransport.transportRequests.resize(rankSize); @@ -51,43 +50,26 @@ HcclResult CalcPartialMeshTransportReq::CalcTransportRequest(const std::string & HCCL_INFO("comm base needn't to create links, rankSize_[%u].", rankSize); return HCCL_SUCCESS; } - + for (u32 rankIndex = 0; rankIndex < rankSize; rankIndex++) { TransportRequest &tmpTransport = subCommTransport.transportRequests[rankIndex]; - auto it = commParaInfo.batchSendRecvtargetRanks.find(subCommPlaneVector_[0][rankIndex].userRank); - if (rankIndex == rank || it == commParaInfo.batchSendRecvtargetRanks.end()) { + auto it = commParaInfo.batchSendRecvtargetRanks.find(subCommPlaneVector_[0][rankIndex]); + if (rankIndex != rank && it != commParaInfo.batchSendRecvtargetRanks.end()) { + tmpTransport.isValid = true; + tmpTransport.localUserRank = userRank_; + tmpTransport.remoteUserRank = subCommPlaneVector_[0][rankIndex]; + tmpTransport.inputMemType = inputMemType; + tmpTransport.outputMemType = outputMemType; + HCCL_INFO("[CommFactory][CalcPartialMeshCommInfo] param_.tag[%s] ringIndex[%u], localRank[%u], "\ + "remoteRank[%u], inputMemType[%d], outputMemType[%d]", tag.c_str(), ringIndex, userRank_, + tmpTransport.remoteUserRank, inputMemType, outputMemType); + } else { tmpTransport.isValid = false; continue; } - if (isDiffModuleInServer_ && !isUsedRdma_ && IsNotSupportSDMA(subCommPlaneVector_[0][rankIndex])) { - const std::string CONN_ERR = "Communication between devId[" + std::to_string(rankData_.devicePhyId) + - "] and devId[" + std::to_string(subCommPlaneVector_[0][rankIndex].devicePhyId) + "] isn't support."; - - RPT_INPUT_ERR(true, "EI0010", std::vector({"reason"}), \ - std::vector({CONN_ERR})); - CHK_PRT_RET(true, HCCL_ERROR("[CalcPartialMeshTransportReq] Communication between devId[%d] and "\ - "devId[%d] is not supported. Ensure that the NPU card is normal and entering environment "\ - "variables export HCCL_INTRA_ROCE_ENABLE=1.", rankData_.devicePhyId, - subCommPlaneVector_[0][rankIndex].devicePhyId), HCCL_E_NOT_SUPPORT); - } - tmpTransport.isValid = true; - tmpTransport.localUserRank = userRank_; - tmpTransport.remoteUserRank = subCommPlaneVector_[0][rankIndex].userRank; - tmpTransport.inputMemType = inputMemType; - tmpTransport.outputMemType = outputMemType; - HCCL_INFO("[CommFactory][CalcPartialMeshCommInfo] param_.tag[%s] ringIndex[%u], localRank[%u], "\ - "remoteRank[%u], inputMemType[%d], outputMemType[%d]", tag.c_str(), ringIndex, userRank_, - tmpTransport.remoteUserRank, inputMemType, outputMemType); } } return HCCL_SUCCESS; } -bool CalcPartialMeshTransportReq::IsNotSupportSDMA(const RankInfo &remoteRankData) -{ - return remoteRankData.serverIdx == rankData_.serverIdx && - remoteRankData.devicePhyId / DEVICE_PER_MODULE != rankData_.devicePhyId / DEVICE_PER_MODULE && - remoteRankData.devicePhyId % DEVICE_PER_MODULE != rankData_.devicePhyId % DEVICE_PER_MODULE; -} - } // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_ring_transport_req.cc b/src/domain/collective_communication/algorithm/base/communicator/calc_ring_transport_req.cc index 253362f2bb15adf966d81989796af5c603fea7a1..6caca4158a0a139472c4712028f2be2a8e6637e7 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/calc_ring_transport_req.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_ring_transport_req.cc @@ -11,7 +11,7 @@ #include "calc_ring_transport_req.h" namespace hccl { -CalcRingTransportReq::CalcRingTransportReq(std::vector> &subCommPlaneVector, +CalcRingTransportReq::CalcRingTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank) : CalcTransportReqBase(subCommPlaneVector, isBridgeVector, userRank) { @@ -23,7 +23,7 @@ CalcRingTransportReq::~CalcRingTransportReq() HcclResult CalcRingTransportReq::CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) + std::vector &commTransport, u32 subUserRankRoot) { u32 ringSize = subCommPlaneVector_.size(); commTransport.resize(ringSize); @@ -54,7 +54,7 @@ HcclResult CalcRingTransportReq::CalcTransportRequest(const std::string &tag, Tr rankIndex == (rank + rankSize - HCCL_RANK_OFFSET) % rankSize) { tmpTransport.isValid = true; tmpTransport.localUserRank = userRank_; - tmpTransport.remoteUserRank = subCommPlaneVector_[ringIndex][rankIndex].userRank; + tmpTransport.remoteUserRank = subCommPlaneVector_[ringIndex][rankIndex]; tmpTransport.inputMemType = inputMemType; tmpTransport.outputMemType = outputMemType; HCCL_INFO("[CommFactory][CalcRingCommInfo] param_.tag[%s] ringIndex[%u], localRank[%u], "\ diff --git a/src/domain/collective_communication/algorithm/base/communicator/calc_transport_req_base.cc b/src/domain/collective_communication/algorithm/base/communicator/calc_transport_req_base.cc index 6fcf23b4ee9c64ec38e3760ff7dff3cde6b30f54..aeca50da61d73db9c24e4cf84eb9bfaf6c1764b1 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/calc_transport_req_base.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/calc_transport_req_base.cc @@ -11,7 +11,7 @@ #include "calc_transport_req_base.h" namespace hccl { -CalcTransportReqBase::CalcTransportReqBase(std::vector> &subCommPlaneVector, +CalcTransportReqBase::CalcTransportReqBase(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank) : subCommPlaneVector_(subCommPlaneVector), isBridgeVector_(isBridgeVector), userRank_(userRank) @@ -24,18 +24,18 @@ CalcTransportReqBase::~CalcTransportReqBase() HcclResult CalcTransportReqBase::CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) + std::vector &commTransport, u32 subUserRankRoot) { return HCCL_SUCCESS; } -const u32 CalcTransportReqBase::GetSubCollectiveRank(const std::vector &vecPara) const +const u32 CalcTransportReqBase::GetSubCollectiveRank(const std::vector &vecPara) const { // 在vecPara数据中,查询本user rank,查询到的vec下标就是rank值 u32 tmpRank = INVALID_VALUE_RANKID; for (u32 rankIndex = 0; rankIndex < vecPara.size(); rankIndex++) { - if (userRank_ == vecPara[rankIndex].userRank) { + if (userRank_ == vecPara[rankIndex]) { tmpRank = rankIndex; break; } @@ -44,4 +44,20 @@ const u32 CalcTransportReqBase::GetSubCollectiveRank(const std::vector return tmpRank; } +HcclResult CalcTransportReqBase::GetRankByUserRank(const std::vector &vecPara, + const u32 userRank, u32 &rank) const +{ + // 在vecPara数据中,查询指定userRank,查询到的vec下标就是rank值 + rank = INVALID_VALUE_RANKID; + + for (u32 rankIndex = 0; rankIndex < vecPara.size(); rankIndex++) { + if (userRank_ == vecPara[rankIndex]) { + rank = rankIndex; + break; + } + } + HCCL_INFO("[Get][RankByUserRank]userRank[%u] --> rank[%u]", userRank, rank); + return HCCL_SUCCESS; +} + } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/base/communicator/comm_base.cc b/src/domain/collective_communication/algorithm/base/communicator/comm_base.cc index 26e9d79ab8f0ebc2759e6f8347db50cde09dd0ec..902f8751482ced7d80dca45398faba3311b1182f 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/comm_base.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/comm_base.cc @@ -180,7 +180,7 @@ HcclResult CommBase::CreateLinks() } if (linkThreads_[index]->joinable()) { HCCL_DEBUG("Joining Link Thread[%u]", index); - linkThreads_[index]->join(); // 等待线程执行完毕 + linkThreads_[index]->join(); // 等待线程执行完毕 } if (!IsGeneralServer()) { CHK_RET(hrtResetDevice(deviceLogicId_)); // 防止线程里面异常退出,在进程中reset @@ -206,7 +206,8 @@ HcclResult CommBase::CalcLink() u32 CommBase::GetSocketsPerLink() { - bool multiQpDevType = paraVector_[rank_].deviceType == DevType::DEV_TYPE_910B; + bool multiQpDevType = paraVector_[rank_].deviceType == DevType::DEV_TYPE_910B || + paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73; if (GetExternalInputQpsPerConnection() != HCCL_QPS_PER_CONNECTION_DEFAULT && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && multiQpDevType) { return 2; // 2:多QP方式下额外创建一个socket用于同步QP状态迁移完成状态 @@ -222,7 +223,18 @@ bool CommBase::NeedDataReceivedAck() // 获取rank间的link type HcclResult CommBase::SetTransportType(const u32 dstRank) { - if (paraVector_[rank_].serverId == paraVector_[dstRank].serverId) { + LinkTypeInServer linkType = LinkTypeInServer::RESERVED_LINK_TYPE; + if (GetExternalInputEnableRdmaSdmaConcurrent() && isUsedRdmaOuter_ + && paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73) { + auto localDeviceId = paraVector_[rank_].devicePhyId; + auto remoteDeviceId = paraVector_[dstRank].devicePhyId; + CHK_RET(hrtGetPairDeviceLinkType(static_cast(localDeviceId), + static_cast(remoteDeviceId), linkType)); + } + // 适配910_73的RDMA+SIO ring,创建RDMA类型下的SIO连接 + if (linkType == LinkTypeInServer::SIO_TYPE && paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73) { + transportType_[dstRank] = TransportType::TRANS_TYPE_P2P; + } else if (paraVector_[rank_].serverId == paraVector_[dstRank].serverId) { // 判断是否在同一个server if (isNeedHeterogP2P_) { transportType_[dstRank] = TransportType::TRANS_TYPE_HETEROG_P2P; @@ -555,6 +567,30 @@ HcclResult CommBase::CalcLinksNum(const MachineType machineType, const u32 dstRa isUsedRdmaOuter_ || isAlltoAllCommMesh_; bool isInterHccs = IsSupportInterHccs(dstRank); + if (GetExternalInputEnableRdmaSdmaConcurrent() && isUsedRdmaOuter_ && + paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73) { + auto localDeviceId = paraVector_[rank_].devicePhyId; + auto remoteDeviceId = paraVector_[dstRank].devicePhyId; + // 计算linkType + LinkTypeInServer linkType = LinkTypeInServer::HCCS_TYPE; + CHK_RET(hrtGetPairDeviceLinkType(static_cast(localDeviceId), + static_cast(remoteDeviceId), linkType)); + HCCL_DEBUG("[Calc][LinksNum]rank[%u], dstRank[%u], isInterRdma[%d], isInterHccs[%d], link type[%u]", + rank_, dstRank, isInterRdma, isInterHccs, linkType); + if (linkType == LinkTypeInServer::SIO_TYPE) { + isInterRdma = false; + isInterHccs = true; + HCCL_DEBUG("[Calc][LinksNum]EnableRdmaSdma rank[%u], rankDevId[%u], ip[%s], dstRank[%u], dstDevId[%u], "\ + "dstIp[%s] adjust to SIO.", rank_, localDeviceId, paraVector_[rank_].nicIp[0].GetReadableAddress(), + dstRank, remoteDeviceId, paraVector_[dstRank].nicIp[0].GetReadableAddress()); + } else { + isInterRdma = true; + isInterHccs = false; + HCCL_DEBUG("[Calc][LinksNum]EnableRdmaSdma rank[%u], rankDevId[%u], ip[%s], dstRank[%u], dstDevId[%u], "\ + "dstIp[%s] link type[%u].", rank_, localDeviceId, paraVector_[rank_].nicIp[0].GetReadableAddress(), + dstRank, remoteDeviceId, paraVector_[dstRank].nicIp[0].GetReadableAddress(), linkType); + } + } HCCL_DEBUG("[Calc][LinksNum]rank[%u], dstRank[%u], isInterRdma[%d], isInterHccs[%d], machineType[%d]", rank_, dstRank, isInterRdma, isInterHccs, machineType); @@ -1164,7 +1200,15 @@ HcclResult CommBase::GetSuperNodeIntraRankIPInfo(std::map & bool CommBase::IsSupportInterHccs(const u32 dstRank) { - return false; + // 仅判断超节点内, 兼容打平通信域同时有server内和server间, 因此不判断server_id + bool isInterHccs = GetExternalInputInterHccsDisable() == false && + paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73 && + paraVector_[rank_].superPodId.empty() == false && + paraVector_[rank_].superPodId == paraVector_[dstRank].superPodId; + + HCCL_INFO("[IsSupportInterHccs]rank[%u], superPodId[%s], dstRank[%u], dstSuperPodId[%s], isInterHccs[%d]", + rank_, paraVector_[rank_].superPodId.c_str(), dstRank, paraVector_[dstRank].superPodId.c_str(), isInterHccs); + return isInterHccs; } void CommBase::SetMachineLinkMode(MachinePara &machinePara) diff --git a/src/domain/collective_communication/algorithm/base/communicator/comm_factory.cc b/src/domain/collective_communication/algorithm/base/communicator/comm_factory.cc index 40bb89f35ed70a940334ffa0db7c2804bf1c226a..b23e23bfb526126954578a82a51fd8cdf2696e65 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/comm_factory.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/comm_factory.cc @@ -16,7 +16,7 @@ #include "device_capacity.h" #include "nonuniform_hierarchical_ring_v1_base_pub.h" #include "search_path.h" - +#include "calc_p2p_transport_req.h" namespace hccl { // 模组设备数量 constexpr u32 SERVER_RANK_SIZE = 8; @@ -156,14 +156,14 @@ HcclResult CommFactory::CheckCommPara(const std::string &tag, const DeviceMem &i break; } case CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING_V1: { - isSupport = (commParaInfo.commPlane == COMM_LEVEL1); + isSupport = (commParaInfo.commPlane == COMM_LEVEL1 && deviceType_ != DevType::DEV_TYPE_910_73); break; } case CommType::COMM_TAG_STAR: case CommType::COMM_TAG_WHOLE_NHR: case CommType::COMM_TAG_WHOLE_NHR_V1: case CommType::COMM_TAG_WHOLE_NB: { - isSupport = true; + isSupport = (deviceType_ != DevType::DEV_TYPE_910_73); break; } default: { @@ -224,6 +224,69 @@ HcclResult CommFactory::GetIsUsedRdma(const CommParaInfo &commParaInfo, bool &is return HCCL_SUCCESS; } +HcclResult CommFactory::GetIsUsedRdmaMap(std::unordered_map &isUsedRdmaMap) +{ + for (const RankInfo &dstRank : rankVector_) { + bool isInterSuperPod = false; + bool isInterServer = false; + bool isConnectedWithPcie = false; + if (rankData_.superPodId != dstRank.superPodId) { // 跨超节点场景 + isInterSuperPod = true; + } else if (rankData_.serverIdx != dstRank.serverIdx) { // 不跨超节点, 跨server场景 + isInterServer = true; + } else { // 同server, PCIE互连场景 + auto it = deviceLinkTypeMap_.find(dstRank.devicePhyId); + CHK_PRT_RET(it == deviceLinkTypeMap_.end(), + HCCL_ERROR("can't find devicePhyId[%d] in deviceLinkTypeMap_", dstRank.devicePhyId), + HCCL_E_NOT_FOUND); + isConnectedWithPcie |= (it->second == LinkTypeInServer::PXI_TYPE); + } + // 使能RDMA的场景: 1.跨超节点 2.跨server且不使能HCCS 3.PCIE连接且使能RDMA开关 + bool isUsedRdma = (isInterSuperPod) || + (isInterServer && !isUsedInterHccsMode_) || (isConnectedWithPcie && isUsedRdmaOuter_); + isUsedRdmaMap[dstRank.userRank] = isUsedRdma; + HCCL_DEBUG("[GetIsUsedRdma]isUsedRdma[%d], isInterSuperPod[%d], isInterServer[%d], isUsedInterHccsMode_[%d], "\ + "isConnectedWithPcie[%d], isUsedRdmaOuter_[%d], dstRank[%d]", isUsedRdma, isInterSuperPod, isInterServer, + isUsedInterHccsMode_, isConnectedWithPcie, isUsedRdmaOuter_, dstRank.userRank); + } + return HCCL_SUCCESS; +} + +HcclResult CommFactory::GetRankVecInfo(std::vector>> &serverAndsuperPodToRank) +{ + std::vector> serverToRank; + std::vector> superPodToRank; + serverToRank.clear(); + superPodToRank.clear(); + u32 firstIdx = 0; + + serverToRank.resize(serverToRank_.size()); + for (auto iterMap = serverToRank_.begin(); iterMap != serverToRank_.end(); iterMap++) { + serverToRank[firstIdx].resize((iterMap->second).size()); + if (!(iterMap->second).empty()) { + for (u32 i = 0; i < (iterMap->second).size(); i++) { + serverToRank[firstIdx][i] = (iterMap->second)[i].userRank; + } + } + firstIdx++; + } + + u32 podFirstIdx = 0; + superPodToRank.resize(superPodToRank_.size()); + for (auto iterMap = superPodToRank_.begin(); iterMap != superPodToRank_.end(); iterMap++) { + if (!(iterMap->second).empty()) { + superPodToRank[podFirstIdx].resize((iterMap->second).size()); + for (u32 i = 0; i < (iterMap->second).size(); i++) { + superPodToRank[podFirstIdx][i] = (iterMap->second)[i].userRank; + } + } + podFirstIdx++; + } + serverAndsuperPodToRank.push_back(serverToRank); + serverAndsuperPodToRank.push_back(superPodToRank); + return HCCL_SUCCESS; +} + HcclResult CommFactory::CreateCommPlane(const std::string &tag, const DeviceMem &inputMem, const DeviceMem &outputMem, const CommParaInfo &commParaInfo, std::vector > &commVec) { @@ -235,6 +298,9 @@ HcclResult CommFactory::CreateCommPlane(const std::string &tag, const DeviceMem CHK_RET(CheckCommPara(tag, inputMem, outputMem, commParaInfo)); bool isUsedRdma = false; CHK_RET(GetIsUsedRdma(commParaInfo, isUsedRdma)); + if (GetExternalInputEnableRdmaSdmaConcurrent() && deviceType_ == DevType::DEV_TYPE_910_73) { + isUsedRdma = commParaInfo.forceRdma; + } switch (commParaInfo.commType) { case CommType::COMM_TAG_RING_INNER: @@ -1096,27 +1162,6 @@ HcclResult CommFactory::SetSingleOuter() return HCCL_SUCCESS; } -bool CheckRankNeighbors(const std::vector &nicList) -{ - // 组成ROH环路必须偶数个,且2节点不能组成双环? - if (nicList.size() % 2 != 0 || nicList.size() < HCCL_DEVICE_NUM_FOUR) { - return false; - } - - std::vector tmpNicList(nicList); - std::sort(tmpNicList.begin(), tmpNicList.end()); - u32 halfNum = 2; - for (u32 i = 0; i < tmpNicList.size() / halfNum; i++) { - auto nicIndex = i * halfNum; - // 检查相邻下标的节点,devID是否相邻 - if (tmpNicList[nicIndex] + 1 != tmpNicList[nicIndex + 1]) { - return false; - } - } - - return true; -} - // 适配ROH平面网段隔离,奇数rank互通,偶数rank互通,奇偶不通 bool CheckSdmaWithRohTopo(const std::vector &nicList, std::vector &topoList) { @@ -1252,7 +1297,7 @@ HcclResult CommFactory::SetTopoInfoForLevel0() CHK_RET(SetSingleOuter()); } else { // 8p-ring/np ring 环场景 u32 ringNum = multiOuterOrder_.size(); - CHK_RET(SetMultiOuter(ringNum)); // 8P_RING场景下,外层拓扑中有四个环; + CHK_RET(SetMultiOuter(ringNum)); // 8P_RING场景下,外层拓扑中有四个环; 910_73场景中适配双环 } return HCCL_SUCCESS; } @@ -1471,8 +1516,8 @@ HcclResult CommFactory::CheckInitInfo() } // 入参组合有效性检查:不支持4P_RING - if ((deviceType_ == DevType::DEV_TYPE_910 || deviceType_ == DevType::DEV_TYPE_910B) && - (topoFlag_ == TopoType::TOPO_TYPE_4P_RING)) { + if ((deviceType_ == DevType::DEV_TYPE_910 || deviceType_ == DevType::DEV_TYPE_910B || + deviceType_ == DevType::DEV_TYPE_910_73) && (topoFlag_ == TopoType::TOPO_TYPE_4P_RING)) { HCCL_ERROR("[Check][InitInfo]Not support the scenes: TopoType[%d] with deviceType[%d] is invalid.", topoFlag_, deviceType_); return HCCL_E_PARA; @@ -1774,93 +1819,6 @@ bool CommFactory::IsDiffDeviceModuleInServer() const return deviceType_ == DevType::DEV_TYPE_910B && isDiffAggregation_; } -/* - * ********************************************************************************* - * comm_factory后续不承担建链功能,只进行通信关系推导 - * ********************************************************************************* -*/ - -HcclResult CommFactory::CalcCommPlaneInfo(const std::string &tag, const CommParaInfo &commParaInfo, - std::vector &commTransport, TransportMemType inputMemType, - TransportMemType outputMemType) -{ - HcclUs startut = TIME_NOW(); - HcclResult ret = HCCL_SUCCESS; - HCCL_INFO("[Calc][CommPlane]tag[%s], identifier[%s], commPlane[%d], commType[%d]", - tag.c_str(), identifier_.c_str(), commParaInfo.commPlane, commParaInfo.commType); - - bool isUsedRdma = false; - CHK_RET(GetIsUsedRdma(commParaInfo, isUsedRdma)); - std::unique_ptr calcTransportReq; - switch (commParaInfo.commType) { - case CommType::COMM_TAG_RING_INNER: - case CommType::COMM_TAG_RING_COMBINED: { - calcTransportReq.reset(new (std::nothrow) CalcRingTransportReq(CommPlaneVector_[commParaInfo.commPlane], - isBridgeVector_, userRank_)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - case CommType::COMM_TAG_HALVING_DOUBLING: { - calcTransportReq.reset(new (std::nothrow) CalcHDTransportReq(CommPlaneVector_[commParaInfo.commPlane], - isBridgeVector_, userRank_)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - case CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING: - case CommType::COMM_TAG_WHOLE_NHR:{ - calcTransportReq.reset(new (std::nothrow) CalcNHRTransportReq(CommPlaneVector_[commParaInfo.commPlane], - isBridgeVector_, userRank_)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - case CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING_V1: - case CommType::COMM_TAG_WHOLE_NHR_V1: { - calcTransportReq.reset(new (std::nothrow) CalcNHRV1TransportReq(CommPlaneVector_[commParaInfo.commPlane], - isBridgeVector_, userRank_)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - case CommType::COMM_TAG_NONUNIFORM_BRUCK: - case CommType::COMM_TAG_WHOLE_NB: { - calcTransportReq.reset(new (std::nothrow) CalcNBTransportReq(CommPlaneVector_[commParaInfo.commPlane], - isBridgeVector_, userRank_)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - case CommType::COMM_TAG_MESH: { - calcTransportReq.reset(new (std::nothrow) CalcMeshTransportReq(CommPlaneVector_[commParaInfo.commPlane], - isBridgeVector_, userRank_)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - case CommType::COMM_TAG_PARTIAL_MESH_COMBINED: { - RdmaEnableCheckInfo rdmaEnableCheckInfo; - rdmaEnableCheckInfo.isUsedRdma = isUsedRdma; - rdmaEnableCheckInfo.isDiffModuleInServer = IsDiffDeviceModuleInServer(); - rdmaEnableCheckInfo.rankData = rankData_; - calcTransportReq.reset(new (std::nothrow) CalcPartialMeshTransportReq - (CommPlaneVector_[commParaInfo.commPlane], isBridgeVector_, userRank_, rdmaEnableCheckInfo)); - ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); - break; - } - default: { - HCCL_ERROR("[Calc][CommPlane]commType[%d] is invalid", commParaInfo.commType); - return HCCL_E_PARA; - } - } - CHK_RET(SetIsUsedRdma(commParaInfo, commTransport, isUsedRdma)); - - CHK_RET(GetRankMap(commParaInfo, commTransport)); - - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Calc][CommPlane]failed, tag[%s], commPlane[%d], commType[%d]", - tag.c_str(), commParaInfo.commPlane, commParaInfo.commType), ret); - - HCCL_INFO("complete commPlane[%d] commType[%d] Calculation, Time:%lld us", - commParaInfo.commPlane, commParaInfo.commType, DURATION_US(TIME_NOW() - startut)); - return HCCL_SUCCESS; -} - HcclResult CommFactory::SetIsUsedRdma(const CommParaInfo &commParaInfo, std::vector &commTransport, bool isUsedRdma) { @@ -1930,4 +1888,26 @@ HcclResult CommFactory::GetUserRank2SubMap(CommPlane commPlane, u32 ringIndex, return HCCL_SUCCESS; } +HcclResult CommFactory::GetCommPlaneRanks(std::vector>> &CommPlaneRanks) +{ + CommPlaneRanks.resize(CommPlaneVector_.size()); + for (u32 level = 0; level < CommPlaneVector_.size(); level ++) { + u32 ringSize = CommPlaneVector_[level].size(); + CommPlaneRanks[level].resize(ringSize); + for (u32 ringIndex = 0 ; ringIndex < ringSize; ringIndex ++) { + u32 rankSize = CommPlaneVector_[level][ringIndex].size(); + CommPlaneRanks[level][ringIndex].resize(rankSize); + for (u32 rankIndex = 0 ; rankIndex < rankSize; rankIndex ++) { + u32 userRank = CommPlaneVector_[level][ringIndex][rankIndex].userRank; + CommPlaneRanks[level][ringIndex][rankIndex] = userRank; + } + } + } + return HCCL_SUCCESS; +} +HcclResult CommFactory::GetIsBridgeVector(std::vector &isBridgeVector) +{ + isBridgeVector = isBridgeVector_; + return HCCL_SUCCESS; +} } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/base/communicator/comm_factory_pub.h b/src/domain/collective_communication/algorithm/base/communicator/comm_factory_pub.h index 8b1c09f81debffc0bd76d4f0c2ae7420fc6e787a..d3dbcd53010bae73c69e15510c2abc59be43dfeb 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/comm_factory_pub.h +++ b/src/domain/collective_communication/algorithm/base/communicator/comm_factory_pub.h @@ -83,8 +83,6 @@ bool Ascending(const RankInfo &first, const RankInfo &second); // 排序规则 bool CompareWithUserRankAscend(const RankInfo &left, const RankInfo &right); // 按UserRank升序 // 生成多个ring环的设备物理ID排序 std::vector> GetRingsOrderByTopoType(u32 ranksSize, TopoType topoType, std::vector &nicList); -bool CheckRankNeighbors(const std::vector &nicList); -bool CheckSdmaWithRohTopo(const std::vector &nicList, std::vector &topoList); class ExchangerNetwork; class CommFactory { @@ -130,14 +128,10 @@ public: std::unordered_map> &rankDevicePhyIdNicInfoMap, std::vector &ranksPort, bool isSetHDCModeInfo, bool isUseRankPort); - /* - * ********************************************************************************* - * comm_factory后续不承担建链功能,只进行通信关系推导 - * ********************************************************************************* - */ - HcclResult CalcCommPlaneInfo(const std::string &tag, const CommParaInfo &commParaInfo, - std::vector &commTransport, TransportMemType inPutMemType, - TransportMemType outPutMemType); + HcclResult GetCommPlaneRanks(std::vector>> &CommPlaneRanks); + HcclResult GetIsBridgeVector(std::vector &isBridgeVector); + HcclResult GetIsUsedRdmaMap(std::unordered_map &isUsedRdmaMap); + HcclResult GetRankVecInfo(std::vector>> &serverAndsuperPodToRank); protected: /* 禁止用户对工厂类的实体做拷贝构造或拷贝赋值的操作,内部有指针成员变量 */ @@ -164,7 +158,7 @@ private: HcclResult CreateCommNHR(const std::string &tag, const DeviceMem &inputMem, const DeviceMem &outputMem, const CommParaInfo &commParaInfo, const std::vector > &commPlaneVec, bool isUsedRdma, std::vector > &commVec); - + HcclResult CreateCommNHRV1(const std::string &tag, const DeviceMem &inputMem, const DeviceMem &outputMem, const CommParaInfo &commParaInfo, const std::vector > &commPlaneVec, bool isUsedRdma, std::vector > &commVec); diff --git a/src/domain/collective_communication/algorithm/base/communicator/comm_p2p.cc b/src/domain/collective_communication/algorithm/base/communicator/comm_p2p.cc index 2383fd9534db39d77eb5d93a2e877a64a352db00..9286ee1c3e49700207393af766283261930b6f15 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/comm_p2p.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/comm_p2p.cc @@ -54,6 +54,7 @@ HcclResult CommP2P::CalcLink() bool isSupportP2P = false; if ((paraVector_[rank_].deviceType == DevType::DEV_TYPE_910) || (paraVector_[rank_].deviceType == DevType::DEV_TYPE_910B) || + (paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73) || (paraVector_[rank_].deviceType == DevType::DEV_TYPE_NOSOC)) { isSupportP2P = true; } else if (paraVector_[rank_].deviceType == DevType::DEV_TYPE_310P3) { diff --git a/src/domain/collective_communication/algorithm/base/communicator/comm_ring.cc b/src/domain/collective_communication/algorithm/base/communicator/comm_ring.cc index 8f784f426cf101e4a8722cd8668a155f15a9171a..892d96aa70dab8324ee401f7aa33efb615116f90 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/comm_ring.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/comm_ring.cc @@ -107,7 +107,8 @@ HcclResult CommRing::CalcLink() // 获取每个 link 需要的 socket 数量 u32 CommRing::GetSocketsPerLink() { - bool multiQpDevType = paraVector_[rank_].deviceType == DevType::DEV_TYPE_910B; + bool multiQpDevType = paraVector_[rank_].deviceType == DevType::DEV_TYPE_910B || + paraVector_[rank_].deviceType == DevType::DEV_TYPE_910_73; if (GetExternalInputQpsPerConnection() != HCCL_QPS_PER_CONNECTION_DEFAULT && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && multiQpDevType) { return 2; // 2:多QP方式下额外创建一个socket用于同步QP状态迁移完成状态 diff --git a/src/domain/collective_communication/algorithm/base/communicator/comm_star.cc b/src/domain/collective_communication/algorithm/base/communicator/comm_star.cc index 2fed9087841886d0b82a436ab51a0d57a2f65381..514e8fc602a0520684e0b2fdfd420f90ff108823 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/comm_star.cc +++ b/src/domain/collective_communication/algorithm/base/communicator/comm_star.cc @@ -13,8 +13,8 @@ constexpr s32 NORMAL_QP_MODE = 0; constexpr s32 OFFLINE_QP_MODE = 1; constexpr s32 OPBASE_QP_MODE = 2; -constexpr s32 OFFLINE_QP_MODE_EXT = 3; // 下沉模式(910B)QP -constexpr s32 OPBASE_QP_MODE_EXT = 4; // 单算子模式(910B)的QP +constexpr s32 OFFLINE_QP_MODE_EXT = 3; // 下沉模式(910B/91073)QP +constexpr s32 OPBASE_QP_MODE_EXT = 4; // 单算子模式(910B/91073)的QP namespace hccl { constexpr s32 MODULE_TYPE_SYSTEM = 0; diff --git a/src/domain/collective_communication/algorithm/base/communicator/search_path.h b/src/domain/collective_communication/algorithm/base/communicator/search_path.h index 4532b88819c33db93547d1122fde14afbf6ceb7b..4307a824c99f5ec9dcafc0ea77ea4128d4a57568 100644 --- a/src/domain/collective_communication/algorithm/base/communicator/search_path.h +++ b/src/domain/collective_communication/algorithm/base/communicator/search_path.h @@ -65,7 +65,7 @@ private: std::vector result_; std::set nicSet_; - + // 适配910_73设备双轨组网,可通过SIO串联 std::map> reachableRank_ = { {0, {1, 2, 4, 6, 8, 10, 12, 14}}, {1, {0, 3, 5, 7, 9, 11, 13, 15}}, diff --git a/src/domain/collective_communication/algorithm/base/executor/alltoallv_staged_mesh.cc b/src/domain/collective_communication/algorithm/base/executor/alltoallv_staged_mesh.cc index 5be5cb1c19eb347e8f496f772ce672117099ef76..67e2a2c522394f339748a761f3575ebdd6b97534 100644 --- a/src/domain/collective_communication/algorithm/base/executor/alltoallv_staged_mesh.cc +++ b/src/domain/collective_communication/algorithm/base/executor/alltoallv_staged_mesh.cc @@ -38,7 +38,7 @@ HcclResult AlltoAllVStagedMesh::Prepare(DeviceMem &sendMem, DeviceMem &recvMem, subStreams_ = subStreams; isAlltoAllZCopyMode_ = isAlltoAllZCopyMode; - HCCL_DEBUG("[AlltoAllVStagedMesh][Prepare] finished"); + HCCL_DEBUG("[AlltoAllVStagedMesh][Prepare] finished and isAlltoAllZCopyMode_[%d]", isAlltoAllZCopyMode_); return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/algorithm/base/executor/reduce_ring.cc b/src/domain/collective_communication/algorithm/base/executor/reduce_ring.cc index 5d11e7e5382369d87bdfbf3e54960992db90c9ce..7cbd42302b9d97e8f16d7cb3e2f293919fed7fa1 100644 --- a/src/domain/collective_communication/algorithm/base/executor/reduce_ring.cc +++ b/src/domain/collective_communication/algorithm/base/executor/reduce_ring.cc @@ -31,8 +31,8 @@ HcclResult ReduceRing::RunAsync(const u32 rank, const u32 rankSize, HCCL_E_PARA); HcclResult ret = HCCL_SUCCESS; - HCCL_INFO("ReduceRing run: rank[%u] totalrank[%u] inputmem[%p] output[%p] count[%llu]", \ - rank, rankSize, inputMem_.ptr(), outputMem_.ptr(), count_); + HCCL_INFO("ReduceRing run: rank[%u] totalrank[%u] root[%u] inputmem[%p] output[%p] count[%llu]", \ + rank, rankSize, root_, inputMem_.ptr(), outputMem_.ptr(), count_); // 如果ranksize为1, inline reduce和普通跨片reduce操作一致,从input->output if (rankSize == 1) { diff --git a/src/domain/collective_communication/algorithm/base/inc/all_gather_pipeline_pub.h b/src/domain/collective_communication/algorithm/base/inc/all_gather_pipeline_pub.h index 0a9c8a24a378ceb9bd5552b248c6ea4c81e0f09f..c38f88ae553392528f9601a3fea882e1cc157b26 100644 --- a/src/domain/collective_communication/algorithm/base/inc/all_gather_pipeline_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/all_gather_pipeline_pub.h @@ -7,7 +7,7 @@ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. * See LICENSE in the root of the software repository for the full text of the License. */ - + #ifndef ALL_GATHER_PIPELINE_PUB_H #define ALL_GATHER_PIPELINE_PUB_H @@ -21,50 +21,57 @@ #include "mem_device_pub.h" #include "stream_pub.h" #include "executor_base_pub.h" - +#include "coll_alg_param.h" + namespace hccl { - + class AllGatherPipeline : public ExecutorBase { public: explicit AllGatherPipeline(const HcclDispatcher dispatcher); ~AllGatherPipeline() override; - + HcclResult Prepare(HcomCollOpInfo *opInfo, u32 userRank, u64 &count, DeviceMem &cclBufferPartOne, DeviceMem &cclBufferPartTwo, std::unique_ptr &commOuter, std::unique_ptr &commInner, Stream &mainStream, std::vector &subStream, std::vector> ¬ifyMain, std::vector> ¬ifySub); - + + // 适配新CollExecutor接口 + HcclResult Prepare(HcomCollOpInfo *opInfo, u32 userRank, u64 &count, DeviceMem &cclBufferPartOne, + DeviceMem &cclBufferPartTwo, SubCommInfo &outerCommInfo, SubCommInfo &innerCommInfo, + Stream &mainStream, std::vector &subStream, + std::vector> ¬ifyMain, std::vector> ¬ifySub); + HcclResult RunAsync(); - + protected: - + private: HcclResult MainRecordSub(); HcclResult SubWaitMain(); HcclResult MainWaitSub(); HcclResult SubRecordMain(); - + HcomCollOpInfo *opInfo_; u64 memSliceCount_; u32 userRank_; - + void* usrInMemAddr_; void* usrOutMemAddr_; std::vector dmaMem_; - + std::vector subStream_; - + std::vector> streamNotifyMain_; std::vector> streamNotifySub_; - + u32 intraRankSize_; u32 interRankSize_; u32 intraRankId_; u32 interRankId_; - + std::vector intraLinks_; std::vector interLinks_; }; } // namespace hccl - + #endif /* ALL_GATHER_PIPELINE_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_base_pub.h b/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_base_pub.h index 7d16681cd4944c4bc3edb368646368c306507472..bb54c11085a01dc6663c8310b6482408f908bf24 100644 --- a/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_base_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_base_pub.h @@ -22,6 +22,7 @@ #include "mem_device_pub.h" #include "stream_pub.h" #include "executor_base_pub.h" +#include "coll_alg_param.h" namespace hccl { class A2aPipelineMemory { @@ -46,6 +47,11 @@ public: std::unique_ptr &commOuter, std::unique_ptr &commInner, Stream &mainStream, std::vector &subStream, std::vector> ¬ifyMain, std::vector> ¬ifySub); + // 适配新CollExecutor接口 + virtual HcclResult Prepare(u32 userRank, A2aPipelineMemory A2aPipelineMemory, + const SubCommInfo &outerCommInfo, const SubCommInfo &innerCommInfo, + Stream &mainStream, std::vector &subStream, + std::vector> ¬ifyMain, std::vector> ¬ifySub); HcclResult RunAsync(); diff --git a/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ccl_enough_pub.h b/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ccl_enough_pub.h index f1ff199b0830bc661e2f3572c5eb4013c1446caa..46a380575046dd4b742bb566a20d5bbb6158bfec 100644 --- a/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ccl_enough_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ccl_enough_pub.h @@ -27,6 +27,11 @@ public: std::unique_ptr &commOuter, std::unique_ptr &commInner, Stream &mainStream, std::vector &subStream, std::vector> ¬ifyMain, std::vector> ¬ifySub); + // 适配新CollExecutor接口 + virtual HcclResult Prepare(u32 userRank, A2aPipelineMemory A2aPipelineMemory, + const SubCommInfo &outerCommInfo, const SubCommInfo &innerCommInfo, + Stream &mainStream, std::vector &subStream, + std::vector> ¬ifyMain, std::vector> ¬ifySub); HcclResult RunAsync(); diff --git a/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ping_pong_pub.h b/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ping_pong_pub.h index f2760b16669001376e1c584c96ed2ca1d12f22ae..cd56d883e6423e3e9523d20ed5c358037493a20d 100644 --- a/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ping_pong_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/allltoall_pipeline_mesh_pairwise_ping_pong_pub.h @@ -27,6 +27,11 @@ public: std::unique_ptr &commOuter, std::unique_ptr &commInner, Stream &mainStream, std::vector &subStream, std::vector> ¬ifyMain, std::vector> ¬ifySub); + // 适配新CollExecutor接口 + virtual HcclResult Prepare(u32 userRank, A2aPipelineMemory A2aPipelineMemory, + const SubCommInfo &outerCommInfo, const SubCommInfo &innerCommInfo, + Stream &mainStream, std::vector &subStream, + std::vector> ¬ifyMain, std::vector> ¬ifySub); HcclResult RunAsync(); diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_hd_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_hd_transport_req_pub.h index 7d070570aaaf4c54a89942a686cf9b21e0ab394b..7c720676d7fbc59bbe82c92d3878230c7513ad59 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_hd_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_hd_transport_req_pub.h @@ -16,14 +16,14 @@ namespace hccl { class CalcHDTransportReq : public CalcTransportReqBase { public: - explicit CalcHDTransportReq(std::vector> &subCommPlaneVector, + explicit CalcHDTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); ~CalcHDTransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_HD_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_mesh_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_mesh_transport_req_pub.h index 71ed68d70b50eddc83f0e3df0f4f793b1dcb4e93..8e44febca88a7638225d9480d609d17dd2011769 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_mesh_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_mesh_transport_req_pub.h @@ -16,14 +16,14 @@ namespace hccl { class CalcMeshTransportReq : public CalcTransportReqBase { public: - explicit CalcMeshTransportReq(std::vector> &subCommPlaneVector, + explicit CalcMeshTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); ~CalcMeshTransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_MESH_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_nb_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_nb_transport_req_pub.h index 8414c64bb9f6531e8b4b2434cf9acba44f676f76..63cadf97f879c27a7db487e8205bf39e2dbc6c79 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_nb_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_nb_transport_req_pub.h @@ -16,14 +16,14 @@ namespace hccl { class CalcNBTransportReq : public CalcTransportReqBase { public: - explicit CalcNBTransportReq(std::vector> &subCommPlaneVector, + explicit CalcNBTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); ~CalcNBTransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_NB_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_nhr_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_nhr_transport_req_pub.h index 5bde66bd668d5e9fd58a25bcd41fbd7a172c34bb..5ad905de6a773dcbab1a7d749f1e33f9384f4cc3 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_nhr_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_nhr_transport_req_pub.h @@ -16,14 +16,14 @@ namespace hccl { class CalcNHRTransportReq : public CalcTransportReqBase { public: - explicit CalcNHRTransportReq(std::vector> &subCommPlaneVector, + explicit CalcNHRTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); ~CalcNHRTransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_NHR_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_nhr_v1_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_nhr_v1_transport_req_pub.h index 3f9c58a73a6337cb4e4a75596ddf53df73ebc434..3a1f35248e37872eff45bf20c3deef3f6e2eb308 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_nhr_v1_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_nhr_v1_transport_req_pub.h @@ -16,14 +16,14 @@ namespace hccl { class CalcNHRV1TransportReq : public CalcTransportReqBase { public: - explicit CalcNHRV1TransportReq(std::vector> &subCommPlaneVector, + explicit CalcNHRV1TransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); ~CalcNHRV1TransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_NHR_V1_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_partial_mesh_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_partial_mesh_transport_req_pub.h index 7ebbc2e462223d8597503b0858e69f6a22a72906..de647fd84a6b4598499ac89a4888f0dbae4d55a4 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_partial_mesh_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_partial_mesh_transport_req_pub.h @@ -14,26 +14,16 @@ #include "calc_transport_req_base_pub.h" namespace hccl { -struct RdmaEnableCheckInfo { - RankInfo rankData; - bool isDiffModuleInServer; - bool isUsedRdma; -}; class CalcPartialMeshTransportReq : public CalcTransportReqBase { public: - explicit CalcPartialMeshTransportReq(std::vector> &subCommPlaneVector, - std::vector &isBridgeVector, u32 userRank, RdmaEnableCheckInfo& rdmacheckInfo); + explicit CalcPartialMeshTransportReq(std::vector> &subCommPlaneVector, + std::vector &isBridgeVector, u32 userRank); ~CalcPartialMeshTransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; - bool IsNotSupportSDMA(const RankInfo &remoteRankData); -private: - RankInfo rankData_; - bool isDiffModuleInServer_; - bool isUsedRdma_; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_PARTIAL_MESH_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_ring_transport_req_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_ring_transport_req_pub.h index dddd93f1b33ffe08f5bf6a9f17ccfefd825ec72d..cd5ae9150fa61291354228e80f02dafc1f6d7993 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_ring_transport_req_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_ring_transport_req_pub.h @@ -16,14 +16,14 @@ namespace hccl { class CalcRingTransportReq : public CalcTransportReqBase { public: - explicit CalcRingTransportReq(std::vector> &subCommPlaneVector, + explicit CalcRingTransportReq(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); ~CalcRingTransportReq(); HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport) override; + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID) override; }; } // namespace hccl #endif /* CALC_RING_TRANSPORT_REQ_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/calc_transport_req_base_pub.h b/src/domain/collective_communication/algorithm/base/inc/calc_transport_req_base_pub.h index 665c4345e3d5611473c9b7662803c06d9d77abb4..16bb137990fef5b2cc303da8b8f3e4e1eb378c1c 100644 --- a/src/domain/collective_communication/algorithm/base/inc/calc_transport_req_base_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/calc_transport_req_base_pub.h @@ -20,20 +20,21 @@ namespace hccl { class CalcTransportReqBase { public: - explicit CalcTransportReqBase(std::vector> &subCommPlaneVector, + explicit CalcTransportReqBase(std::vector> &subCommPlaneVector, std::vector &isBridgeVector, u32 userRank); virtual ~CalcTransportReqBase(); virtual HcclResult CalcTransportRequest(const std::string &tag, TransportMemType inputMemType, TransportMemType outputMemType, const CommParaInfo &commParaInfo, - std::vector &commTransport); + std::vector &commTransport, u32 subUserRankRoot = INVALID_VALUE_RANKID); protected: // 获取本rank在子通信域(多平面)内当前平面的rank号 - const u32 GetSubCollectiveRank(const std::vector &vecPara) const; + const u32 GetSubCollectiveRank(const std::vector &vecPara) const; + HcclResult GetRankByUserRank(const std::vector &vecPara, const u32 userRank, u32 &rank) const; - const std::vector> &subCommPlaneVector_; + const std::vector> &subCommPlaneVector_; const std::vector &isBridgeVector_; const u32 userRank_; }; diff --git a/src/domain/collective_communication/algorithm/base/inc/executor_base_pub.h b/src/domain/collective_communication/algorithm/base/inc/executor_base_pub.h index 12dbfea77e15643a85efeed38862e62f397d1870..8e5fe85eb2729789c71ceb9b171e1fb2217abf84 100644 --- a/src/domain/collective_communication/algorithm/base/inc/executor_base_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/executor_base_pub.h @@ -37,6 +37,7 @@ constexpr u32 DMA_REDUCE_THREE_OFFSET = 3; constexpr u64 HCCL_CHUNK_SIZE = 1024 * 1024 * 1024; // 1024*1024*1024的size constexpr u64 HCCL_MIN_PIPLINE_SLICE_ALIGN = 512; constexpr u64 HCCL_MIN_SLICE_ALIGN_910B = 16384; +constexpr u64 HCCL_MIN_SLICE_ALIGN_910_73 = 16384; constexpr u64 HCCL_SDMA_RDMA_SPLIT_SIZE = 67108864; constexpr u64 HCCL_MIN_SLICE_ALIGN_ONCHIP = 512; constexpr u64 HCCL_MIN_SLICE_ALIGN = 128; diff --git a/src/domain/collective_communication/algorithm/base/inc/reduce_scatter_pipeline_pub.h b/src/domain/collective_communication/algorithm/base/inc/reduce_scatter_pipeline_pub.h index ae9bf338cb6c71555837cbd91b09f2633a015371..917049b7032a181b2852fd1ed0cbb163a040873f 100644 --- a/src/domain/collective_communication/algorithm/base/inc/reduce_scatter_pipeline_pub.h +++ b/src/domain/collective_communication/algorithm/base/inc/reduce_scatter_pipeline_pub.h @@ -23,6 +23,7 @@ #include "executor_base_pub.h" #include "reducer_pub.h" #include "sender_pub.h" +#include "coll_alg_param.h" namespace hccl { constexpr u32 PIPELINE_DEPTH = 3; @@ -32,13 +33,14 @@ public: explicit ReduceScatterPipeline (const HcclDispatcher dispatcher, const u64 reduceAttrBitMap); ~ReduceScatterPipeline() override; + // 适配新CollExecutor接口 HcclResult Prepare(HcomCollOpInfo *opInfo, DeviceMem &cclBuffer, const u64 count, const u64 bufferSize, const u64 offset, - std::unique_ptr &commOuter, - std::unique_ptr &commInner, + const SubCommInfo &outerCommInfo, + const SubCommInfo &innerCommInfo, Stream &mainStream, std::vector &subStream, std::vector> ¬ifyMain, diff --git a/src/domain/collective_communication/algorithm/impl/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/CMakeLists.txt index 1e55f666af2c7c11c4c2c86e4186f3e08a73855e..767d63dd03db5f0860b3fde15885af8f54148e2f 100644 --- a/src/domain/collective_communication/algorithm/impl/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/CMakeLists.txt @@ -2,6 +2,7 @@ set(src_list ${CMAKE_CURRENT_SOURCE_DIR}/hccl_impl.cc ${CMAKE_CURRENT_SOURCE_DIR}/hccl_alg.cc ${CMAKE_CURRENT_SOURCE_DIR}/hccl_pipeline.cc + ${CMAKE_CURRENT_SOURCE_DIR}/topo_matcher.cc ) target_sources(hccl_alg PRIVATE diff --git a/src/domain/collective_communication/algorithm/impl/coll_alg_param.h b/src/domain/collective_communication/algorithm/impl/coll_alg_param.h index 4ce1e1181d84fa6f7bec83649902798d71f573ae..a9f823402aec79174c2546a7244b7be9ebd709ee 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_alg_param.h +++ b/src/domain/collective_communication/algorithm/impl/coll_alg_param.h @@ -21,6 +21,7 @@ #include "stream_pub.h" #include "local_notify.h" #include "hccl_opbase_atrace_info_pub.h" +#include "common.h" namespace hccl { using RankId = u32; @@ -64,6 +65,8 @@ struct SingleSubCommTransport { bool supportDataReceivedAck = false; LinkMode linkMode = LinkMode::LINK_DUPLEX_MODE; bool enableUseOneDoorbell = false; + bool needVirtualLink =false; // for alltoall 多线程性能提升使用 + std::vector virtualLinks; // for alltoall 多线程性能提升使用 }; using LevelNSubCommTransport = std::vector; @@ -103,9 +106,9 @@ struct OpParam { u64 inputSize; void* outputPtr; u64 outputSize; - HcclReduceOp reduceType; - SyncMode syncMode; - RankId root; + HcclReduceOp reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED; + SyncMode syncMode = SyncMode::DEFAULT_TIMEWAITSYNCMODE; + RankId root = INVALID_VALUE_RANKID; RankId dstRank; RankId srcRank; HcclOpBaseAtraceInfo* opBaseAtraceInfo = nullptr; @@ -128,12 +131,14 @@ struct OpParam { u32 itemNum; } BatchSendRecvDataDes; }; + HcclCMDType opType = HcclCMDType::HCCL_CMD_INVALID; }; struct SubCommInfo { u32 localRank; u32 localRankSize; std::vector links; + std::vector virtualLinks; // for alltoall 多线程性能提升使用 }; } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt index 3a55c7254cbe69bd8d6277daba8567fdd8147229..3f14d5e2bd57e16f64faa5a091fe133c023ff9f2 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt @@ -8,6 +8,12 @@ target_sources(hccl_alg PRIVATE ${src_list} ) +add_subdirectory(coll_all_gather) add_subdirectory(coll_all_reduce) +add_subdirectory(coll_reduce_scatter) +add_subdirectory(coll_reduce) add_subdirectory(coll_send_receive) +add_subdirectory(coll_all_to_all) +add_subdirectory(coll_scatter) +add_subdirectory(coll_broadcast) add_subdirectory(registry) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f659bd59b1c5b2c3c6a91bb82267aa49d5275ee1 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/CMakeLists.txt @@ -0,0 +1,7 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_for_310p_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/coll_all_gather_for_310p_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/coll_all_gather_for_310p_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..3069d2047f6fa8a725c894ac0199a491b310689a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/coll_all_gather_for_310p_executor.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_for_310p_executor.h" + +namespace hccl { +CollAllGatherFor310PExecutor::CollAllGatherFor310PExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollAllGatherFor310PExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFor310PExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherFor310PExecutor][CalcTransportMemType]" \ + "tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFor310PExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollAllGatherFor310PExecutor][CalcOuterCommInfo]tag[%s]start", tag_.c_str()); + CommParaInfo commParaInfo(COMM_LEVEL0, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollAllGatherFor310PExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", + tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFor310PExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + std::unique_ptr executor; + executor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + CHK_SMART_PTR_NULL(executor); + + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType)); + + u32 rankSize = outerCommInfo.localRankSize; + CHK_RET(executor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(executor, outerCommInfo)); + HCCL_INFO("allgather for 310P run success"); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherFor310PExecutor", AllGatherFor310P, CollAllGatherFor310PExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/coll_all_gather_for_310p_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/coll_all_gather_for_310p_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..74e7dd6a9d7d592dc59a7fd010a4e247a7855a0b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/310P/coll_all_gather_for_310p_executor.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_FOR_310P_RING_EXECUTOR_H +#define COLL_ALLGATHER_FOR_310P_RING_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherFor310PExecutor : public CollAllGatherExecutor { + +public: + explicit CollAllGatherFor310PExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherFor310PExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3747098555a447feb5f03ab250d857a902a513be --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/CMakeLists.txt @@ -0,0 +1,17 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_comm_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_ring_for_910_73_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_mesh_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_mesh_opbase_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_mesh_opbase_pipeline_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_single_rank_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_double_ring_concurrent_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) + +add_subdirectory(310P) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_comm_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..319863d6e78e1e70e5e92d9df5ec5e3843a7d1eb --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_comm_executor.cc @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_comm_executor.h" + +namespace hccl { +CollAllGatherCommExecutor::CollAllGatherCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollAllGatherCommExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcCombinedCommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherCommExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherCommExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherCommExecutor::CalcCombinedCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_COMBINE, CommType::COMM_TAG_MAX); + if (UseInterServerNHRAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING; + } else if (UseInterServerNHRV1Algo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING_V1; + } else if (UseInterServerNBAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_BRUCK; + } else { + commParaInfo.commType = CommType::COMM_TAG_RING_INNER; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_COMBINE], inputType, outputType)); + + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherCommExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_COMBINE, COMM_INDEX_0 + 1)); + SubCommInfo combinedCommInfo = GetSubCommInfo(COMM_COMBINE, COMM_INDEX_0); + + // 构造ring algorithm对应的all_gather实例 + std::unique_ptr executor; + if (UseInterServerNHRAlgo(algType_)) { + executor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("algather comm: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + executor.reset(new (std::nothrow) AllGatherNHRV1(dispatcher_)); + HCCL_INFO("algather comm: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + executor.reset(new (std::nothrow) AllGatherNB(dispatcher_)); + HCCL_INFO("algather comm: using nonuniform-bruck algo inter-server."); + } else { + executor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("algather comm: ring algo inter-server."); + } + CHK_SMART_PTR_NULL(executor); + + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID)); + + CHK_RET(RunTemplate(executor, combinedCommInfo)); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherComm", AllGatherComm, CollAllGatherCommExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_comm_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..0bf62ff7fc43571e22baed557eb8cb6955f0a37a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_comm_executor.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_COMM_EXECUTOR_H +#define COLL_ALLGATHER_COMM_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherCommExecutor : public CollAllGatherExecutor { +public: + explicit CollAllGatherCommExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherCommExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcCombinedCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_double_ring_concurrent_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_double_ring_concurrent_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa02a0d73c0fbc954a7eadaed09c12d94e2f349b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_double_ring_concurrent_executor.cc @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_double_ring_concurrent_executor.h" + +namespace hccl { + +CollAllGatherDoubleRingConcurrentExecutor::CollAllGatherDoubleRingConcurrentExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollAllGatherDoubleRingConcurrentExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + // DoubleRing只支持910_73场景 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + totalStreamNum += RDMA_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollAllGatherDoubleRingConcurrentExecutor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherDoubleRingConcurrentExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherDoubleRingConcurrentExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherDoubleRingConcurrentExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherDoubleRingConcurrentExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + commParaLevel0.forceRdma = false; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + CommParaInfo commParaLevel0Rdma(COMM_LEVEL0_RDMA, CommType::COMM_TAG_RING_INNER); + commParaLevel0Rdma.forceRdma = true; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0Rdma, opTransport[COMM_LEVEL0_RDMA], + inputType, outputType)); + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherDoubleRingConcurrentExecutor::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} +u64 CollAllGatherDoubleRingConcurrentExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + u64 maxCountPerLoop = cclBuffSize / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +u32 CollAllGatherDoubleRingConcurrentExecutor::IsDataSplit(const u64 curSize) +{ + u32 dataSplit = 0; + u64 dataValue = curSize * topoAttr_.userRankSize; + if ((topoAttr_.serverNum > 1) && ((dataValue / topoAttr_.serverNum) <= HCCL_SDMA_RDMA_SPLIT_SIZE)) { + dataSplit = 1; + } else if (dataValue <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + dataSplit = HCCL_SPLIT_FLAG; + } + return dataSplit; +} + +HcclResult CollAllGatherDoubleRingConcurrentExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun]AllGatherDoubleRingConcurrentExecutor starts."); + + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + CHK_PRT_RET(perDataSize == 0, + HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun]errNo[0x%016llx] datatype[%d] is invalid", + HCCL_ERROR_CODE(HCCL_E_PARA), param.DataDes.dataType), HCCL_E_PARA); + + // 获取子通信域信息 + auto nicList = topoAttr_.nicList; + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 outerRankSize = outerCommInfo.localRankSize; + u32 commIndex = outerCommInfo.localRank; + commIndex = RefreshCommIdx(commIndex, nicList, topoAttr_.devicePhyId); + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + u32 serverIndex = GetSubCommInfo(COMM_LEVEL1, commIndex).localRank; + + // 第一步,将数据从input内存拷贝到output内存的对应位置 + u64 inputMemSize = execMem.inputMem.size(); + u64 baseOffset = serverIndex * inputMemSize * outerRankSize; + u64 outerOffset = commIndex * inputMemSize; + DeviceMem dstMem = execMem.outputMem.range(baseOffset + outerOffset, inputMemSize); + CHK_SMART_PTR_NULL(dstMem); + + HcomCollOpInfo opInfo = { + "", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, 0, HCCL_REDUCE_RESERVED + }; + HcomCollOpInfo *opInfoPtr = nullptr; + + if (!DMAReduceFlag_) { + HcclResult ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, execMem.inputMem, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherDoubleRingExecutor][KernelRun]all gather double " + "ring memcpy Failed, Offset[%llu], Size[%llu]", + baseOffset + outerOffset, inputMemSize), ret); + } else { + opInfoPtr = &opInfo; + } + + // 第二步,各个AI Server 内 multi ring all gather + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector> multRingsSliceZero; // 数据基于该rank上环0的偏移 + u32 sliceNum = outerRankSize; + CHK_RET(PrepareAllgatherSlice(sliceNum, inputMemSize, dataSegsSlice)); + + // 多环数据切分 + auto mult2RingsSlice = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, nicList); + std::vector>> mult4RingsSlice; + // 基于2环数据切分2环SDMA+2环ROH; bool = true表示SDMA; + u32 syncTrans = BEST_SPLIT_VALUE; + u64 totalDataSize = inputMemSize * dataSegsSlice.size(); + if (totalDataSize <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + syncTrans = MAX_SPLIT_VALUE; + } + mult4RingsSlice.resize(mult2RingsSlice.size() * SLICES_FACTOR); + for (u32 ringIndex = 0; ringIndex < mult2RingsSlice.size(); ringIndex++) { + std::vector sdmaSlice; + std::vector rdmaSlice; + for (u32 segsIndex = 0; segsIndex < mult2RingsSlice[ringIndex].size(); segsIndex++) { + auto totalSize = mult2RingsSlice[ringIndex][segsIndex].size; + auto sdmaSliceOffset = mult2RingsSlice[ringIndex][segsIndex].offset; + auto sdmaSliceSize = (totalSize <= HCCL_MIN_SLICE_ALIGN_910_73) ? totalSize: + ((syncTrans * totalSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * HCCL_MIN_SLICE_ALIGN_910_73; + Slice sdmaSliceTmp; + sdmaSliceTmp.offset = sdmaSliceOffset; + sdmaSliceTmp.size = sdmaSliceSize; + Slice rdmaSliceTmp; + rdmaSliceTmp.offset = sdmaSliceOffset + sdmaSliceSize; + rdmaSliceTmp.size = totalSize - sdmaSliceSize; + sdmaSlice.push_back(sdmaSliceTmp); + rdmaSlice.push_back(rdmaSliceTmp); + HCCL_DEBUG("Ring index:%u, segId:%u, Orignal [offset %llu, size %llu], sdma [offset %llu, size %llu], "\ + "rdma [offset %llu, size %llu]", ringIndex, segsIndex, sdmaSliceOffset, totalSize, + sdmaSliceTmp.offset, sdmaSliceTmp.size, rdmaSliceTmp.offset, rdmaSliceTmp.size); + } + mult4RingsSlice[ringIndex] = std::make_pair(true, sdmaSlice); // true表示使用sdma + mult4RingsSlice[ringIndex + mult2RingsSlice.size()] = std::make_pair(false, rdmaSlice); // false表示rdma + } + if (syncTrans == MAX_SPLIT_VALUE) { + mult4RingsSlice.erase(mult4RingsSlice.end() - mult2RingsSlice.size(), mult4RingsSlice.end()); + } + // 抽取当前用于多环all gather 的output内存数据 + DeviceMem currentOutputMem = execMem.outputMem.range(baseOffset, inputMemSize * outerRankSize); + CHK_SMART_PTR_NULL(currentOutputMem); + CHK_RET(ActiveSlaveStreams(param.stream)); + + CHK_RET(MultiRingAllGatherConcurrent(param.tag, execMem.inputMem, currentOutputMem, execMem.count, param.DataDes.dataType, + mult4RingsSlice, param.stream, PROF_STAGE_1, baseOffset, opInfoPtr)); + + HCCL_INFO("all gather double ring outer run success"); + + // 第三步, AI server 间 recursive halving doubling all gather + u64 hdSize = 0; + std::vector::iterator iterNic = std::find(nicList.begin(), nicList.end(), topoAttr_.devicePhyId); + if (iterNic != nicList.end()) { + hdSize = inputMemSize * outerRankSize; + } + + u64 hdCount = hdSize / perDataSize; + std::unique_ptr innerExecutor; + u64 firstCommInnerSize = ((syncTrans * hdSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * + HCCL_MIN_SLICE_ALIGN_910_73; + std::vector sendSize{firstCommInnerSize, hdSize - firstCommInnerSize}; + std::vector sendOffset{0, firstCommInnerSize}; + for (int innerCommIndex = 0; innerCommIndex < RDMA_PLANE_NUM_IN_NPRING_DOUBLE; ++innerCommIndex) { + if (sendSize[innerCommIndex] == 0 || (!GetExternalInputEnableRdmaSdmaConcurrent() && innerCommIndex > 0)) { + continue; + } + if (GetExternalInputEnableRdmaSdmaConcurrent() || UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-hierarchical-ring algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNB(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("allgather ring: using halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + CHK_RET(CheckCommSize(COMM_LEVEL1_RDMA, commIndex + 1)); + SubCommInfo innerCommInfo = (innerCommIndex == 0 ? + GetSubCommInfo(COMM_LEVEL1, commIndex) : GetSubCommInfo(COMM_LEVEL1_RDMA, commIndex)); + + if (topoAttr_.devNumInLevel2 <= 1) { + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + u32 rankSize = innerCommInfo.localRankSize; + std::vector inputSlices(rankSize, Slice()); + for (u32 i = 0; i < rankSize; i++) { + inputSlices[i].size = sendSize[innerCommIndex]; + inputSlices[i].offset = hdSize * i + sendOffset[innerCommIndex]; + } + auto &innerCommStream = streamInfo_.ringStreams[innerCommIndex]; + auto ret = streamInfo_.ringSignalAux[innerCommIndex]->Wait(innerCommStream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun] "\ + " inner wait main [%u] failed", innerCommIndex), ret); + + CHK_RET(innerExecutor->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, hdCount, param.DataDes.dataType, innerCommStream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, inputSlices, 0)); + + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + + innerCommInfo.localRank, PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, innerCommStream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + ret = streamInfo_.ringSignal[innerCommIndex]->Post(innerCommStream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun] "\ + " inner post mains [%u] failed", innerCommIndex), ret); + + ret = streamInfo_.ringSignalAux[innerCommIndex]->Post(const_cast(param.stream), + dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun] "\ + " main post inner [%u] failed", innerCommIndex), ret); + } else { + u32 innerRankSize = GetSubCommInfo(COMM_LEVEL1, COMM_INDEX_0).localRankSize; + u64 innerBaseOffset = baseOffset * innerRankSize; + DeviceMem innerInputMem = execMem.outputMem.range(innerBaseOffset, inputMemSize * outerRankSize); + DeviceMem innerOutputMem = execMem.outputMem.range(innerBaseOffset, + inputMemSize * outerRankSize * innerRankSize); + + std::vector inputSlices(innerRankSize, Slice()); + for (u32 i = 0; i < innerRankSize; i++) { + inputSlices[i].size = sendSize[innerCommIndex]; + inputSlices[i].offset = hdSize * i + sendOffset[innerCommIndex]; + } + + auto &innerCommStream = streamInfo_.ringStreams[innerCommIndex]; + auto ret = streamInfo_.ringSignalAux[innerCommIndex]->Wait(innerCommStream, dispatcher_, PROF_STAGE_2); + + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(innerExecutor->Prepare(innerInputMem, innerOutputMem, execMem.inputMem, hdCount, + param.DataDes.dataType, innerCommStream, HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, + inputSlices, 0)); + + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 rankSize = innerCommInfo.localRankSize; + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, innerCommStream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + ret = streamInfo_.ringSignal[innerCommIndex]->Post(innerCommStream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun] "\ + "inner post mains [%u] failed", innerCommIndex), ret); + + ret = streamInfo_.ringSignalAux[innerCommIndex]->Post(const_cast(param.stream), + dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun] "\ + "main post inner [%u] failed", innerCommIndex), ret); + + // 超节点间做allgather + ret = AllGatherLevel2(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun]tag[%s], all_gather failed, "\ + "return[%d]", param.tag.c_str(), ret), ret); + } + if (sendSize[innerCommIndex] == 0 || (!GetExternalInputEnableRdmaSdmaConcurrent() && innerCommIndex > 0)) { + continue; + } + + auto ret = streamInfo_.ringSignal[innerCommIndex]->Wait(const_cast(param.stream), + dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollAllGatherDoubleRingConcurrentExecutor][KernelRun] "\ + "main wait inner [%u] failed", innerCommIndex), ret); + } + HCCL_INFO("all gather double ring inner run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherDoubleRingConcurrentExecutor", AllGatherDoubleRingConcurrent, + CollAllGatherDoubleRingConcurrentExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_double_ring_concurrent_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_double_ring_concurrent_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..fff28d51c3e36fe9a1279841d4b62083d3d95978 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_double_ring_concurrent_executor.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALL_GATHER_DOUBLE_RING_CONCURRENT_EXECUTOR_H +#define COLL_ALL_GATHER_DOUBLE_RING_CONCURRENT_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherDoubleRingConcurrentExecutor : public CollAllGatherExecutor { + +public: + explicit CollAllGatherDoubleRingConcurrentExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherDoubleRingConcurrentExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + u32 IsDataSplit(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a5fdfb38eff0fb75eca85be0163e153bc7acd56 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_executor.cc @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_executor.h" + +namespace hccl { +CollAllGatherExecutor::CollAllGatherExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollAllGatherExecutor::Orchestrate(const OpParam& param, + const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + tag_ = param.tag; + algResResp_ = &algRes; + GetStreamInfo(algRes); + auto rtStream = param.stream.ptr(); + HCCL_PROFILER_ADD_TAG(param.tag, algoAttr_.identifier, GetWorkflowMode()); + HCCL_PROFILER_ADD_STREAM(rtStream, param.tag, 0, algType_); + CHK_RET(AddSubStreamToProfiling()); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_ADD_OPDATA(param.tag, param.DataDes.count, param.inputPtr, param.outputPtr, + param.DataDes.dataType, INVALID_VALUE_RANKID, algoAttr_.identifier); + HCCL_PROFILER_ADD_GROUPRANK(algoAttr_.identifier, topoAttr_.userRankSize, topoAttr_.userRank); + } + + HcclResult ret = HCCL_SUCCESS; + // 图模式和单卡场景下不需要Loop + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + u64 totalSize = param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; + execMem.inputMem = DeviceMem::create(algRes.paramInputMem.ptr(), totalSize); + execMem.outputMem = DeviceMem::create(algRes.paramOutputMem.ptr(), totalSize * topoAttr_.userRankSize); + execMem.scratchMem = algRes.scratchMem; + HCCL_DEBUG("[CollAllGatherExecutor][Orchestrate]offload inputMem[%p][%u], outputMem[%p][%u]," \ + "scratchMem[%p][%u], inputPtr[%p] outputPtr[%p], count[%llu]", + execMem.inputMem.ptr(), execMem.inputMem.size(), execMem.outputMem.ptr(), execMem.outputMem.size(), + execMem.scratchMem.ptr(), execMem.scratchMem.size(), execMem.inputPtr, execMem.outputPtr, execMem.count); + ret = KernelRun(param, execMem); + } else if (topoAttr_.userRankSize == 1) { + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } else { + ret = RunLoop(param, algRes); + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherExecutor][Orchestrate]errNo[0x%016llx]all reudce excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_DEL_STREAM(rtStream); + HCCL_PROFILER_DEL_TAG(param.tag); + HCCL_PROFILER_DEL_OPDATA(param.tag); + HCCL_PROFILER_DEL_GROUPRANK(param.tag); + } + HCCL_INFO("tag[%s], Allgather executor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + + +u64 CollAllGatherExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = cclBuffSize / (unitSize * topoAttr_.userRankSize); + HCCL_WARNING("[CollAllGatherExecutor][CalcLoopMaxCount]" \ + "using default maxCountPerLoop[%llu] as CCLBuffSize / unitSize.", maxCountPerLoop); + return maxCountPerLoop; +} + +bool CollAllGatherExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE || + curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +u32 CollAllGatherExecutor::IsDataSplit(const u64 curSize) +{ + HCCL_INFO("[CollAllGatherExecutor][IsDataSplit]opMeta is using the default option: not data split."); + return 0; +} + +HcclResult CollAllGatherExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + + u8 *curInputPtr = static_cast(param.inputPtr); + u8 *curOutputPtr = static_cast(param.outputPtr); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + + u64 maxCountPerLoop = CalcLoopMaxCount(algRes.cclInputMem.size(), unitSize); // override + HCCL_DEBUG("[CollAllGatherExecutor][RunLoop]tag[%s], userRankSize is [%llu], maxCountPerLoop is [%llu].", + param.tag.c_str(), topoAttr_.userRankSize, maxCountPerLoop); + + for (u64 countLeft = param.DataDes.count, curCount = 0, inputOffset = 0, outputOffset = 0; + countLeft > 0; countLeft -= curCount) { + curInputPtr += inputOffset; + curOutputPtr += outputOffset; + // 判断剩余数据量对应的output size是否大于中转output size + curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; + u64 curSize = curCount * unitSize; // 单位:字节 + + HCCL_DEBUG("[CollAllGatherExecutor][RunLoop]tag[%s], inputOffset[%llu], outputOffset[%llu], " \ + "sendBuf[%p], recvBuf[%p], sendCount[%llu], dataType[%d].", + param.tag.c_str(), inputOffset, outputOffset, curInputPtr, curOutputPtr, curCount, param.DataDes.dataType); + + ExecMem execMem; + execMem.count = curCount; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + // 使用当前Loop偏移到的地址作为当前的inputPtr和outputPtr + execMem.inputPtr = curInputPtr; + execMem.outputPtr = curOutputPtr; + CHK_RET(RunLoopInner(param, execMem)); + + inputOffset = curSize; + outputOffset = curSize; + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherExecutor::RunLoopInner(const OpParam ¶m, ExecMem &execMem) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + u64 curSize = execMem.count * unitSize; // 单位:字节 + void *commInputPtr = execMem.inputMem.ptr(); + void *commOutputPtr = execMem.outputMem.ptr(); + CHK_PRT_RET((execMem.count == 0), + HCCL_ERROR("[CollAllGatherExecutor][RunLoop]In OP_BASE curCount is zero."), HCCL_E_PARA); + + if (!is310P3Common_) { + /* 记录指令信息用于一致性校验 */ + CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_ALLGATHER, + param.tag, execMem.count, param.DataDes.dataType, execMem.inputMem.size(), execMem.outputMem.size(), + HCCL_WORLD_GROUP)); + /* 设置子图复用标志 */ + auto autoSelectedAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + bool hugeData = IsHugeData(curSize); // override + u32 dataSplit = IsDataSplit(curSize); + auto opMeta = HcclOpMetaInfo::GetOneForAllGather(autoSelectedAlgTypeLevel1, hugeData); + opMeta.dataSplit = dataSplit; + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + } + + execMem.inputMem = DeviceMem::create(commInputPtr, curSize); + execMem.outputMem = DeviceMem::create(commOutputPtr, curSize * topoAttr_.userRankSize); + HCCL_DEBUG("[CollAllGatherExecutor][RunLoopInner]inputMem[%p][%llu], outputMem[%p][%llu], " \ + "intputPtr[%p], outputPtr[%p], curCount[%llu], curSize[%llu]", + execMem.inputMem.ptr(), execMem.inputMem.size(), execMem.outputMem.ptr(), execMem.outputMem.size(), + execMem.inputPtr, execMem.outputPtr, execMem.count, curSize); + + // 执行 + if (!DMAReduceFlag_) { + // 如果使用in CCL buffer,需要将user buffer in中的结果拷贝到CCL buffer in + DeviceMem srcMem = DeviceMem::create(execMem.inputPtr, curSize); + DeviceMem dstMem = DeviceMem::create(commInputPtr, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + HCCL_DEBUG("[CollAllGatherExecutor][RunLoop]copy from user in to ccl in."); + } + HcclResult ret = KernelRun(param, execMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherExecutor][RunLoop]errNo[0x%016llx]kernel run error, tag[%s], " \ + "inputMem ptr[%p], outputMem ptr[%p], count[%llu], dataType[%d], reduce op type[%d]", + HCCL_ERROR_CODE(ret), param.tag.c_str(), execMem.inputMem.ptr(), execMem.outputMem.ptr(), + execMem.count, param.DataDes.dataType), + ret); + + if (!DMAReduceFlag_) { + // 如果使用CCL buffer,需要将CCL buffer out中的结果拷贝到user buffer out + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + // 拷贝中转output上每个slice的数据到output内存,目的端中每个slice的size固定为output的size + u8 *curOutputPtr = static_cast(execMem.outputPtr); + DeviceMem dstMem = DeviceMem::create(curOutputPtr + param.DataDes.count * unitSize * i, curSize); + DeviceMem srcMem = DeviceMem::create(static_cast(commOutputPtr) + curSize * i, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + } + } + + if (!is310P3Common_) { + CHK_RET(RankConsistent::GetInstance().DelOpPara(param.tag)); + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + } + return ret; +} + +HcclResult CollAllGatherExecutor::PrepareAllgatherSlice(u32 sliceNum, u64 inputMemSize, + std::vector &dataSegsSlice) const +{ + Slice sliceTemp; + for (u32 i = 0; i < sliceNum; i++) { // 根据数据量计算每个环上数据的偏移和大小 + sliceTemp.size = inputMemSize; + sliceTemp.offset = inputMemSize * i; + dataSegsSlice.push_back(sliceTemp); + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherExecutor::CalculateLevel1AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, + std::vector> multRingsSliceZero, std::vector> &multRingsSlice) const +{ + for (u32 ringIndex = 0; ringIndex < multRingsSliceZero.size(); ringIndex++) { + std::vector level1DataSlice; + for (u32 level0Idx = 0; level0Idx < level0RankSize; level0Idx++) { + for (u32 level1Idx = 0; level1Idx < level1RankSize; level1Idx++) { + Slice tmpSlice; + tmpSlice.size = multRingsSliceZero[ringIndex][level0Idx].size; + tmpSlice.offset = + multRingsSliceZero[ringIndex][level0Idx].offset + level1Idx * level0RankSize * inputMemSize; + level1DataSlice.push_back(tmpSlice); + } + } + multRingsSlice.push_back(level1DataSlice); + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherExecutor::CalculateLevel2AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, + u32 level2RankSize, std::vector dataSegsSlice, std::vector &level0DataSlice) const +{ + for (u32 i = 0; i < level0RankSize; i++) { + for (u32 j = 0; j < level1RankSize; j++) { + for (u32 z = 0; z < level2RankSize; z++) { + Slice rankSliceTemp; + rankSliceTemp.size = dataSegsSlice[i].size; + rankSliceTemp.offset = dataSegsSlice[i].offset + + (j * level0RankSize * level1RankSize + z * level1RankSize) * inputMemSize; + level0DataSlice.push_back(rankSliceTemp); + } + } + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherExecutor::AllGatherLevel2(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + u64 count, HcclDataType dataType, Stream &stream, const HcomCollOpInfo *opInfo) +{ + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(dataType, perDataSize)); + + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = outerCommInfo.localRank; + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + CHK_RET(CheckCommSize(COMM_LEVEL2, commIndex + 1)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, commIndex); + + u64 inputMemSize = inputMem.size(); + u32 level0RankSize = outerCommInfo.localRankSize; + u32 level1RankSize = innerCommInfo.localRankSize; + u32 level2RankSize = level2CommInfo.localRankSize; + u32 level0ServerIndex = outerCommInfo.localRank; + u32 level1ServerIndex = innerCommInfo.localRank; + + std::unique_ptr level2AGExecutor; + if (UseLevel2RingAlgo(algType_)) { + level2AGExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else { + level2AGExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("allgather ring: using halving-doubling algo inter-server."); + } + + // 计算slice, 不同超节点相同slice + std::vector level2DataSegsSlice; + Slice sliceTemp; + for (u32 i = 0; i < level2RankSize; i++) { + sliceTemp.size = inputMemSize; + sliceTemp.offset = i * level1RankSize * level0RankSize * inputMemSize; + level2DataSegsSlice.push_back(sliceTemp); + } + // outputMem传整块,通过baseOffset偏移 + u64 level2BaseOffset = (level0ServerIndex + level1ServerIndex * level1RankSize) * inputMemSize; + CHK_RET(level2AGExecutor->Prepare(outputMem, outputMem, inputMem, count, dataType, stream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, level2DataSegsSlice, level2BaseOffset)); + + CHK_RET(level2AGExecutor->RegisterProfiler(( + level2RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, stream)); + + CHK_RET(RunTemplate(level2AGExecutor, level2CommInfo)); + HCCL_INFO("allgather double ring [superpod] level2 allgather run success"); + + // 第二步,各个AI Server 间 all gather (ring/NHR) + HCCL_INFO("commIdx:%u Tag[%s].commInner.size():%u", commIndex, tag.c_str(), + level1RankSize); + + std::unique_ptr level1AGExecutor; + if (UseInterServerRingAlgo(algType_)) { + level1AGExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else { + level1AGExecutor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-hierarchical-ring algo inter-server."); + } + + // 计算slice, 不同超节点相同slice + std::vector level1DataSegsSlice; + for (u32 j = 0; j < level1RankSize; j++) { + for (u32 i = 0; i < level2RankSize; i++) { + sliceTemp.size = inputMemSize; + sliceTemp.offset = + j * level0RankSize *inputMemSize + i * level1RankSize * level0RankSize * inputMemSize; + level1DataSegsSlice.push_back(sliceTemp); + } + } + // outputMem传整块,通过baseOffset偏移? + u64 level1BaseOffset = level0ServerIndex * inputMemSize; + CHK_RET(level1AGExecutor->Prepare(outputMem, outputMem, inputMem, count, dataType, stream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, level1DataSegsSlice, level1BaseOffset)); + + CHK_RET(level1AGExecutor->RegisterProfiler(( + level1RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, stream)); + + CHK_RET(RunTemplate(level1AGExecutor, innerCommInfo)); + HCCL_INFO("allgather double ring [superpod] level1 allgather run success"); + + // 节点内做all gather double ring + std::vector dataSegsSlice; + std::vector> multRingsSliceZero; // 数据基于该rank上环0的偏移 + CHK_RET(PrepareAllgatherSlice(level0RankSize, inputMemSize, dataSegsSlice)); + + // 多环数据切分 + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, tag, false, topoAttr_.nicList); + + // 计算slice + std::vector> multRingsSlice; + for (u32 ringIndex = 0; ringIndex < multRingsSliceZero.size(); ringIndex++) { + std::vector level0DataSlice; + CHK_RET(CalculateLevel2AllgatherSlice(inputMemSize, level0RankSize, level1RankSize, + level2RankSize, dataSegsSlice, level0DataSlice)); + multRingsSlice.push_back(level0DataSlice); + } + + CHK_RET(ActiveSlaveStreams(stream)); + CHK_RET(MultiRingAllGather(tag, inputMem, outputMem, count, dataType, + multRingsSliceZero, stream, PROF_STAGE_1, 0, opInfo)); + HCCL_INFO("allgather double ring [superpod] level2 allgather run success"); + return HCCL_SUCCESS; +} + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..8401c90ac35c4723813ed09c3a5108d679178ed7 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_executor.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_EXECUTOR_H +#define COLL_ALLGATHER_EXECUTOR_H +#include "coll_comm_executor.h" +namespace hccl { +class CollAllGatherExecutor : public CollCommExecutor { + +public: + explicit CollAllGatherExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; +protected: + // AllGather Loop Executor公共接口 + virtual u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize); + virtual bool IsHugeData(const u64 curSize); + virtual u32 IsDataSplit(const u64 curSize); + HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + + // 工具类 + HcclResult PrepareAllgatherSlice(u32 sliceNum, u64 inputMemSize, std::vector &dataSegsSlice) const; + + HcclResult CalculateLevel1AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, + std::vector> multRingsSliceZero, std::vector> &multRingsSlice) const; + + HcclResult CalculateLevel2AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, + u32 level2RankSize, std::vector dataSegsSlice, std::vector &level0DataSlice) const; + + HcclResult AllGatherLevel2(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + u64 count, HcclDataType dataType, Stream &stream, const HcomCollOpInfo *opInfo = nullptr); + + bool DMAReduceFlag_{false}; // 是否DMA消减的标志 +private: + HcclResult RunLoopInner(const OpParam ¶m, ExecMem &execMem); +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d61e55333104776cd5b9b963cb830f15df1ddda --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_executor.cc @@ -0,0 +1,170 @@ + +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "coll_all_gather_mesh_executor.h" + +namespace hccl { +CollAllGatherMeshExecutor::CollAllGatherMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollAllGatherMeshExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation > 1U ? topoAttr_.deviceNumPerAggregation - 1U : 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollAllGatherMeshExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherMeshExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshExecutor::CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = (topoAttr_.deviceType == DevType::DEV_TYPE_910B) && + !topoMatcher_->GetExternalInputHcclDeterministic() && (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollAllGatherMeshExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + u64 maxCountPerLoop = cclBuffSize / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +HcclResult CollAllGatherMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + // 获取子通信域信息 + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 outerRankSize = outerCommInfo.localRankSize; + u32 commIndex = outerCommInfo.localRank; + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 serverIndex = innerCommInfo.localRank; + + u64 inputMemSize = execMem.inputMem.size(); + u64 baseOffset = serverIndex * inputMemSize * outerRankSize; + u64 outerOffset = commIndex * inputMemSize; + DeviceMem dstMem = execMem.outputMem.range(baseOffset + outerOffset, inputMemSize); + CHK_SMART_PTR_NULL(dstMem); + // 第一步,将数据从input内存拷贝到output内存的对应位置 + HcclResult ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, execMem.inputMem, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherMeshExecutor][KernelRun]all gather 4PmeshHD memcpy Failed, Offset[%llu], Size[%llu].", + baseOffset + outerOffset, inputMemSize), ret); + + // 第二步,各个AI Server 内 multi stream mesh all gather + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector> multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + u32 sliceNum = outerRankSize; + CHK_RET(PrepareAllgatherSlice(sliceNum, inputMemSize, dataSegsSlice)); + + // mesh算法stream数量为server内rank数减1 + CHK_RET(ExecutorBase::PrepareSliceMeshStreams(dataSegsSlice, sliceNum - 1, multiStreamSlice)); + + CHK_RET(ActiveSlaveStreams(param.stream)); + + // 抽取当前用于多环all gather 的output内存数据 + DeviceMem currentOutputMem = execMem.outputMem.range(baseOffset, inputMemSize * outerRankSize); + CHK_SMART_PTR_NULL(currentOutputMem); + + std::unique_ptr outerExecutor; + if (topoAttr_.deviceType == DevType::DEV_TYPE_910B) { + outerExecutor.reset( + new (std::nothrow) AllGatherMeshAtomic(dispatcher_, streamInfo_.ringStreams, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, commIndex, outerRankSize, + topoAttr_.userRank)); + } else { + outerExecutor.reset( + new (std::nothrow) AllGatherMesh(dispatcher_, streamInfo_.ringStreams, streamInfo_.ringSignal, + streamInfo_.ringSignalAux, commIndex, outerRankSize, topoAttr_.userRank)); + } + CHK_SMART_PTR_NULL(outerExecutor); + CHK_RET(outerExecutor->Prepare(currentOutputMem, currentOutputMem, execMem.inputMem, + execMem.count * outerRankSize, param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, + OUTER_BRIDGE_RANK_ID, dataSegsSlice, baseOffset)); + u32 rankSize = outerRankSize; + CHK_RET(outerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commIndex, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(outerExecutor, outerCommInfo)); + HCCL_INFO("all gather mesh HD outer run success"); + + // 第三步, AI server 间 recursive halving doubling all gather + u64 hdSize = inputMemSize * outerRankSize; + u64 hdCount = hdSize / perDataSize; + + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_) || (topoAttr_.isDiffDeviceModule && topoAttr_.serverNum == 1)) { + // 1-单server-SDMA + innerExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather mesh: using ring algo inter-server."); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("allgather mesh: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNHRV1(dispatcher_)); + HCCL_INFO("allgather mesh: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNB(dispatcher_)); + HCCL_INFO("allgather mesh: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("allgather mesh: using halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(innerExecutor->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, hdCount, + param.DataDes.dataType, param.stream, HcclReduceOp::HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, + std::vector(COMM_INDEX_0), 0)); + + rankSize = innerCommInfo.localRankSize; + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + serverIndex, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + HCCL_INFO("all gather mesh HD inner run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherMeshExecutor", AllGatherMesh, CollAllGatherMeshExecutor); +} // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c9ff81f9b9393d0765d42df185691e392c3ff83c --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_executor.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_MESH_EXECUTOR_H +#define COLL_ALLGATHER_MESH_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherMeshExecutor : public CollAllGatherExecutor { +public: + explicit CollAllGatherMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherMeshExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..fdecee02fceeb4bb3165783918f7597722f16490 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_executor.cc @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_mesh_opbase_executor.h" + +namespace hccl { +CollAllGatherMeshOpbaseExecutor::CollAllGatherMeshOpbaseExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = true; +} + +HcclResult CollAllGatherMeshOpbaseExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollAllGatherMeshOpbaseExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshOpbaseExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshOpbaseExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherMeshOpbaseExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshOpbaseExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollAllGatherMeshOpbaseExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + u64 maxCountPerLoop = (cclBuffSize - HCCL_MIN_SLICE_ALIGN_910B) / unitSize; + return maxCountPerLoop; +} + +bool CollAllGatherMeshOpbaseExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +HcclResult CollAllGatherMeshOpbaseExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u8 *curInputPtr = static_cast(execMem.inputPtr); + u8 *curOutputPtr = static_cast(execMem.outputPtr); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + + // 获取子通信域信息 + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + u64 inputMemSize = execMem.inputMem.size(); + u64 baseOffset = 0; + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + + CHK_RET(ActiveSlaveStreams(param.stream)); + + // 抽取当前用于多环all gather 的output内存数据 + DeviceMem currentOutputMem = execMem.outputMem.range(baseOffset, inputMemSize); // 减少dma out大小 + CHK_SMART_PTR_NULL(currentOutputMem); + + // DMA消减场景,打包opInfo + HcomCollOpInfo opInfo = { + "", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, 0, HCCL_REDUCE_RESERVED + }; + + std::unique_ptr outerExecutor; + outerExecutor.reset( + new (std::nothrow) AllgatherMeshDirect(dispatcher_, streamInfo_.ringStreams, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, outerCommInfo.localRank, outerCommInfo.localRankSize, + topoAttr_.userRank, &opInfo)); + CHK_SMART_PTR_NULL(outerExecutor); + CHK_RET(outerExecutor->Prepare(currentOutputMem, currentOutputMem, execMem.inputMem, execMem.count, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, + dataSegsSlice, baseOffset)); + + u32 rankSize = outerCommInfo.localRankSize; + CHK_RET(outerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(outerExecutor, outerCommInfo)); + + HCCL_INFO("all gather mesh outer run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherMeshOpbaseExecutor", AllGatherOpbase, CollAllGatherMeshOpbaseExecutor); +} // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..a526796574dcb27aed133e93ae5a9d50e7f34df1 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_executor.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_OPBASE_EXECUTOR_H +#define COLL_ALLGATHER_OPBASE_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherMeshOpbaseExecutor : public CollAllGatherExecutor { +public: + explicit CollAllGatherMeshOpbaseExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherMeshOpbaseExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_pipeline_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_pipeline_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1a0fb431d86bfdac9bba74bf70740baf2c8018a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_pipeline_executor.cc @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "coll_all_gather_mesh_opbase_pipeline_executor.h" + +namespace hccl { +CollAllGatherMeshOpbasePipelineExecutor::CollAllGatherMeshOpbasePipelineExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = true; +} + +HcclResult CollAllGatherMeshOpbasePipelineExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation + 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollAllGatherMeshOpbasePipelineExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshOpbasePipelineExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshOpbasePipelineExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + HCCL_INFO("[CollAllGatherMeshOpbasePipelineExecutor][CalcTransportMemType]" \ + "tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherMeshOpbasePipelineExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_LEVEL0, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +// PipeLine模式下使用Ring算法 +HcclResult CollAllGatherMeshOpbasePipelineExecutor::CalcLevel1CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_LEVEL1, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL1], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollAllGatherMeshOpbasePipelineExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + u64 maxCountPerLoop = (cclBuffSize - HCCL_MIN_SLICE_ALIGN_910B) / unitSize; + return maxCountPerLoop; +} + +bool CollAllGatherMeshOpbasePipelineExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize > RDMA_SEND_MAX_SIZE || curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +HcclResult CollAllGatherMeshOpbasePipelineExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollAllGatherMeshOpbasePipelineExecutor][KernelRun]AllGatherMeshOpbasePipelineExecutor begins."); + + // step 1 先获取 comm inner \ comm outer 的value + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = outerCommInfo.localRank; + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + // DMA消减场景,打包opInfo + HcomCollOpInfo opInfo = { + "", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, 0, HCCL_REDUCE_RESERVED + }; + + std::unique_ptr executor; + executor.reset(new (std::nothrow) AllGatherPipeline(dispatcher_)); + CHK_SMART_PTR_NULL(executor); + CHK_RET(executor->Prepare(&opInfo, topoAttr_.userRank, execMem.count, execMem.inputMem, execMem.outputMem, + outerCommInfo, innerCommInfo, const_cast(param.stream), + streamInfo_.ringStreams, streamInfo_.ringSignal, streamInfo_.ringSignalAux)); + CHK_RET(executor->RunAsync()); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherMeshOpbasePipelineExecutor", AllGatherOpbasePipeline, CollAllGatherMeshOpbasePipelineExecutor); + +} // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_pipeline_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_pipeline_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c7135170db1fba225a28de2f0d9084bd3ce67796 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_mesh_opbase_pipeline_executor.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_MESH_OPBASE_PIPELINE_EXECUTOR_H +#define COLL_ALLGATHER_MESH_OPBASE_PIPELINE_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherMeshOpbasePipelineExecutor : public CollAllGatherExecutor { +public: + explicit CollAllGatherMeshOpbasePipelineExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherMeshOpbasePipelineExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel1CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad78f923d70cf0d6039d4d2e3e25e7c37dd0773a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_executor.cc @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_ring_executor.h" + +namespace hccl { +CollAllGatherRingExecutor::CollAllGatherRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollAllGatherRingExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 1U; + switch (algType_) { + case AlgType::ALG_8P_RING_PLUS_HD: + case AlgType::ALG_8P_RING_PLUS_RING: + case AlgType::ALG_8P_RING_PLUS_NHR: + case AlgType::ALG_8P_RING_PLUS_NHR_V1: + case AlgType::ALG_8P_RING_PLUS_NB: + case AlgType::ALG_8P_RING_PLUS_PIPELINE: + totalStreamNum = OUTER_PLANE_NUM_IN_8PRING; + break; + case AlgType::ALG_NP_SINGLE_RING_PLUS_RING: + case AlgType::ALG_NP_SINGLE_RING_PLUS_HD: + if (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + } + } + break; + default: + break; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollAllGatherRingExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherRingExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollAllGatherRingExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollAllGatherRingExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +u64 CollAllGatherRingExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + u64 maxCountPerLoop = cclBuffSize / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +HcclResult CollAllGatherRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollAllGatherRingExecutor][KernelRun]The AllGatherRingExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + CHK_PRT_RET(perDataSize == 0, + HCCL_ERROR("[CollAllGatherRingExecutor][KernelRun]errNo[0x%016llx] datatype[%d] is invalid", + HCCL_ERROR_CODE(HCCL_E_PARA), param.DataDes.dataType), HCCL_E_PARA); + + // 获取子通信域的信息 + u32 ringNum = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? OUTER_PLANE_NUM_IN_8PRING : + OUTER_PLANE_NUM_IN_NPRING_SINGLE; + + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = (ringNum == OUTER_PLANE_NUM_IN_8PRING) ? topoAttr_.devicePhyId : outerCommInfo.localRank; + u32 outerRankSize = outerCommInfo.localRankSize; + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 serverIndex = innerCommInfo.localRank; + + // 第一步,如果非DMA消减,将数据从input内存拷贝到output内存的对应位置 + u64 inputMemSize = execMem.inputMem.size(); + u64 baseOffset = serverIndex * inputMemSize * outerRankSize; + u64 outerOffset = commIndex * inputMemSize; + DeviceMem dstMem = execMem.outputMem.range(baseOffset + outerOffset, inputMemSize); + CHK_SMART_PTR_NULL(dstMem); + + HcclResult ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, execMem.inputMem, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherRingExecutor][KernelRun]all gather 8PringHD memcpy Failed, " + "Offset[%llu], Size[%llu]", baseOffset + outerOffset, inputMemSize), ret); + + // 第二步,各个AI Server 内 multi ring all gather + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector> multRingsSliceZero; // 数据基于该rank上环0的偏移 + u32 sliceNum = outerRankSize; + CHK_RET(PrepareAllgatherSlice(sliceNum, inputMemSize, dataSegsSlice)); + + // 多环数据切分 + if (ringNum == OUTER_PLANE_NUM_IN_8PRING) { + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, param.tag); + } else { + multRingsSliceZero.push_back(dataSegsSlice); + } + CHK_PRT_RET(multRingsSliceZero.size() != ringNum, + HCCL_ERROR("[CollAllGatherRingExecutor][KernelRun]ringNum[%u] != multRingsSliceZero size[%llu]", + ringNum, multRingsSliceZero.size()), HCCL_E_INTERNAL); + + // 抽取当前用于多环all gather 的output内存数据 + DeviceMem currentOutputMem = execMem.outputMem.range(baseOffset, inputMemSize * outerRankSize); + CHK_SMART_PTR_NULL(currentOutputMem); + + CHK_RET(ActiveSlaveStreams(param.stream)); + + CHK_RET(MultiRingAllGather(param.tag, execMem.inputMem, currentOutputMem, execMem.count, param.DataDes.dataType, + multRingsSliceZero, param.stream, PROF_STAGE_1, baseOffset, nullptr)); + + HCCL_INFO("all gather 8PringHD outer run success"); + + // 第三步, AI server 间 recursive halving doubling all gather + u64 hdSize = 0; + std::vector nicList = const_cast&>(topoAttr_.nicList); + std::vector::iterator iterNic = std::find(nicList.begin(), nicList.end(), topoAttr_.devicePhyId); + if (iterNic != nicList.end()) { + hdSize = inputMemSize * outerRankSize; + } + u64 hdCount = hdSize / perDataSize; + + bool isMultiNic = topoType_ == TopoType::TOPO_TYPE_8P_RING && nicList.size() != DEVICE_EIGHT; + bool innRunRet = isMultiNic && (iterNic == nicList.end()); + if (!innRunRet) { // 满足以下条件, 不做server间通信: 1. 8P ring的拓扑 2. 网口不满配 3. 当前device不出网口 + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("allgather ring: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNHRV1(dispatcher_)); + HCCL_INFO("allgather ring: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNB(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("allgather ring: using halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(innerExecutor->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, hdCount, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, + std::vector(COMM_INDEX_0), 0)); + + u32 rankSize = innerCommInfo.localRankSize; + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + serverIndex, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } + HCCL_INFO("all gather 8PringHD inner run success"); + + // 网口裁剪:AI server 内多网口的allgather + if (topoType_ == TopoType::TOPO_TYPE_8P_RING && nicList.size() != DEVICE_EIGHT) { + CHK_RET(ActiveSlaveStreams(param.stream)); // 为什么要active两遍 + + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + u64 tempCount = execMem.outputMem.size() / perDataSize; + CHK_RET(ExecutorBase::PrepareSliceData(tempCount, perDataSize, sliceNum, 0, dataSegsSlice)); + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, nicList); + CHK_PRT_RET(multRingsSliceZero.size() != ringNum, HCCL_ERROR("[CollAllGatherRingExecutor][KernelRun]"\ + "ringNum[%u] != multRingsSliceZero size[%llu]", ringNum, multRingsSliceZero.size()), HCCL_E_INTERNAL); + + CHK_RET(MultiRingAllGather(param.tag, execMem.outputMem, execMem.outputMem, tempCount / DEVICE_EIGHT, + param.DataDes.dataType, multRingsSliceZero, param.stream, PROF_STAGE_1)); + + HCCL_INFO("all gather 8PringHD inner chunk run success"); + } + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherRingExecutor", AllGatherRing, CollAllGatherRingExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..108f5e06ef3a745b53380dd39b78b10b44135475 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_executor.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_RING_EXECUTOR_H +#define COLL_ALLGATHER_RING_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherRingExecutor : public CollAllGatherExecutor { +public: + explicit CollAllGatherRingExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherRingExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_for_910_73_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_for_910_73_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..82dac8b9c152a0a77544992613f1d85fc9397a0f --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_for_910_73_executor.cc @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "coll_all_gather_ring_for_910_73_executor.h" + +namespace hccl { +CollAllGatherRingFor91073Executor::CollAllGatherRingFor91073Executor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE; +} + +HcclResult CollAllGatherRingFor91073Executor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING ? OUTER_PLANE_NUM_IN_NPRING_DOUBLE : + OUTER_PLANE_NUM_IN_NPRING_SINGLE); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum *= STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollAllGatherRingFor91073Executor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingFor91073Executor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingFor91073Executor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherRingFor91073Executor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingFor91073Executor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherRingFor91073Executor::CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollAllGatherRingFor91073Executor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + u64 maxCountPerLoop = cclBuffSize / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +HcclResult CollAllGatherRingFor91073Executor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollAllGatherRingFor91073Executor][KernelRun] The AllGatherDoubleRingExecutor starts."); + + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + CHK_PRT_RET(perDataSize == 0, + HCCL_ERROR("[CollAllGatherRingFor91073Executor][KernelRun]errNo[0x%016llx] datatype[%s] is invalid", + HCCL_ERROR_CODE(HCCL_E_PARA), GetDataTypeEnumStr(param.DataDes.dataType).c_str()), HCCL_E_PARA); + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = outerCommInfo.localRank; + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + // 第一步,将数据从input内存拷贝到output内存的对应位置 + u32 level0ServerIndex = outerCommInfo.localRank; + u32 level1ServerIndex = innerCommInfo.localRank; + u32 level0RankSize = outerCommInfo.localRankSize; + u32 level1RankSize = innerCommInfo.localRankSize; + + u64 inputMemSize = execMem.inputMem.size(); + u64 baseOffset = level1ServerIndex * inputMemSize * level0RankSize; + u64 outerOffset = commIndex * inputMemSize; + DeviceMem dstMem = execMem.outputMem.range(baseOffset + outerOffset, inputMemSize); + CHK_SMART_PTR_NULL(dstMem); + + HcomCollOpInfo opInfo = { + "", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, 0, HCCL_REDUCE_RESERVED + }; + HcomCollOpInfo *opInfoPtr = nullptr; + + // 图模式opinfo为空,需要将数据从ccl input拷贝到ccl output上 + HcclResult ret = HCCL_SUCCESS; + if (!DMAReduceFlag_) { + ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, execMem.inputMem, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherRingFor91073Executor][KernelRun]all gather double " + "ring memcpy Failed, Offset[%llu], Size[%llu]", + baseOffset + outerOffset, inputMemSize), ret); + } else { + opInfoPtr = &opInfo; + // 先做server间算法,带有消减拷贝场景数据需要从user input取,拷贝到ccl output上 + if (level1RankSize > 1) { + DeviceMem srcMem = DeviceMem::create(static_cast(execMem.inputPtr), inputMemSize); + ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherRingFor91073Executor][KernelRun]all gather double " + "ring user memcpy Failed, Offset[%llu], Size[%llu]", + baseOffset + outerOffset, inputMemSize), ret); + } + } + if (topoAttr_.devNumInLevel2 > 1) { + // 超节点间做allgather + ret = AllGatherLevel2(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, param.DataDes.dataType, + const_cast(param.stream), opInfoPtr); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherRingFor91073Executor][KernelRun]tag[%s], all_gather failed, return[%d]", + param.tag.c_str(), ret), ret); + } else { + // 无超节点间场景 + if (level1RankSize > 1) { + std::unique_ptr level1AGExecutor; + if (UseInterServerRingAlgo(algType_)) { + level1AGExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + level1AGExecutor.reset(new (std::nothrow) AllGatherNB(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-bruck algo inter-server."); + } else { + level1AGExecutor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-hierarchical-ring algo inter-server."); + } + + // 计算slice, 不同超节点相同slice + std::vector level1DataSegsSlice; + Slice sliceTemp; + for (u32 i = 0; i < level1RankSize; i++) { + sliceTemp.size = inputMemSize; + sliceTemp.offset = (i * level0RankSize + level0ServerIndex) * inputMemSize; + level1DataSegsSlice.push_back(sliceTemp); + } + CHK_RET(level1AGExecutor->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, execMem.count, + param.DataDes.dataType, param.stream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, level1DataSegsSlice, 0)); + + CHK_RET(level1AGExecutor->RegisterProfiler(( + level1RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level1ServerIndex, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(level1AGExecutor, innerCommInfo)); + HCCL_INFO("allgather double ring [superpod] level1 allgather run success"); + } + // 节点内做all gather double ring + std::vector dataSegsSlice; + std::vector> multRingsSliceZero; // 数据基于该rank上环0的偏移 + CHK_RET(PrepareAllgatherSlice(level0RankSize, inputMemSize, dataSegsSlice)); + + // 多环数据切分 + if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, topoAttr_.nicList); + } else { + multRingsSliceZero.push_back(dataSegsSlice); + } + std::vector> multRingsSlice; + CHK_RET(CalculateLevel1AllgatherSlice(inputMemSize, level0RankSize, level1RankSize, + multRingsSliceZero, multRingsSlice)); + + std::vector> multRingsUserMemSlice; + if (!DMAReduceFlag_) { + multRingsUserMemSlice = multRingsSlice; + } else { + for (u32 ringIndex = 0; ringIndex < multRingsSlice.size(); ringIndex++) { + std::vector level1UserMemSlice; + for (auto &cclSlice : multRingsSlice[ringIndex]) { + Slice tmpSlice; + tmpSlice.size = cclSlice.size; + tmpSlice.offset = + (cclSlice.offset / inputMemSize) * opInfo.count * perDataSize + + multRingsSliceZero[ringIndex][0].offset; + level1UserMemSlice.push_back(tmpSlice); + HCCL_DEBUG("rank[%u], ringIndex[%u], tmpSlice.offset=[%llu], size=[%llu]", + topoAttr_.userRank, ringIndex, tmpSlice.offset, tmpSlice.size); + } + multRingsUserMemSlice.push_back(level1UserMemSlice); + } + } + CHK_RET(ActiveSlaveStreams(param.stream)); + if (DMAReduceFlag_ && level1RankSize > 1) { + // allgather输入放在CCL buffer上,通过设置nullptr指示要从CCL buffer获取输入 + opInfo.inputAddr = nullptr; + } + CHK_RET(MultiRingAllGather(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, multRingsSlice, param.stream, PROF_STAGE_2, 0, opInfoPtr, multRingsUserMemSlice)); + } + HCCL_INFO("all gather double ring inner run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherRingFor91073Executor", AllGatherRingFor91073, CollAllGatherRingFor91073Executor); + +} // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_for_910_73_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_for_910_73_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..6671a85df4d76861f901703a7e45a0afe05ceda2 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_for_910_73_executor.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_RING_FOR_910_73_EXECUTOR_H +#define COLL_ALLGATHER_RING_FOR_910_73_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherRingFor91073Executor : public CollAllGatherExecutor { +public: + explicit CollAllGatherRingFor91073Executor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherRingFor91073Executor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_single_rank_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_single_rank_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..c66d833498aa396e5cc98e090d35e462b2681fc4 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_single_rank_executor.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_single_rank_executor.h" + +namespace hccl { +CollAllGatherSingleRankExecutor::CollAllGatherSingleRankExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollAllGatherSingleRankExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + bool hugeData = (execMem.count * unitSize) > SDMA_SEND_MAX_SIZE; + if (execMem.inputPtr == execMem.outputPtr) { + // 通过CopyPattern字段区分不同的子图 + auto opMeta = HcclOpMetaInfo::GetOneForAllGather(originalAlgTypeLevel1, hugeData, CopyPattern::ZCOPY); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + } else { + auto opMeta = HcclOpMetaInfo::GetOneForAllGather(originalAlgTypeLevel1, hugeData, CopyPattern::BCOPY); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + // ranksize = 1; intput、output地址不同,input->output + DeviceMem srcMem(execMem.inputPtr, execMem.count * unitSize); + DeviceMem dstMem(execMem.outputPtr, execMem.count * unitSize); + HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream)); + } + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherSingleExecutor", AllGatherSingleRank, CollAllGatherSingleRankExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_single_rank_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_single_rank_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..73ca9e9e0050c424ab053d8d12d0c9e54a075331 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_single_rank_executor.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_SINGLE_RANK_EXECUTOR_H +#define COLL_ALLGATHER_SINGLE_RANK_EXECUTOR_H +#include "coll_all_gather_executor.h" +namespace hccl { +class CollAllGatherSingleRankExecutor : public CollAllGatherExecutor { + +public: + explicit CollAllGatherSingleRankExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherSingleRankExecutor() = default; + +private: + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.cc index def2f0e7719219f19bad7f377be513e3532611d7..1eed8f83909611ca9afe6341ac3ffeb744aa7123 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.cc @@ -11,8 +11,8 @@ #include "coll_all_reduce_for_310p_doubling_direct_executor.h" namespace hccl { -CollAllReduceFor310PDoublingDirectExecutor::CollAllReduceFor310PDoublingDirectExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) +CollAllReduceFor310PDoublingDirectExecutor::CollAllReduceFor310PDoublingDirectExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher): CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = true; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.h index e40d4fbe3c369dc3c0e2edc3e98c1bfbd5898db8..1048d69286a2de14ea2bde6edcef7a8d5bbc9c8d 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_direct_executor.h @@ -15,7 +15,8 @@ namespace hccl { class CollAllReduceFor310PDoublingDirectExecutor : public CollAllReduceExecutor { public: - CollAllReduceFor310PDoublingDirectExecutor(std::unique_ptr &pImpl); + CollAllReduceFor310PDoublingDirectExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); ~CollAllReduceFor310PDoublingDirectExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.cc index 6f2167897f848e96dbab0d514ecc45dc977cff70..19e44b8d3da4214ca229352f74b5b8b3ce67142c 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.cc @@ -11,8 +11,8 @@ #include "coll_all_reduce_for_310p_doubling_executor.h" namespace hccl { -CollAllReduceFor310PDoublingExecutor::CollAllReduceFor310PDoublingExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) +CollAllReduceFor310PDoublingExecutor::CollAllReduceFor310PDoublingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.h index 4d2078cdc53c0cec8ab962918c1188c3b69c255e..3fbc1ae48553769e1453b620f49ddc37f59d2122 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_doubling_executor.h @@ -15,7 +15,7 @@ namespace hccl { class CollAllReduceFor310PDoublingExecutor : public CollAllReduceExecutor { public: - CollAllReduceFor310PDoublingExecutor(std::unique_ptr &pImpl); + CollAllReduceFor310PDoublingExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceFor310PDoublingExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.cc index b02ca13e6a9a9d8161eac93b9526150252dec981..00632ef16d5aa7e9a20f37072a74d8529402ed86 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.cc @@ -11,8 +11,8 @@ #include "coll_all_reduce_for_310p_ring_executor.h" namespace hccl { -CollAllReduceFor310PRingExecutor::CollAllReduceFor310PRingExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) +CollAllReduceFor310PRingExecutor::CollAllReduceFor310PRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.h index 32e2047f1c9629432ee519b375d606edde47a6a8..3171d73ad3063e3089e20b6f9a1f18a10c8d1802 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/310P/coll_all_reduce_for_310p_ring_executor.h @@ -15,7 +15,7 @@ namespace hccl { class CollAllReduceFor310PRingExecutor : public CollAllReduceExecutor { public: - CollAllReduceFor310PRingExecutor(std::unique_ptr &pImpl); + CollAllReduceFor310PRingExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceFor310PRingExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/CMakeLists.txt index 142b130d910e2c9dcbf0b315ee3438b730e98f23..030baee6896555f56e8d9e4ab24dd2d952657c89 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/CMakeLists.txt @@ -3,6 +3,8 @@ set(src_list ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_single_rank_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_ring_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_comm_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_double_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_double_ring_concurrent_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_mesh_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_mesh_mid_count_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_reduce_mesh_oneshot_executor.cc diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.cc index 990c014c9a12b501b3c39d4277558c77b85a0bad..a7a87705f71b24c7d302a2ee3501636cc19d505a 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_comm_executor.h" namespace hccl { -CollAllReduceCommExecutor::CollAllReduceCommExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceCommExecutor::CollAllReduceCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.h index 658cd6f23ed6d5dcb1f8ea807a42d59222239084..24ee42963ec6be436b59532963f7989d6d134812 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_comm_executor.h @@ -17,7 +17,7 @@ namespace hccl { class CollAllReduceCommExecutor : public CollAllReduceExecutor { public: - CollAllReduceCommExecutor(std::unique_ptr &pImpl); + CollAllReduceCommExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceCommExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_concurrent_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_concurrent_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..15232e47a009ba985898b08182fd931219198f60 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_concurrent_executor.cc @@ -0,0 +1,352 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_reduce_double_ring_concurrent_executor.h" + +namespace hccl { + +CollAllReduceDoubleRingConcurrentExecutor::CollAllReduceDoubleRingConcurrentExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher): CollAllReduceExecutor(dispatcher, topoMatcher) +{ + if (!topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent() && + GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + DMAReduceFlag_ = true; + } else { + DMAReduceFlag_ = false; + } +} + +HcclResult CollAllReduceDoubleRingConcurrentExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + // DoubleRing只支持910_73场景 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } + if (topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent()) { + totalStreamNum += RDMA_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollAllReduceDoubleRingConcurrentExecutor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingConcurrentExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingConcurrentExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllReduceDoubleRingConcurrentExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingConcurrentExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + commParaLevel0.forceRdma = false; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + if (topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent()) { + CommParaInfo commParaLevel0Rdma(COMM_LEVEL0_RDMA, CommType::COMM_TAG_RING_INNER); + commParaLevel0Rdma.forceRdma = true; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0Rdma, opTransport[COMM_LEVEL0_RDMA], + inputType, outputType)); + } + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingConcurrentExecutor::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +bool CollAllReduceDoubleRingConcurrentExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize / topoAttr_.deviceNumPerAggregation / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE || + curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +bool CollAllReduceDoubleRingConcurrentExecutor::IsSmallData(const u64 totalSize, const u64 curSize) +{ + return false; +} + +HcclResult CollAllReduceDoubleRingConcurrentExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollAllReduceDoubleRingConcurrentExecutor][Run]The CollAllReduceDoubleRingConcurrentExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multi2RingsSlice; // 数据基于该rank上环0的偏移 + std::vector>> multi4RingsSlice; // 基于2环数据切分2环SDMA+2环ROH bool = true表示SDMA + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 sliceNum = outerCommInfo.localRankSize; + // 根据数据量计算每个环上数据的偏移和大小 + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + + /* 三步算法step1:外层 - 节点内 reduce-scatter */ + // 构造ring algorithm对应的reduce-scatter实例 + multi2RingsSlice = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, topoAttr_.nicList); + CHK_PRT_RET(multi2RingsSlice.size() != ringNum, HCCL_ERROR("[CollAllReduceDoubleRingConcurrentExecutor][Run]"\ + "ringNum[%u] != multRingsSliceZero size[%llu]", ringNum, multi2RingsSlice.size()), + HCCL_E_INTERNAL); + + // 根据数据量计算每个环上数据的偏移和大小 + u32 syncTrans = BEST_SPLIT_VALUE; + u64 totalDataSize = execMem.count * perDataSize; + if (totalDataSize <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + syncTrans = MAX_SPLIT_VALUE; + } + multi4RingsSlice.resize(multi2RingsSlice.size() * SLICES_FACTOR); + for (u32 ringIndex = 0; ringIndex < multi2RingsSlice.size(); ringIndex++) { + std::vector sdmaSlice; + std::vector rdmaSlice; + for (u32 segsIndex = 0; segsIndex < multi2RingsSlice[ringIndex].size(); segsIndex++) { + auto totalSize = multi2RingsSlice[ringIndex][segsIndex].size; + auto sdmaSliceOffset = multi2RingsSlice[ringIndex][segsIndex].offset; + auto sdmaSliceSize = (totalSize <= HCCL_MIN_SLICE_ALIGN_910_73) ? totalSize: + ((syncTrans * totalSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * HCCL_MIN_SLICE_ALIGN_910_73; + Slice sdmaSliceTmp; + sdmaSliceTmp.offset = sdmaSliceOffset; + sdmaSliceTmp.size = sdmaSliceSize; + Slice rdmaSliceTmp; + rdmaSliceTmp.offset = sdmaSliceOffset + sdmaSliceSize; + rdmaSliceTmp.size = totalSize - sdmaSliceSize; + sdmaSlice.push_back(sdmaSliceTmp); + rdmaSlice.push_back(rdmaSliceTmp); + HCCL_DEBUG("Ring index:%u, segId:%u, Orignal [offset %llu, size %llu], sdma " + "[offset %llu, size %llu], rdma [offset %llu, size %llu]", + ringIndex, segsIndex, sdmaSliceOffset, totalSize, sdmaSliceTmp.offset, + sdmaSliceTmp.size, rdmaSliceTmp.offset, rdmaSliceTmp.size); + } + multi4RingsSlice[ringIndex] = std::make_pair(true, sdmaSlice); // true表示使用sdma + multi4RingsSlice[ringIndex + multi2RingsSlice.size()] = std::make_pair(false, rdmaSlice); // false表示rdma + } + if (syncTrans == MAX_SPLIT_VALUE) { + multi4RingsSlice.erase(multi4RingsSlice.end() - multi2RingsSlice.size(), multi4RingsSlice.end()); + } + + HcomCollOpInfo *reduceScatterOpInfoPtr = nullptr; + // 第一步的reducescatter输出放在CCL buffer上,通过设置nullptr指示不做最后一步的DMA削减动作 + HcomCollOpInfo reduceScatterOpInfo = { + "", execMem.inputPtr, nullptr, execMem.count, param.DataDes.dataType, param.root, param.reduceType + }; + if (DMAReduceFlag_) { + reduceScatterOpInfoPtr = &reduceScatterOpInfo; + } + CHK_RET(MultiRingReduceScatterConcurrent(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.reduceType, multi4RingsSlice, param.stream, + PROF_STAGE_0, 0, reduceScatterOpInfoPtr)); + HCCL_INFO("allreduce double ring stage0 run success"); + + /* 三步算法step2: 内层 - 节点间 allreduce */ + u64 hdSize; + u32 segmentIdx; + u32 commIndex; + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, + outerCommInfo, multi2RingsSlice, param.tag)); + auto nicList = topoAttr_.nicList; + auto devicePhyId = topoAttr_.devicePhyId; + commIndex = RefreshCommIdx(commIndex, nicList, devicePhyId); + u64 hdCount = hdSize / perDataSize; + if (topoAttr_.devNumInLevel2 <= 1) { + DeviceMem allreduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + CHK_SMART_PTR_NULL(allreduceInput); + DeviceMem allreduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + CHK_SMART_PTR_NULL(allreduceOutput); + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + u64 reduceAttr = GetReduceAttr(allreduceInput, allreduceOutput, param.DataDes.dataType, param.reduceType); + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using ring algo inter-server."); + } else if (UseInterServerNHRAlgo(algType_)) { + u64 curSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; // 单位 byte + HCCL_DEBUG("allreduce ring: curSize[%llu] deviceNumPerAggregation[%u] commOuterSize[%u]", + curSize, topoAttr_.deviceNumPerAggregation, outerCommInfo.localRankSize); + if (curSize / topoAttr_.deviceNumPerAggregation <= NHR_ALLREDUCE_SMALL_SIZE) { + innerExecutor.reset(new (std::nothrow) AllReduceNHROneshot(dispatcher_, reduceAttr)); + } else { + innerExecutor.reset(new (std::nothrow) AllReduceNHR(dispatcher_, reduceAttr)); + } + HCCL_INFO("allreduce ring: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllReduceNHRV1(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllReduceNB(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using Recursive halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + u32 rankSize = innerCommInfo.localRankSize; + // 节点间的hd 使用环0来记录 + CHK_RET(innerExecutor->Prepare( + allreduceInput, allreduceOutput, allreduceOutput, hdCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0), dataSegsSlice[segmentIdx].offset)); + CHK_RET(innerExecutor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + HCCL_INFO("allreduce double ring stage1 run success"); + } else { + // 超节点内做reducescatter + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + SubCommInfo innerZeroCommInfo = GetSubCommInfo(COMM_LEVEL1, COMM_INDEX_0); + u32 sliceNum = innerZeroCommInfo.localRankSize; + // 根据数据量计算每个环上数据的偏移和大小 + CHK_RET(ExecutorBase::PrepareSliceData(hdCount, perDataSize, sliceNum, 0, dataSegsSlice)); + DeviceMem reducescatterInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + DeviceMem reducescatterOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + + u64 reduceAttr = GetReduceAttr(reducescatterInput, reducescatterOutput, + param.DataDes.dataType, param.reduceType); + std::unique_ptr level1RSExecutor; + if (UseInterServerRingAlgo(algType_)) { + level1RSExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(level1RSExecutor); + CHK_RET(level1RSExecutor->Prepare( + reducescatterInput, reducescatterInput, reducescatterOutput, hdCount, + param.DataDes.dataType, param.stream, param.reduceType, + OUTER_BRIDGE_RANK_ID, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + } else { + level1RSExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(level1RSExecutor); + CHK_RET(level1RSExecutor->Prepare( + reducescatterInput, reducescatterOutput, reducescatterOutput, hdCount, + param.DataDes.dataType, param.stream, param.reduceType, + OUTER_BRIDGE_RANK_ID, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + } + CHK_RET(level1RSExecutor->RegisterProfiler( + (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level1RSExecutor, innerCommInfo)); + HCCL_INFO("allreduce double ring [superpod] level1 reducescatter run success"); + + // 超节点间做allreduce + u64 arSize; + std::vector > rdSlice; + rdSlice.push_back(dataSegsSlice); + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, arSize, innerZeroCommInfo, rdSlice, param.tag)); + auto nicList = topoAttr_.nicList; + auto devicePhyId = topoAttr_.devicePhyId; + commIndex = RefreshCommIdx(commIndex, nicList, devicePhyId); + u64 arCount = arSize / perDataSize; + + CHK_RET(CheckCommSize(COMM_LEVEL2, commIndex + 1)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, commIndex); + u32 rankSize = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0).localRankSize; + + DeviceMem allreduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, arSize); + DeviceMem allreduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, arSize); + reduceAttr = GetReduceAttr(allreduceInput, allreduceOutput, param.DataDes.dataType, param.reduceType); + + std::unique_ptr level2ARExecutor; + if (UseLevel2RingAlgo(algType_)) { + level2ARExecutor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + } else { + level2ARExecutor.reset(new (std::nothrow) AllReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + } + CHK_RET(level2ARExecutor->Prepare( + allreduceInput, allreduceOutput, allreduceOutput, arCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0), dataSegsSlice[segmentIdx].offset)); + CHK_SMART_PTR_NULL(level2ARExecutor); + CHK_RET(level2ARExecutor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2ARExecutor, level2CommInfo)); + HCCL_INFO("allreduce double ring [superpod] level2 allreduce run success"); + // 超节点内做allgather + std::unique_ptr level1AGExecutor; + DeviceMem allgatherInput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, arSize); + DeviceMem allgatherOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, arSize*sliceNum); + if (UseInterServerRingAlgo(algType_)) { + level1AGExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + } else { + level1AGExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + } + CHK_SMART_PTR_NULL(level1AGExecutor); + CHK_RET(level1AGExecutor->Prepare(allgatherOutput, allgatherOutput, allgatherOutput, arCount, + param.DataDes.dataType, param.stream, + HcclReduceOp::HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, dataSegsSlice, + dataSegsSlice[segmentIdx].offset)); + CHK_RET(level1AGExecutor->RegisterProfiler( + (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level1AGExecutor, innerCommInfo)); + HCCL_INFO("allreduce double ring [superpod] level1 allgather run success"); + } + + /* 三步算法step3:外层 - 节点内 allgather */ + HcomCollOpInfo *allgatherOpInfoPtr = nullptr; + // 第三步的allgather输入放在CCL buffer上,通过设置nullptr指示要从CCL buffer获取输入 + HcomCollOpInfo allgatherOpInfo = { + "", nullptr, execMem.outputPtr, execMem.count, param.DataDes.dataType, param.root, param.reduceType + }; + if (DMAReduceFlag_) { + allgatherOpInfoPtr = &allgatherOpInfo; + } + CHK_RET(MultiRingAllGatherConcurrent(param.tag, execMem.inputMem, execMem.outputMem, hdCount, + param.DataDes.dataType, multi4RingsSlice, param.stream, + PROF_STAGE_2, 0, allgatherOpInfoPtr)); + HCCL_INFO("allreduce double ring stage2 run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllReduceDoubleRingConcurrentExecutor", AllReduceDoubleRingConcurrent, + CollAllReduceDoubleRingConcurrentExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_concurrent_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_concurrent_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..4eac4ae86265ecaa549f5a56ca0dd6928c567a57 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_concurrent_executor.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALL_REDUCE_DOUBLE_RING_CONCURRENT_EXECUTOR_H +#define COLL_ALL_REDUCE_DOUBLE_RING_CONCURRENT_EXECUTOR_H + +#include "coll_all_reduce_executor.h" + +namespace hccl { +class CollAllReduceDoubleRingConcurrentExecutor : public CollAllReduceExecutor { + +public: + CollAllReduceDoubleRingConcurrentExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollAllReduceDoubleRingConcurrentExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + bool IsHugeData(const u64 curSize) override; + bool IsSmallData(const u64 totalSize, const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..c58dbfec210228799e1c5cd2a47199460059cd59 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_executor.cc @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_reduce_double_ring_executor.h" + +namespace hccl { + +CollAllReduceDoubleRingExecutor::CollAllReduceDoubleRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + DMAReduceFlag_ = true; + } else { + DMAReduceFlag_ = false; + } +} + +HcclResult CollAllReduceDoubleRingExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + // DoubleRing只支持910_73场景 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollAllReduceDoubleRingExecutor][CalcStreamNum] tag[%s] streamNum_[%u].", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllReduceDoubleRingExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d].", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollAllReduceDoubleRingExecutor::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +bool CollAllReduceDoubleRingExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize / topoAttr_.deviceNumPerAggregation / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE || + curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +bool CollAllReduceDoubleRingExecutor::IsSmallData(const u64 totalSize, const u64 curSize) +{ + return false; +} + +HcclResult CollAllReduceDoubleRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollAllReduceDoubleRingExecutor][Run]The CollAllReduceDoubleRingExecutor starts"); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multRingsSliceZero; // 数据基于该rank上环0的偏移 + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 sliceNum = outerCommInfo.localRankSize; + // 根据数据量计算每个环上数据的偏移和大小 + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + + /* 三步算法step1:外层 - 节点内 reduce-scatter */ + // 构造ring algorithm对应的reduce-scatter实例 + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, topoAttr_.nicList); + CHK_PRT_RET(multRingsSliceZero.size() != ringNum, HCCL_ERROR("[CollAllReduceDoubleRingExecutor][Run]"\ + "ringNum[%u] != multRingsSliceZero size[%llu]", ringNum, multRingsSliceZero.size()), + HCCL_E_INTERNAL); + + HcomCollOpInfo *reduceScatterOpInfoPtr = nullptr; + // 第一步的reducescatter输出放在CCL buffer上,通过设置nullptr指示不做最后一步的DMA削减动作 + HcomCollOpInfo reduceScatterOpInfo = { + "", execMem.inputPtr, nullptr, execMem.count, param.DataDes.dataType, param.root, param.reduceType + }; + if (DMAReduceFlag_) { + reduceScatterOpInfoPtr = &reduceScatterOpInfo; + } + CHK_RET(MultiRingReduceScatter(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.reduceType, multRingsSliceZero, param.stream, + PROF_STAGE_0, 0, reduceScatterOpInfoPtr)); + HCCL_INFO("allreduce double ring stage0 run success."); + + /* 三步算法step2: 内层 - 节点间 allreduce */ + u64 hdSize; + u32 segmentIdx; + u32 commIndex; + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, outerCommInfo, multRingsSliceZero, param.tag)); + u64 hdCount = hdSize / perDataSize; + if (topoAttr_.devNumInLevel2 <= 1) { + DeviceMem allreduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + CHK_SMART_PTR_NULL(allreduceInput); + DeviceMem allreduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + CHK_SMART_PTR_NULL(allreduceOutput); + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + u64 reduceAttr = GetReduceAttr(allreduceInput, allreduceOutput, param.DataDes.dataType, param.reduceType); + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using ring algo inter-server."); + } else if (UseInterServerNHRAlgo(algType_)) { + u64 curSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; // 单位 byte + HCCL_DEBUG("allreduce ring: curSize[%llu] deviceNumPerAggregation[%u] commOuterSize[%u]", + curSize, topoAttr_.deviceNumPerAggregation, outerCommInfo.localRankSize); + if (curSize / topoAttr_.deviceNumPerAggregation <= NHR_ALLREDUCE_SMALL_SIZE) { + innerExecutor.reset(new (std::nothrow) AllReduceNHROneshot(dispatcher_, reduceAttr)); + } else { + innerExecutor.reset(new (std::nothrow) AllReduceNHR(dispatcher_, reduceAttr)); + } + HCCL_INFO("allreduce ring: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllReduceNHRV1(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllReduceNB(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("allreduce ring: using Recursive halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + u32 rankSize = innerCommInfo.localRankSize; + // 节点间的hd 使用环0来记录 + CHK_RET(innerExecutor->Prepare( + allreduceInput, allreduceOutput, allreduceOutput, hdCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0), dataSegsSlice[segmentIdx].offset)); + CHK_RET(innerExecutor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + HCCL_INFO("allreduce double ring stage1 run success"); + } else { + // 超节点内做reducescatter + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + SubCommInfo innerZeroCommInfo = GetSubCommInfo(COMM_LEVEL1, COMM_INDEX_0); + u32 sliceNum = innerZeroCommInfo.localRankSize; + // 根据数据量计算每个环上数据的偏移和大小 + CHK_RET(ExecutorBase::PrepareSliceData(hdCount, perDataSize, sliceNum, 0, dataSegsSlice)); + DeviceMem reducescatterInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + DeviceMem reducescatterOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + + u64 reduceAttr = GetReduceAttr(reducescatterInput, reducescatterOutput, + param.DataDes.dataType, param.reduceType); + std::unique_ptr level1RSExecutor; + if (UseInterServerRingAlgo(algType_)) { + level1RSExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(level1RSExecutor); + CHK_RET(level1RSExecutor->Prepare( + reducescatterInput, reducescatterInput, reducescatterOutput, hdCount, + param.DataDes.dataType, param.stream, param.reduceType, + OUTER_BRIDGE_RANK_ID, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + } else { + level1RSExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(level1RSExecutor); + CHK_RET(level1RSExecutor->Prepare( + reducescatterInput, reducescatterOutput, reducescatterOutput, hdCount, + param.DataDes.dataType, param.stream, param.reduceType, + OUTER_BRIDGE_RANK_ID, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + } + CHK_RET(level1RSExecutor->RegisterProfiler( + (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level1RSExecutor, innerCommInfo)); + HCCL_INFO("allreduce double ring [superpod] level1 reducescatter run success"); + + // 超节点间做allreduce + u64 arSize; + std::vector > rdSlice; + rdSlice.push_back(dataSegsSlice); + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, arSize, innerZeroCommInfo, rdSlice, param.tag)); + u64 arCount = arSize / perDataSize; + + CHK_RET(CheckCommSize(COMM_LEVEL2, commIndex + 1)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, commIndex); + u32 rankSize = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0).localRankSize; + + DeviceMem allreduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, arSize); + DeviceMem allreduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, arSize); + reduceAttr = GetReduceAttr(allreduceInput, allreduceOutput, param.DataDes.dataType, param.reduceType); + + std::unique_ptr level2ARExecutor; + if (UseLevel2RingAlgo(algType_)) { + level2ARExecutor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + } else { + level2ARExecutor.reset(new (std::nothrow) AllReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + } + CHK_RET(level2ARExecutor->Prepare( + allreduceInput, allreduceOutput, allreduceOutput, arCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0), dataSegsSlice[segmentIdx].offset)); + CHK_SMART_PTR_NULL(level2ARExecutor); + CHK_RET(level2ARExecutor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2ARExecutor, level2CommInfo)); + HCCL_INFO("allreduce double ring [superpod] level2 allreduce run success"); + // 超节点内做allgather + std::unique_ptr level1AGExecutor; + DeviceMem allgatherInput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, arSize); + DeviceMem allgatherOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, arSize*sliceNum); + if (UseInterServerRingAlgo(algType_)) { + level1AGExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + } else { + level1AGExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + } + CHK_SMART_PTR_NULL(level1AGExecutor); + CHK_RET(level1AGExecutor->Prepare(allgatherOutput, allgatherOutput, allgatherOutput, arCount, + param.DataDes.dataType, param.stream, + HcclReduceOp::HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, dataSegsSlice, + dataSegsSlice[segmentIdx].offset)); + CHK_RET(level1AGExecutor->RegisterProfiler( + (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level1AGExecutor, innerCommInfo)); + HCCL_INFO("allreduce double ring [superpod] level1 allgather run success"); + } + + /* 三步算法step3:外层 - 节点内 allgather */ + HcomCollOpInfo *allgatherOpInfoPtr = nullptr; + // 第三步的allgather输入放在CCL buffer上,通过设置nullptr指示要从CCL buffer获取输入 + HcomCollOpInfo allgatherOpInfo = { + "", nullptr, execMem.outputPtr, execMem.count, param.DataDes.dataType, param.root, param.reduceType + }; + if (DMAReduceFlag_) { + allgatherOpInfoPtr = &allgatherOpInfo; + } + CHK_RET(MultiRingAllGather(param.tag, execMem.inputMem, execMem.outputMem, hdCount, + param.DataDes.dataType, multRingsSliceZero, param.stream, + PROF_STAGE_2, 0, allgatherOpInfoPtr)); + HCCL_INFO("allreduce double ring stage2 run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllReduceDoubleRingExecutor", AllReduceDoubleRing, CollAllReduceDoubleRingExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..271b21f0e4805f47ea891c3f4b5c98acceb6d24e --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_double_ring_executor.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALL_REDUCE_DOUBLE_RING_EXECUTOR_H +#define COLL_ALL_REDUCE_DOUBLE_RING_EXECUTOR_H + +#include "coll_all_reduce_executor.h" + +namespace hccl { +class CollAllReduceDoubleRingExecutor : public CollAllReduceExecutor { + +public: + CollAllReduceDoubleRingExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllReduceDoubleRingExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + bool IsHugeData(const u64 curSize) override; + bool IsSmallData(const u64 totalSize, const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.cc index 339739b20c343693a1d3cbb222e97eddae7273e7..47cb64ce78fc1483bf455bd1329508ff8d4537fc 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_executor.h" namespace hccl { -CollAllReduceExecutor::CollAllReduceExecutor(std::unique_ptr &pImpl) - : CollCommExecutor(pImpl) + +CollAllReduceExecutor::CollAllReduceExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) { } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.h index 735ab199cdb8e4bbf8d8cd275494a0799ba95e05..592b9e41cccc7f1f9ba3c63f539f8fdd211291af 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceExecutor : public CollCommExecutor { public: - CollAllReduceExecutor(std::unique_ptr &pImpl); + CollAllReduceExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.cc index fd0ee9506d05ba82057fa159164f3f114041ace7..2ce2806c1ba5fd6474fb966a5bf991bb4a7a64a1 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_mesh_aiv_executor.h" namespace hccl { -CollAllReduceMeshAivExecutor::CollAllReduceMeshAivExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshAivExecutor::CollAllReduceMeshAivExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.h index 098183df7ca9e9513ace5bdc3bd59612646a0c72..5c0e5d9c51dfc7c5daa0e4adc672d380f0aab379 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_aiv_executor.h @@ -17,7 +17,7 @@ namespace hccl { class CollAllReduceMeshAivExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshAivExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshAivExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshAivExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.cc index 17b60d7a603ee62b311f6586d8d1ff56b43242b8..b1da5e526cd92512d9e5ea9f3bb3019c2e6dc168 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_mesh_executor.h" namespace hccl { -CollAllReduceMeshExecutor::CollAllReduceMeshExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshExecutor::CollAllReduceMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } @@ -23,8 +25,8 @@ void CollAllReduceMeshExecutor::ParseParam(const OpParam& param) bool isInlineReduce = IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType); meshSinglePlane_ = (topoAttr_.deviceType == DevType::DEV_TYPE_910B) && - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && isInlineReduce && - (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); + !topoMatcher_->GetExternalInputHcclDeterministic() && + isInlineReduce && (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); } HcclResult CollAllReduceMeshExecutor::CalcStreamNum(u32& streamNum) @@ -103,8 +105,7 @@ HcclResult CollAllReduceMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &e CHK_RET(ActiveSlaveStreams(param.stream)); - if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && - (param.DataDes.dataType != HCCL_DATA_TYPE_INT64) && + if (!topoMatcher_->GetExternalInputHcclDeterministic() && (param.DataDes.dataType != HCCL_DATA_TYPE_INT64) && ((topoAttr_.deviceType == DevType::DEV_TYPE_910B && param.reduceType != HCCL_REDUCE_PROD) || (IsSupportHighPerf() && param.reduceType == HCCL_REDUCE_SUM))) { CHK_RET(MultiStreamReduceScatterMeshAtomic(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, @@ -211,7 +212,7 @@ HcclResult CollAllReduceMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &e bool CollAllReduceMeshExecutor::IsSupportHighPerf() { - return ((GetExternalInputHcclHighPerfEnable() != 0) && + return ((topoMatcher_->GetExternalInputHcclHighPerfEnable() != 0) && (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB)); } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.h index 1dbd4ccb6e650b252d28832b8d50c5ddfd8398ad..778f61d5b267add1ce07f7a668f5db09eeabc5af 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceMeshExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.cc index 2f9ff8f3f4a011e95808e1616f42de714d12bec9..99104b749648298c23eb8b8afd9bf50e9293acc4 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_mesh_mid_count_executor.h" namespace hccl { -CollAllReduceMeshMidCountExecutor::CollAllReduceMeshMidCountExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshMidCountExecutor::CollAllReduceMeshMidCountExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { CCLMemSlice_ = false; DMAReduceFlag_ = true; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.h index b8d904d66c431073040d18bdade01b513c70fe71..4d0a56cbae89423985f167cc8c2a68fc1fc5cd4d 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_mid_count_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceMeshMidCountExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshMidCountExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshMidCountExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshMidCountExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.cc index d3eb48d79642637ef473039a1a6aa729a3ebcb7e..df5cbd294ecec9b1521da2664b9b4a1fc42fed8d 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_mesh_oneshot_executor.h" namespace hccl { -CollAllReduceMeshOneshotExecutor::CollAllReduceMeshOneshotExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshOneshotExecutor::CollAllReduceMeshOneshotExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { CCLMemSlice_ = false; DMAReduceFlag_ = true; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.h index 85d0528b90740ead727b17c1145ce696c7ed2d1e..9558dccc240eae5c6845c7c9340d0da884d13eb3 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_oneshot_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceMeshOneshotExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshOneshotExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshOneshotExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshOneshotExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.cc index 97fb31461c04467e13bdb887b1aef290d02d6c43..14b5b0a1b278db30a3dfd48c66f8d8ce2aae810f 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.cc @@ -11,8 +11,9 @@ #include "coll_all_reduce_mesh_opbase_big_count_aiv_executor.h" namespace hccl { -CollAllReduceMeshOpbaseBigCountAivExecutor::CollAllReduceMeshOpbaseBigCountAivExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshOpbaseBigCountAivExecutor::CollAllReduceMeshOpbaseBigCountAivExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher): CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.h index 7c19f5c93202ee97f04a4b438255498bd708049d..cda4600276fde94b8f7d78577994a89579dc9b75 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_big_count_aiv_executor.h @@ -17,7 +17,8 @@ namespace hccl { class CollAllReduceMeshOpbaseBigCountAivExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshOpbaseBigCountAivExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshOpbaseBigCountAivExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); ~CollAllReduceMeshOpbaseBigCountAivExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.cc index 202d7225472ae6979141ca5ee6c10c3f95d27814..8d9b82a461a4a555e354e420016e05c2418b4cf1 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_mesh_opbase_executor.h" namespace hccl { -CollAllReduceMeshOpbaseExecutor::CollAllReduceMeshOpbaseExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshOpbaseExecutor::CollAllReduceMeshOpbaseExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { CCLMemSlice_ = false; DMAReduceFlag_ = true; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.h index 7bdbdaabd6b4064ce19f035137a5857d1d1b5d46..1bf718a2ce7f29422aca5840d17984bcdbf8dd14 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceMeshOpbaseExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshOpbaseExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshOpbaseExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshOpbaseExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.cc index b0042d3ac2c5c7d8037c88607a3c2eeb471ba937..86bc6b2012e3404afb11fe26c153ccd32502227b 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.cc @@ -12,8 +12,8 @@ namespace hccl { // 准入条件: pipeLine && 910B && 单算子 && sdmaReduce && rdmaReduce && 多Mesh && MeshTopo && 非确定性 -CollAllReduceMeshOpbasePipelineExecutor::CollAllReduceMeshOpbasePipelineExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) +CollAllReduceMeshOpbasePipelineExecutor::CollAllReduceMeshOpbasePipelineExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher): CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = true; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.h index e08a791ef3c854d69c6b676c08a32ea22c86be5a..bf590b17583937ef7753a797a9b4d860afa71e8c 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_opbase_pipeline_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceMeshOpbasePipelineExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshOpbasePipelineExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshOpbasePipelineExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshOpbasePipelineExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.cc index f91eb543c88262688664b872a73536579f61d151..c1e94025ad9759d5fbcf8c1f1c7630ea44bc90df 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_mesh_small_count_executor.h" namespace hccl { -CollAllReduceMeshSmallCountExecutor::CollAllReduceMeshSmallCountExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMeshSmallCountExecutor::CollAllReduceMeshSmallCountExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = true; } @@ -27,7 +29,7 @@ bool CollAllReduceMeshSmallCountExecutor::CalcScratchMemFlag(const u64 totalSize { return GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB && topoAttr_.deviceType == DevType::DEV_TYPE_910B && - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE && + topoMatcher_->GetExternalInputHcclDeterministic() && topoAttr_.deviceNumPerAggregation > DEVICE_TWO && topoAttr_.deviceNumPerAggregation < DEVICE_EIGHT && totalSize <= HCCL_SMALL_COUNT_GRAPH_64_KB; @@ -175,7 +177,7 @@ HcclResult CollAllReduceMeshSmallCountExecutor::KernelRun(const OpParam ¶m, }; std::unique_ptr outer2Executor; - if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE) { + if (!topoMatcher_->GetExternalInputHcclDeterministic()) { outer2Executor.reset(new (std::nothrow) AllReduceReduceBcast(dispatcher_, reduceAttr, streamInfo_.ringStreams, streamInfo_.ringSignal, streamInfo_.ringSignalAux, outerCommInfo.localRank, outerCommInfo.localRankSize, topoAttr_.userRank, &opInfo)); diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.h index d15c2cfe971c07952e8ea5fbb2d8225a365544d3..793de7a7c443eab0c6d98cd2c0fb16d91735cf78 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mesh_small_count_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceMeshSmallCountExecutor : public CollAllReduceExecutor { public: - CollAllReduceMeshSmallCountExecutor(std::unique_ptr &pImpl); + CollAllReduceMeshSmallCountExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMeshSmallCountExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes); diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.cc index 1eb4c8e48f03b2153ae17d655cbf1fa2846dac46..8b92521888412db13ef0012842c0f3806c296674 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.cc @@ -13,9 +13,10 @@ namespace hccl { constexpr s32 INTRA_RS_STEP = 0; constexpr s32 INTRA_AG_STEP = 2; - -CollAllReduceMidCountAivRdmaExecutor::CollAllReduceMidCountAivRdmaExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceMidCountAivRdmaExecutor::CollAllReduceMidCountAivRdmaExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.h index e6d9a997caee95afaec50d26563f7fad5e83c204..75e1f27b64c2b2310815214cae77189bba986073 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_mid_count_aiv_rdma_executor.h @@ -19,7 +19,7 @@ namespace hccl { class CollAllReduceMidCountAivRdmaExecutor : public CollAllReduceExecutor { public: - CollAllReduceMidCountAivRdmaExecutor(std::unique_ptr &pImpl); + CollAllReduceMidCountAivRdmaExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceMidCountAivRdmaExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.cc index 9f9ec9eb0378e4bb89ea0938336e45f838f4929f..e960ac1d9e2ef18f979238ac687eb5b62ceb5cec 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_reduce_plus_bcast_executor.h" namespace hccl { -CollAllReduceReducePlusBcastExecutor::CollAllReduceReducePlusBcastExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceReducePlusBcastExecutor::CollAllReduceReducePlusBcastExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.h index b45efcc5d693fda2a9cf876a80adc38423db38ad..89162cc0e7096d86acdae0e1a86a9f11240252fb 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_reduce_plus_bcast_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceReducePlusBcastExecutor : public CollAllReduceExecutor { public: - CollAllReduceReducePlusBcastExecutor(std::unique_ptr &pImpl); + CollAllReduceReducePlusBcastExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceReducePlusBcastExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.cc index 2d4a72261a690ff39c0ce554d8d49e8e0fbe9b88..6635b756240580182c69895a84e7d314520164bb 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.cc @@ -11,10 +11,17 @@ #include "coll_all_reduce_ring_executor.h" namespace hccl { -CollAllReduceRingExecutor::CollAllReduceRingExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceRingExecutor::CollAllReduceRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { - DMAReduceFlag_ = false; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + DMAReduceFlag_ = true; + } else { + DMAReduceFlag_ = false; + } } HcclResult CollAllReduceRingExecutor::CalcStreamNum(u32& streamNum) @@ -29,6 +36,16 @@ HcclResult CollAllReduceRingExecutor::CalcStreamNum(u32& streamNum) case AlgType::ALG_8P_RING_PLUS_PIPELINE: totalStreamNum = OUTER_PLANE_NUM_IN_8PRING; break; + case AlgType::ALG_NP_SINGLE_RING_PLUS_RING: + case AlgType::ALG_NP_SINGLE_RING_PLUS_HD: + if (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + } + } + break; default: break; } @@ -132,8 +149,7 @@ HcclResult CollAllReduceRingExecutor::KernelRun(const OpParam ¶m, ExecMem &e u64 hdSize; u32 segmentIdx; u32 commIndex; - CHK_RET(hcclImpl_->PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, - outerCommInfo, multRingsSliceZero, param.tag)); + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, outerCommInfo, multRingsSliceZero, param.tag)); u64 hdCount = hdSize / perDataSize; auto nicList = topoAttr_.nicList; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.h index 9af71e015be001c0e272e80b3b545e7dea0ec0bf..07ebf138a98142037525f8e560b1ba980e62342a 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_ring_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceRingExecutor : public CollAllReduceExecutor { public: - CollAllReduceRingExecutor(std::unique_ptr &pImpl); + CollAllReduceRingExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceRingExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.cc index 943e7a25c0cdbcef5bd1701c948e75bb97baa2b3..ee30f023a86281506d28722ce6f6f221f174fa77 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.cc @@ -11,8 +11,10 @@ #include "coll_all_reduce_single_rank_executor.h" namespace hccl { -CollAllReduceSingleRankExecutor::CollAllReduceSingleRankExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceSingleRankExecutor::CollAllReduceSingleRankExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { } diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.h index 8df793b10dadb8a4201ee04baf63d0207e7e9399..c2782691eaff1fd551c487e89a66157f71f3edfe 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_single_rank_executor.h @@ -16,7 +16,7 @@ namespace hccl { class CollAllReduceSingleRankExecutor : public CollAllReduceExecutor { public: - CollAllReduceSingleRankExecutor(std::unique_ptr &pImpl); + CollAllReduceSingleRankExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceSingleRankExecutor() = default; private: diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.cc index e0f59890388fb6ece323a09ef62bafee26b91a30..dbe8e2e08b85c5aa0d89a9773b82e7117c3634cf 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.cc @@ -18,9 +18,10 @@ constexpr u32 A_X_AGGR_SIZE = 2; constexpr u64 HALF_OFFSET = 16 * 1024 * 1024; u64 CollAllReduceSmallCountAivRdmaExecutor::allreduceSmallDataAivRdmaCount_ = 0; - -CollAllReduceSmallCountAivRdmaExecutor::CollAllReduceSmallCountAivRdmaExecutor(std::unique_ptr &pImpl) - : CollAllReduceExecutor(pImpl) + +CollAllReduceSmallCountAivRdmaExecutor::CollAllReduceSmallCountAivRdmaExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllReduceExecutor(dispatcher, topoMatcher) { DMAReduceFlag_ = false; } @@ -51,7 +52,7 @@ HcclResult CollAllReduceSmallCountAivRdmaExecutor::CalcCommInfo(std::vectorGetExternalInputIntraRoceSwitch() == 0) { std::vector &commTransportLevel1 = opTransport[COMM_LEVEL1]; for (u32 ringIndex = 0; ringIndex < commTransportLevel1.size(); ringIndex++) { commTransportLevel1[ringIndex].isUsedRdma = false; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.h index 8f340ae9281c20a666955aa61a64f9037ed80866..8a5fd02cc3e20f11ec262d504e898240ecdd01f4 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_reduce/coll_all_reduce_small_count_aiv_rdma_executor.h @@ -19,7 +19,7 @@ namespace hccl { class CollAllReduceSmallCountAivRdmaExecutor : public CollAllReduceExecutor { public: - CollAllReduceSmallCountAivRdmaExecutor(std::unique_ptr &pImpl); + CollAllReduceSmallCountAivRdmaExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollAllReduceSmallCountAivRdmaExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7e53ae27bb1c4c159bc0cd70e7ccafcb88154498 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/CMakeLists.txt @@ -0,0 +1,10 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_to_all_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_to_all_v_2level_pipeline_excecutor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_to_all_v_fullmesh_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_to_all_v_staged_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bea5f3086e8c56e0833d559759d758ccb1360d2 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_executor.cc @@ -0,0 +1,411 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_all_to_all_executor.h" + +namespace hccl { + +CollAlltoAllExecutor::CollAlltoAllExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollNativeExecutorBase(dispatcher, topoMatcher) +{ +} + +HcclResult CollAlltoAllExecutor::Orchestrate(const OpParam& param, + const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + tag_ = param.tag; + algRes_ = algRes; + algResResp_ = &algRes; + AlltoAllVParam_ = param; + GetStreamInfo(algRes); + auto rtStream = param.stream.ptr(); + + HCCL_PROFILER_ADD_STREAM(rtStream, param.tag, 0, algType_); + + ExecMem execMem; + execMem.count = 0; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + + HcclResult ret = HCCL_SUCCESS; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + auto opMeta = GetOpMeta(param.opType, algRes.paramInputMem.size()); // override + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + bool massTasks = HasMassTasks(allMeshAggregationSendRecvInfo_); + if (massTasks) { + CHK_RET(SetNormalMode(dispatcher_)); + } + ret = KernelRun(param, execMem); + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + } else { + execMem.inputMem = algRes.paramInputMem; + execMem.outputMem = algRes.paramOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollRunAlltoAllVFullMesh][Orchestrate]errNo[0x%016llx]excutor run failed", + HCCL_ERROR_CODE(ret)), ret); + + HCCL_PROFILER_DEL_STREAM(rtStream); + + HCCL_INFO("tag[%s], AlltoAll executor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + + return HCCL_SUCCESS; +} + +// override----------------------资源计算接口---------------------- +HcclResult CollAlltoAllExecutor::CalcResRequest(const OpParam& param, AlgResourceRequest& resourceRequest) +{ + (void)ParseParam(param); + + u64 scratchMemSize = 0U; + u32 streamNum = 0U; + u32 notifyNum = 0U; + bool needAivBuffer = false; + std::vector opTransport { + std::vector(static_cast(COMM_LEVEL_RESERVED)) + }; + + CHK_RET(CalcScratchMemSize(scratchMemSize)); + CHK_RET(CalcStreamNum(streamNum)); + CHK_RET(CalcNotifyNum(streamNum, notifyNum)); + CHK_RET(GetIfNeedAivBuffer(needAivBuffer)); + CHK_RET(CalcCommInfo(opTransport)); + + CHK_RET(BuildResourceRequest(scratchMemSize, streamNum, notifyNum, needAivBuffer, opTransport, resourceRequest)); + HCCL_INFO("streamNum[%u], notifyNum[%u], sctrachMemSize[%llu], needAivBuffer[%u]", + resourceRequest.streamNum, resourceRequest.notifyNum, resourceRequest.scratchMemSize, + resourceRequest.needAivBuffer); + // 打印建链诉求 + for (u32 levelIndex = 0; levelIndex < COMM_LEVEL_RESERVED; levelIndex++) { + LevelNSubCommTransport &levelTransport = resourceRequest.opTransport[levelIndex]; + u32 ringSize = levelTransport.size(); + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { + SingleSubCommTransport &subCommTransport = levelTransport[ringIndex]; + u32 rankSize = subCommTransport.transportRequests.size(); + for (u32 rankIndex = 0; rankIndex < rankSize; rankIndex++) { + if (subCommTransport.transportRequests[rankIndex].isValid == true) { + HCCL_INFO("[CollAlltoAllExecutor][CalcResRequest]" \ + "levelIndex[%u], ringIndex[%u], rankIndex[%u], userRank[%u], remoteRank[%u]", + levelIndex, ringIndex, rankIndex, subCommTransport.transportRequests[rankIndex].localUserRank, + subCommTransport.transportRequests[rankIndex].remoteUserRank); + } + } + } + } + CHK_RET(CheckNeedCreateVirtualLinks(resourceRequest)); + return HCCL_SUCCESS; +} + +HcclResult CollAlltoAllExecutor::CheckNeedCreateVirtualLinks(AlgResourceRequest &resourceRequest) +{ + return HCCL_SUCCESS; +} + +HcclResult CollAlltoAllExecutor::SetExcutorExtraInfo(const std::vector &allMeshAggregationSendRecvInfo) +{ + allMeshAggregationSendRecvInfo_.clear(); + allMeshAggregationSendRecvInfo_ = allMeshAggregationSendRecvInfo; + UpdateAlltoAllZCopyMode(allMeshAggregationSendRecvInfo_); + + return HCCL_SUCCESS; +} + +void CollAlltoAllExecutor::UpdateAlltoAllZCopyMode(std::vector &allMeshAggregationSendRecvInfo) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + u64 maxSendSize = 0; + u64 maxRecvSize = 0; + for (auto &sendRecvInfo : allMeshAggregationSendRecvInfo) { + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + u64 curSendSize = sendRecvInfo.sendLength[i] + sendRecvInfo.sendOffset[i]; + maxSendSize = std::max(maxSendSize, curSendSize); + u64 curRecvSize = sendRecvInfo.recvLength[i] + sendRecvInfo.recvOffset[i]; + maxRecvSize = std::max(maxRecvSize, curRecvSize); + } + } + bool isAlltoAllZCopyMode = (maxSendSize <= GetExternalInputCCLBuffSize()) && + (maxRecvSize <= GetExternalInputCCLBuffSize()); + if (isAlltoAllZCopyMode) { + isAlltoAllZCopyMode_ = true; + } + HCCL_INFO("[CollAlltoAllExecutor][UpdateAlltoAllCopyMode] maxSendSize[%llu], maxRecvSize[%llu], "\ + "cclBufferSize[%llu]", maxSendSize, maxRecvSize, GetExternalInputCCLBuffSize()); + } else { + // 图模式走ZCopy实现 + isAlltoAllZCopyMode_ = true; + } + HCCL_DEBUG("UpdateAlltoAllZCopyMode isAlltoAllZCopyMode_[%d]", isAlltoAllZCopyMode_); +} + +void CollAlltoAllExecutor::CalcIntraMeshAggregationSendInfo(const AlltoAllUserRankInfo &userRankInfo, + const SendRecvInfo &mySendRecvInfo, const std::vector &myMeshAggregationSendRecvInfo, + u32 rankInMeshAggregation, u32 infoIndex, OneSendRecvAddrInfo &curSendInfo, u32 meshAggregationRankSize, + const bool &isSingleMesh) +{ + (void)userRankInfo; + if (infoIndex >= mySendRecvInfo.sendOffset.size() || infoIndex >= mySendRecvInfo.sendLength.size()) { + HCCL_ERROR("[CalcIntraMeshAggregationSendInfo] Invalid infoIndex[%u]", infoIndex); + return; + } + curSendInfo.localOffset = mySendRecvInfo.sendOffset[infoIndex]; + curSendInfo.localLength = mySendRecvInfo.sendLength[infoIndex]; + u64 remoteOffset = 0; + + if (isSingleMesh) { + remoteOffset = myMeshAggregationSendRecvInfo[infoIndex].recvOffset[userRankInfo.userRank]; + } else { + for (u32 j = infoIndex % meshAggregationRankSize; j <= infoIndex; j += meshAggregationRankSize) { + for (u32 k = 0; k < meshAggregationRankSize; k++) { + if (j == infoIndex && k == rankInMeshAggregation) { + break; + } + if (k < myMeshAggregationSendRecvInfo.size() && j < + myMeshAggregationSendRecvInfo[k].sendLength.size()) { + remoteOffset += myMeshAggregationSendRecvInfo[k].sendLength[j]; + } else { + HCCL_ERROR("[AlltoAllVStagedCalculator] invalid MeshAggregationSendRecvInfo size[%u]", + myMeshAggregationSendRecvInfo.size()); + return; + } + } + } + } + + curSendInfo.remoteOffset = remoteOffset; + curSendInfo.remoteLength = curSendInfo.localLength; + HCCL_DEBUG("[CalcIntraMeshAggregationSendInfo] localOffset[%llu], localLength[%llu], "\ + "remoteOffset[%llu], remoteLength[%llu]", curSendInfo.localOffset, curSendInfo.localLength, + curSendInfo.remoteOffset, curSendInfo.remoteLength); +} + +void CollAlltoAllExecutor::CalcIntraMeshAggregationRecvInfoInMeshAggregation(u32 rankIndex, u32 infoIndex, + const std::vector &myMeshAggregationSendRecvInfo, u64 &localOffset, u32 &offsetCounter, + u64 &localLength, u64 &remoteOffset, u32 meshAggregationRankSize) +{ + // 这里的判断在外部已经保证了,为了应对coverity sc + if (myMeshAggregationSendRecvInfo.size() < meshAggregationRankSize) { + HCCL_ERROR("[CalcIntraMeshAggregationSendInfo] Invalid myMeshAggregationSendRecvInfo[%u]", + myMeshAggregationSendRecvInfo.size()); + return; + } + if (myMeshAggregationSendRecvInfo[0].sendLength.size() == 0 || + myMeshAggregationSendRecvInfo[0].sendOffset.size() == 0) { + HCCL_ERROR("[CalcIntraMeshAggregationSendInfo] Invalid sendLength size[%u] or sendOffset size[%u]", + myMeshAggregationSendRecvInfo[0].sendLength.size(), myMeshAggregationSendRecvInfo[0].sendOffset.size()); + return; + } + for (u32 k = 0; k < meshAggregationRankSize; k++) { + if (infoIndex == 0) { + localOffset = 0; + localLength = myMeshAggregationSendRecvInfo[k].sendLength[rankIndex]; + remoteOffset = myMeshAggregationSendRecvInfo[k].sendOffset[rankIndex]; + break; + } + + localOffset += myMeshAggregationSendRecvInfo[k].sendLength[rankIndex]; + offsetCounter++; + if (offsetCounter == infoIndex) { + if (k == meshAggregationRankSize - 1) { + localLength = myMeshAggregationSendRecvInfo[0].sendLength[rankIndex + meshAggregationRankSize]; + remoteOffset = myMeshAggregationSendRecvInfo[0].sendOffset[rankIndex + meshAggregationRankSize]; + } else { + localLength = myMeshAggregationSendRecvInfo[k + 1].sendLength[rankIndex]; + remoteOffset = myMeshAggregationSendRecvInfo[k + 1].sendOffset[rankIndex]; + } + break; + } + } +} + +void CollAlltoAllExecutor::CalcIntraMeshAggregationRecvInfo(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &myMeshAggregationSendRecvInfo, u32 infoIndex, OneSendRecvAddrInfo &curRecvInfo, + u32 meshAggregationRankSize, const bool &isSingleMesh) +{ + u64 localOffset = 0, localLength = 0, remoteLength = 0, remoteOffset = 0; + u32 offsetCounter = 0; + + if (isSingleMesh) { + localOffset = myMeshAggregationSendRecvInfo[userRankInfo.userRank].recvOffset[infoIndex]; + localLength = myMeshAggregationSendRecvInfo[userRankInfo.userRank].recvLength[infoIndex]; + remoteLength = myMeshAggregationSendRecvInfo[infoIndex].sendLength[userRankInfo.userRank]; + remoteOffset = myMeshAggregationSendRecvInfo[infoIndex].sendOffset[userRankInfo.userRank]; + } else { + for (u32 j = userRankInfo.userRank % meshAggregationRankSize; j < userRankInfo.userRankSize; + j += meshAggregationRankSize) { + CalcIntraMeshAggregationRecvInfoInMeshAggregation(j, infoIndex, myMeshAggregationSendRecvInfo, localOffset, + offsetCounter, localLength, remoteOffset, meshAggregationRankSize); + if (offsetCounter == infoIndex || infoIndex == 0) { + break; + } + } + remoteLength = localLength; + } + curRecvInfo.localOffset = localOffset; + curRecvInfo.localLength = localLength; + + curRecvInfo.remoteOffset = remoteOffset; + curRecvInfo.remoteLength = remoteLength; + HCCL_DEBUG("[CalcIntraMeshAggregationRecvInfo] localOffset[%llu], localLength[%llu], "\ + "remoteOffset[%llu], remoteLength[%llu]", localOffset, localLength, remoteOffset, remoteLength); +} + +void CollAlltoAllExecutor::CalcIntraMeshAggregationAlltoAllMemInfo(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &allSendRecvInfo, + std::map> &sendAddrInfosIntra, + std::map> &recvAddrInfosIntra, u32 meshAggregationRankSize, + const bool &isSingleMesh) +{ + sendAddrInfosIntra.clear(); + recvAddrInfosIntra.clear(); + if (allSendRecvInfo.size() != userRankInfo.userRankSize) { + HCCL_ERROR("Invalid All send recv info size[%u], should be[%u]", allSendRecvInfo.size(), + userRankInfo.userRankSize); + return; + } + SendRecvInfo mySendRecvInfo = allSendRecvInfo[userRankInfo.userRank]; + u32 rankInMeshAggregation = userRankInfo.userRank % meshAggregationRankSize; + u32 cluserIndex = userRankInfo.userRank / meshAggregationRankSize; + auto itBegin = allSendRecvInfo.begin(); + auto itEnd = allSendRecvInfo.begin(); + std::advance(itBegin, cluserIndex * meshAggregationRankSize); + std::advance(itEnd, (cluserIndex + 1) * meshAggregationRankSize); + std::vector myMeshAggregationSendRecvInfo(itBegin, itEnd); + + for (u32 i = 0; i < userRankInfo.userRankSize; i++) { + // sendInfo 的计算 + OneSendRecvAddrInfo curSendInfo; + u32 remoteRankInMeshAggregation = i % meshAggregationRankSize; + CalcIntraMeshAggregationSendInfo(userRankInfo, mySendRecvInfo, myMeshAggregationSendRecvInfo, + rankInMeshAggregation, i, curSendInfo, meshAggregationRankSize, isSingleMesh); + sendAddrInfosIntra[remoteRankInMeshAggregation].push_back(curSendInfo); + + // recvInfo 的计算 + OneSendRecvAddrInfo curRecvInfo; + CalcIntraMeshAggregationRecvInfo(userRankInfo, myMeshAggregationSendRecvInfo, i, + curRecvInfo, meshAggregationRankSize, isSingleMesh); + recvAddrInfosIntra[remoteRankInMeshAggregation].push_back(curRecvInfo); + } +} + +HcclOpMetaInfo CollAlltoAllExecutor::GetOpMeta(HcclCMDType opType, const u64 size) +{ + bool hugeData = size > SDMA_SEND_MAX_SIZE; + + HcclOpMetaInfoDef opMeta; + if (isAlltoAllZCopyMode_) { + /* zcopy拆分4GB以上SDMA任务前,准备好子图不复用标志 */ + if (opType == HcclCMDType::HCCL_CMD_ALLTOALLV) { + opMeta = HcclOpMetaInfo::GetOneForAllToAllV(CopyPattern::ZCOPY, size, hugeData); + } else { + opMeta = HcclOpMetaInfo::GetOneForAllToAllVC(CopyPattern::ZCOPY, size, hugeData); + } + } else { + /* bcopy每次重新生成子图 */ + if (opType == HcclCMDType::HCCL_CMD_ALLTOALLV) { + opMeta = HcclOpMetaInfo::GetOneForAllToAllV(CopyPattern::BCOPY, size, hugeData); + } else { + opMeta = HcclOpMetaInfo::GetOneForAllToAllVC(CopyPattern::BCOPY, size, false); + } + } + + return opMeta; +} + +u64 CollAlltoAllExecutor::CalAlltoAllVScratchMemSize(u64 &workSpaceMemSize) // 再对齐一下 zlj +{ + u64 scratchMemSize = 0U; + if (workSpaceMemSize == 0) { + scratchMemSize = TINY_MEM_SIZE; + } else { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + scratchMemSize = std::max(std::max(workSpaceMemSize, GetExternalInputCCLBuffSize()), TINY_MEM_SIZE); + } else { + scratchMemSize = workSpaceMemSize; + } + } + return scratchMemSize; +} + +bool CollAlltoAllExecutor::NAFullmeshSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize) +{ + bool rankSizeSupport = (rankSize <= MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH); + bool isDevice91073 = (deviceType == DevType::DEV_TYPE_910_73); + bool oneLevelUseMesh = + (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_NA && + GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_FULLMESH); + bool isHCCS = !GetExternalInputInterHccsDisable(); + HCCL_DEBUG("[CollAlltoAllExecutor][AlltoAllVCOutPlace]isDevice91073 %u oneLevelUseMesh %u isHCCS %u", + isDevice91073, oneLevelUseMesh, isHCCS); + CHK_PRT_CONT(!(oneLevelUseMesh && !isDevice91073), + HCCL_WARNING("[CollAlltoAllExecutor][NAFullmeshSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm only " + "support 91073 device type, use default algorithm type")); + CHK_PRT_CONT(!(oneLevelUseMesh && !isHCCS), + HCCL_WARNING("[CollAlltoAllExecutor][NAFullmeshSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm depends " + "on HCCS, use default algorithm type")); + return (isDevice91073 && oneLevelUseMesh && rankSizeSupport && isHCCS); +} + +bool CollAlltoAllExecutor::FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize) +{ + bool rankSizeSupport = (rankSize <= MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH); + bool isDevice91073 = (deviceType == DevType::DEV_TYPE_910_73); + bool twoLevelIntraUseMesh = + (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_FULLMESH && + GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE); + bool isHCCS = !GetExternalInputInterHccsDisable(); + HCCL_DEBUG("[CollAlltoAllExecutor][AlltoAllVCOutPlace]isDevice91073 %u twoLevelIntraUseMesh %u isHCCS %u", + isDevice91073, twoLevelIntraUseMesh, isHCCS); + CHK_PRT_CONT(!(twoLevelIntraUseMesh && !isDevice91073), + HCCL_WARNING("[CollAlltoAllExecutor][FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm only " + "support 91073 device type, use default algorithm type")); + CHK_PRT_CONT(!(twoLevelIntraUseMesh && !isHCCS), + HCCL_WARNING("[CollAlltoAllExecutor][FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm depends " + "on HCCS, use default algorithm type")); + return (isDevice91073 && twoLevelIntraUseMesh && rankSizeSupport && isHCCS); +} + +bool CollAlltoAllExecutor::HasMassTasks(std::vector &allMeshAggregationSendRecvInfo) +{ + if (isAlltoAllZCopyMode_) { + return false; + } + + u64 maxSendTimes = 0; + u64 maxRecvTimes = 0; + const u64 cclBufferSize = algRes_.cclInputMem.size(); + for (auto &sendRecvInfo : allMeshAggregationSendRecvInfo) { + u64 sendTimes = 0; + u64 recvTimes = 0; + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + sendTimes += (sendRecvInfo.sendLength[i] + cclBufferSize - 1) / cclBufferSize; + recvTimes += (sendRecvInfo.recvLength[i] + cclBufferSize - 1) / cclBufferSize; + } + maxSendTimes = (maxSendTimes > sendTimes) ? maxSendTimes : sendTimes; + maxRecvTimes = (maxRecvTimes > recvTimes) ? maxRecvTimes : recvTimes; + } + const u64 massThreshold = 65535; // 65535: 单个ffts+任务中,最多承载64K个task + const u64 maxTasksPerStep = 10; // BCOPY中每次和远端通信最多消耗task数 + const u64 maxTasksBaseCost = 50; // BCOPY中除每步和远端通信外,最多消耗的task数 + u64 maxTasks = (maxSendTimes + maxRecvTimes) * maxTasksPerStep + maxTasksBaseCost; + HCCL_DEBUG("[AlltoAllV] bcopy maxSendTimes[%lu], maxRecvTimes[%lu], maxTasks[%lu], hasMassTask[%u]", maxSendTimes, + maxRecvTimes, maxTasks, (maxTasks > massThreshold)); + return (maxTasks > massThreshold); +} + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..ebb464d50b27cab3f8449da206d9bf7981a30e11 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_executor.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLTOALL_COMM_EXECUTOR_H +#define COLL_ALLTOALL_COMM_EXECUTOR_H +#include "coll_comm_executor.h" +namespace hccl { +constexpr u64 MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH = 16; +constexpr u64 TINY_MEM_SIZE = 2 * 1024 * 1024; // tinyMem size +constexpr u32 MINORS_NUM_TWO = 2; + +class CollAlltoAllExecutor : public CollNativeExecutorBase { +public: + CollAlltoAllExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAlltoAllExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; + HcclResult SetExcutorExtraInfo(const std::vector &allMeshAggregationSendRecvInfo) override; + HcclResult CalcResRequest(const OpParam& param, AlgResourceRequest &resourceRequest) override; + virtual HcclResult CheckNeedCreateVirtualLinks(AlgResourceRequest &resourceRequest); +protected: + /* *************** 算法编排 *************** */ + // 公共接口 + HcclOpMetaInfo GetOpMeta(HcclCMDType opType, const u64 size); + void UpdateAlltoAllZCopyMode(std::vector &allMeshAggregationSendRecvInfo); + void CalcIntraMeshAggregationAlltoAllMemInfo(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &allSendRecvInfo, + std::map> &sendAddrInfosIntra, + std::map> &recvAddrInfosIntra, u32 meshAggregationRankSize, + const bool &isSingleMesh); + void CalcIntraMeshAggregationSendInfo(const AlltoAllUserRankInfo &userRankInfo, + const SendRecvInfo &mySendRecvInfo, const std::vector &myMeshAggregationSendRecvInfo, + u32 rankInMeshAggregation, u32 infoIndex, OneSendRecvAddrInfo &curSendInfo, u32 meshAggregationRankSize, + const bool &isSingleMesh); + void CalcIntraMeshAggregationRecvInfo(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &myMeshAggregationSendRecvInfo, u32 infoIndex, OneSendRecvAddrInfo &curRecvInfo, + u32 meshAggregationRankSize, const bool &isSingleMesh); + void CalcIntraMeshAggregationRecvInfoInMeshAggregation(u32 rankIndex, u32 infoIndex, + const std::vector &myMeshAggregationSendRecvInfo, u64 &localOffset, u32 &offsetCounter, + u64 &localLength, u64 &remoteOffset, u32 meshAggregationRankSize); + u64 CalAlltoAllVScratchMemSize(u64 &workSpaceMemSize); + bool NAFullmeshSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize); + bool FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize); + bool HasMassTasks(std::vector &allMeshAggregationSendRecvInfo); + + OpParam AlltoAllVParam_; + AlgResourceResponse algRes_; + bool DMAReduceFlag_{false}; // 是否DMA消减 + std::vector allMeshAggregationSendRecvInfo_; + bool isAlltoAllZCopyMode_ = false; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_2level_pipeline_excecutor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_2level_pipeline_excecutor.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ffe8c2e21a5f6155db27f6da08a05d7cfb04817 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_2level_pipeline_excecutor.cc @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_all_to_all_v_2level_pipeline_excecutor.h" +namespace hccl { + +CollRunAlltoAllVTwoLevelPipeline::CollRunAlltoAllVTwoLevelPipeline(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAlltoAllExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +// 计算 alltoall pipeline 910B 的两级流水算法本卡需要的 scratch 大小(图模式需要) +u64 CollRunAlltoAllVTwoLevelPipeline::GetAlltoall2LevelPipelineScratchSize910B( + u32 rank, std::vector &allMeshAggregationSendRecvInfo) +{ + u64 scratchSize = 0; + u32 meshRankStart = (rank / topoAttr_.meshAggregationRankSize) * topoAttr_.meshAggregationRankSize; + u32 meshRankEnd = meshRankStart + topoAttr_.meshAggregationRankSize - 1; + u32 rankIntraMesh = rank - meshRankStart; + for (u32 sendRank = rankIntraMesh, userRankSize = allMeshAggregationSendRecvInfo.size(); + sendRank < userRankSize; sendRank += topoAttr_.meshAggregationRankSize) { + const std::vector& remoteSendLength = allMeshAggregationSendRecvInfo[sendRank].sendLength; + const std::vector& remoteSendOffset = allMeshAggregationSendRecvInfo[sendRank].sendOffset; + scratchSize += (remoteSendOffset[meshRankEnd] + remoteSendLength[meshRankEnd] - + remoteSendOffset[meshRankStart]); + } + return scratchSize; +} + +// 计算 alltoall pipeline 910B 的两级流水算法所有卡需要的 scratch 大小的最大值(单算子模式需要) +u64 CollRunAlltoAllVTwoLevelPipeline::GetAlltoall2LevelPipelineMaxScratchSize910B( + std::vector &allMeshAggregationSendRecvInfo) +{ + u64 maxScratchSize = 0; + for (u32 rank = 0, userRankSize = allMeshAggregationSendRecvInfo.size(); rank < userRankSize; rank++) { + u64 currRankScratchSize = GetAlltoall2LevelPipelineScratchSize910B(rank, allMeshAggregationSendRecvInfo); + maxScratchSize = (currRankScratchSize > maxScratchSize ? currRankScratchSize : maxScratchSize); + } + return maxScratchSize; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::CalcScratchMemSize(u64& scratchMemSize) +{ + scratchMemSize = 0U; + u64 tmpMemSize = 0U; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + // 图模式才需要申请 scratch 在此只计算scratchMem size + tmpMemSize = GetAlltoall2LevelPipelineMaxScratchSize910B(allMeshAggregationSendRecvInfo_); + } + scratchMemSize = CalAlltoAllVScratchMemSize(tmpMemSize); + HCCL_INFO("[CollRunAlltoAllVTwoLevelPipeline][CalcScratchMemSize] tag_[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation + 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollRunAlltoAllVTwoLevelPipeline][CalcStreamNum] tag_[%s] streamNum[%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_MESH_L0, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_MESH_L0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::CalcLevel1CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_MESH_L1, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_MESH_L1], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + + CalNoScratchAlltoallCommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::CalNoScratchAlltoallCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + CalcLevel0CommInfo(TransportMemType::CCL_OUTPUT, TransportMemType::CCL_OUTPUT, opTransport); + CalcLevel1CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + } else { + CalcLevel0CommInfo(TransportMemType::SCRATCH, TransportMemType::CCL_OUTPUT, opTransport); + CalcLevel1CommInfo(TransportMemType::CCL_INPUT, TransportMemType::SCRATCH, opTransport); + } + + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVTwoLevelPipeline::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollRunAlltoAllVTwoLevelPipeline][KernelRun] alltoall two level pipeline start"); + + // 子图 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + bool hugeData = algRes_.paramInputMem.size() > SDMA_SEND_MAX_SIZE; + bool alltoallPingPong = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + !topoAttr_.multiModuleDiffDeviceNumMode && + GetAlltoall2LevelPipelineMaxScratchSize910B(allMeshAggregationSendRecvInfo_) > + execMem.inputMem); + if (AlltoAllVParam_.opType == HcclCMDType::HCCL_CMD_ALLTOALLV) { + auto opMeta = HcclOpMetaInfo::GetOneForAllToAllV((isAlltoAllZCopyMode_ ? + CopyPattern::ZCOPY : CopyPattern::BCOPY), algRes_.paramInputMem.size(), + hugeData || alltoallPingPong); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + } else { + auto opMeta = HcclOpMetaInfo::GetOneForAllToAllV(CopyPattern::BCOPY, + algRes_.paramInputMem.size(), hugeData || alltoallPingPong); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + } + } + + bool cclEnough = true; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + GetAlltoall2LevelPipelineMaxScratchSize910B(allMeshAggregationSendRecvInfo_) > + execMem.inputMem) { + cclEnough = false; + } + HCCL_DEBUG("[CollRunAlltoAllVTwoLevelPipeline][KernelRun] alltoall pipeline run %s algo", + cclEnough ? "cclEnough" : "ping pong"); + A2aPipelineMemory a2aPipelineMemory; + a2aPipelineMemory.userInput = algRes_.paramInputMem; + a2aPipelineMemory.userOutput = algRes_.paramOutputMem; + // 具体传入 A2aPipelineMemory 对象的 alltoall pipeline executor 会根据图模式还是单算子模式 + // 选择使用 ccl 还是 scratch,不会访问空指针 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + a2aPipelineMemory.cclInBuffer = execMem.inputMem; + a2aPipelineMemory.cclOutBuffer = execMem.outputMem; + } else { + a2aPipelineMemory.scratchMem = execMem.scratchMem; + } + + std::unique_ptr alltoallPipe = nullptr; + if (cclEnough) { + alltoallPipe.reset(new (std::nothrow)AlltoallPipelineMeshPairwiseCCLEnough(dispatcher_, + allMeshAggregationSendRecvInfo_, GetWorkflowMode())); + } else { + alltoallPipe.reset(new (std::nothrow)AlltoallPipelineMeshPairwisePingPong(dispatcher_, + allMeshAggregationSendRecvInfo_, GetWorkflowMode())); + } + + CHK_RET(CheckCommSize(COMM_MESH_L0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_MESH_L0, COMM_INDEX_0); + CHK_RET(CheckCommSize(COMM_MESH_L1, COMM_INDEX_0 + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_MESH_L1, COMM_INDEX_0); // 待确认 zlj + + alltoallPipe->Prepare(topoAttr_.userRank, a2aPipelineMemory, outerCommInfo, innerCommInfo, + const_cast(param.stream), streamInfo_.ringStreams, streamInfo_.ringSignal, streamInfo_.ringSignalAux); + alltoallPipe->RunAsync(); + HCCL_INFO("[CollRunAlltoAllVTwoLevelPipeline][kernelRun] alltoall two level pipeline end"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("RunAlltoAllVTwoLevelPipeline", AlltoAllVTwoLevelPipeline, CollRunAlltoAllVTwoLevelPipeline); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_2level_pipeline_excecutor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_2level_pipeline_excecutor.h new file mode 100644 index 0000000000000000000000000000000000000000..bd6088ec0591273715f14dd33230b67716886c2a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_2level_pipeline_excecutor.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef COLL_RUN_ALLTOALLV_TWO_LEVEL_PIPELINE_H +#define COLL_RUN_ALLTOALLV_TWO_LEVEL_PIPELINE_H +#include "coll_all_to_all_executor.h" +namespace hccl { + +class CollRunAlltoAllVTwoLevelPipeline : public CollAlltoAllExecutor { +public: + CollRunAlltoAllVTwoLevelPipeline(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollRunAlltoAllVTwoLevelPipeline() = default; + +private: + u64 GetAlltoall2LevelPipelineScratchSize910B(u32 rank, std::vector &allMeshAggregationSendRecvInfo); + u64 GetAlltoall2LevelPipelineMaxScratchSize910B(std::vector &allMeshAggregationSendRecvInfo); + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcStreamNum(u32& streamNum) override; + + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel1CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalNoScratchAlltoallCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_fullmesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_fullmesh_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..57cfd7854ae6740800d762cfe49529f5ee4bf2c5 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_fullmesh_executor.cc @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_all_to_all_v_fullmesh_executor.h" + +namespace hccl { + +CollRunAlltoAllVFullMesh::CollRunAlltoAllVFullMesh(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAlltoAllExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollRunAlltoAllVFullMesh::CalcStreamNum(u32& streamNum) +{ + if (NAFullmeshSatisfyHighPerfAlltoallMeshCondition(topoAttr_.deviceType, topoAttr_.userRankSize)) { + streamNum = topoAttr_.userRankSize - 1; + } else { + streamNum = 0; + } + HCCL_INFO("[CollRunAlltoAllVFullMesh][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +// level0-level1 打平fullmesh +HcclResult CollRunAlltoAllVFullMesh::CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commCombinePara(COMM_COMBINE_ORDER, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commCombinePara, opTransport[COMM_COMBINE_ORDER], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVFullMesh::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + + +HcclResult CollRunAlltoAllVFullMesh::CalAlltoAllFullMeshCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + // A+X单机双module启用下,未使能RDMA不能进行一层pairWise。 + bool isDifModule = topoAttr_.serverNum == 1 && topoAttr_.isDiffDeviceModule && + topoAttr_.userRankSize > HCCL_ALLTOALLV_P2P_SIZE; + CHK_PRT_RET(isDifModule && !algoAttr_.isUsedRdmaOuter, + HCCL_ERROR("[CalAlltoAllFullMeshCommInfo] not support dual modules in a single server" \ + " when RDMA disabled "), HCCL_E_NOT_SUPPORT); + + // 将网卡初始化判断,提到上层调用,减少无必要的循环依赖。 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + // level0 - level1 全连接通信域 + CalcLevel0CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + // level2 层通信域 + CalcLevel2CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + } else { + // level0 - level1 全连接通信域 + CalcLevel0CommInfo(TransportMemType::PARAM_INPUT, TransportMemType::PARAM_OUTPUT, opTransport); + // level2 层通信域 + CalcLevel2CommInfo(TransportMemType::PARAM_INPUT, TransportMemType::PARAM_OUTPUT, opTransport); + } + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVFullMesh::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + + CalAlltoAllFullMeshCommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVFullMesh::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollRunAlltoAllVFullMesh][KernelRun] alltoall two level pipeline start"); + bool opbaseCopyMode = GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_; + + // 构造入参 + AlltoAllVBufferInfo sendInfo; + sendInfo.mem = opbaseCopyMode ? execMem.inputMem : algRes_.paramInputMem; + sendInfo.counts = &allMeshAggregationSendRecvInfo_[topoAttr_.userRank].sendCounts[0]; + sendInfo.displs = &allMeshAggregationSendRecvInfo_[topoAttr_.userRank].sendDispls[0]; + sendInfo.dataType = param.All2AllDataDes.sendType; + + AlltoAllVBufferInfo recvInfo; + recvInfo.mem = opbaseCopyMode ? execMem.outputMem : algRes_.paramOutputMem; + recvInfo.counts = &allMeshAggregationSendRecvInfo_[topoAttr_.userRank].recvCounts[0]; + recvInfo.displs = &allMeshAggregationSendRecvInfo_[topoAttr_.userRank].recvDispls[0]; + recvInfo.dataType = param.All2AllDataDes.recvType; + + CHK_RET(CheckCommSize(COMM_COMBINE_ORDER, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_COMBINE_ORDER, COMM_INDEX_0); + + if (NAFullmeshSatisfyHighPerfAlltoallMeshCondition(topoAttr_.deviceType, topoAttr_.userRankSize)) { + HCCL_INFO("[CollRunAlltoAllVFullMesh] one level read only algo"); + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + ActiveSlaveStreams(param.stream); + } + CHK_RET(AddSubStreamToProfiling()); + + std::unique_ptr alltoallReadOnly = nullptr; + alltoallReadOnly.reset(new (std::nothrow) AlltoAllVMeshReadOnly(dispatcher_, const_cast(param.stream), + streamInfo_.ringStreams, streamInfo_.ringSignal, streamInfo_.ringSignalAux, topoAttr_.userRank, + topoAttr_.userRankSize, outerCommInfo.links, allMeshAggregationSendRecvInfo_)); + CHK_SMART_PTR_NULL(alltoallReadOnly); + + AlltoAllUserRankInfo userRankInfo; + userRankInfo.userRank = topoAttr_.userRank; + userRankInfo.userRankSize = topoAttr_.userRankSize; + std::map> sendAddrInfosIntra; + std::map> recvAddrInfosIntra; + CalcIntraMeshAggregationAlltoAllMemInfo(userRankInfo, allMeshAggregationSendRecvInfo_, sendAddrInfosIntra, + recvAddrInfosIntra, topoAttr_.userRankSize, true); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + CHK_RET(alltoallReadOnly->Prepare(algRes_.paramInputMem, algRes_.paramOutputMem, execMem.inputMem, + execMem.outputMem, sendAddrInfosIntra, recvAddrInfosIntra, GetWorkflowMode())); + } else { + CHK_RET(alltoallReadOnly->Prepare(algRes_.paramInputMem, algRes_.paramOutputMem, algRes_.paramInputMem, + algRes_.paramOutputMem, sendAddrInfosIntra, recvAddrInfosIntra, GetWorkflowMode())); + } + alltoallReadOnly->RunAsync(); + return HCCL_SUCCESS; + } + + // 执行算法 + std::map> rankSendDisplsMap; + std::map> rankRecvDisplsMap; + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE || isAlltoAllZCopyMode_) { + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + rankSendDisplsMap.insert(std::pair>(i, allMeshAggregationSendRecvInfo_[i].sendOffset)); + rankRecvDisplsMap.insert(std::pair>(i, allMeshAggregationSendRecvInfo_[i].recvOffset)); + } + } + + std::unique_ptr pairWisePtr = nullptr; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + !isAlltoAllZCopyMode_) { // 单算子 && Buffer Copy模式 + pairWisePtr.reset(new (std::nothrow)AlltoAllVPairWise(dispatcher_)); + CHK_SMART_PTR_NULL(pairWisePtr); + CHK_RET(pairWisePtr->Prepare(sendInfo, recvInfo, execMem.inputMem, execMem.outputMem, isAlltoAllZCopyMode_, + const_cast(param.stream))); + CHK_RET(RunAlltoAllTemplate(pairWisePtr, outerCommInfo)); // 怎么知道是内层还是外层mesh zlj + } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_) { + pairWisePtr.reset(new (std::nothrow)AlltoAllVPairWise(dispatcher_, rankSendDisplsMap, rankRecvDisplsMap, + HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE)); + CHK_SMART_PTR_NULL(pairWisePtr); + DeviceMem dstMem = execMem.inputMem.range(0, algRes_.paramInputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, algRes_.paramInputMem, const_cast(param.stream))); + + CHK_RET(pairWisePtr->Prepare(sendInfo, recvInfo, execMem.inputMem, execMem.outputMem, + isAlltoAllZCopyMode_, const_cast(param.stream))); + CHK_RET(RunAlltoAllTemplate(pairWisePtr, outerCommInfo)); // inputMem_ -> outputMem_ + + DeviceMem srcMem = execMem.outputMem.range(0, algRes_.paramOutputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, algRes_.paramOutputMem, srcMem, const_cast(param.stream))); + } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + pairWisePtr.reset(new (std::nothrow)AlltoAllVPairWise(dispatcher_, rankSendDisplsMap, rankRecvDisplsMap, + HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB)); + CHK_SMART_PTR_NULL(pairWisePtr); + CHK_RET(pairWisePtr->Prepare(sendInfo, recvInfo, isAlltoAllZCopyMode_, const_cast(param.stream))); + // 保证最新的commMesh是为该次alltoallv创建(不支持多线程) + CHK_RET(RunAlltoAllTemplate(pairWisePtr, outerCommInfo)); + } else { + HCCL_ERROR("[hcclImpl][RunAlltoAllVFullMesh]work flow mode is invalid"); + return HCCL_E_PARA; + } + HCCL_INFO("[CollRunAlltoAllVFullMesh] excutor run success"); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("RunAlltoAllVFullMesh", AlltoAllVFullMesh, CollRunAlltoAllVFullMesh); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_fullmesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_fullmesh_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..db329bb5f101aca33761b14d4c719e54518c93b0 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_fullmesh_executor.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_RUN_ALLTOALLV_FULLMESH_EXECUTOR_H +#define COLL_RUN_ALLTOALLV_FULLMESH_EXECUTOR_H +#include "coll_all_to_all_executor.h" +namespace hccl { +class CollRunAlltoAllVFullMesh : public CollAlltoAllExecutor { + +public: + CollRunAlltoAllVFullMesh(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollRunAlltoAllVFullMesh() = default; + +private: + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalAlltoAllFullMeshCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_staged_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_staged_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..133c84c5757117b6ece8f902a96f9ca5f958bab5 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_staged_executor.cc @@ -0,0 +1,489 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_all_to_all_v_staged_executor.h" +namespace hccl { + +CollRunAlltoAllVStaged::CollRunAlltoAllVStaged(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAlltoAllExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollRunAlltoAllVStaged::ParallelTaskLoaderProcess(const std::string &tag, Stream &stream, + SubCommInfo &outerCommInfo, std::vector &ringStreams) +{ + u32 streamIndex; + std::vector streamsPtr; + streamsPtr.resize(ringStreams.size() + 1); + + for (streamIndex = 0; streamIndex < ringStreams.size(); streamIndex++) { // StreamInfo_.ringStreams + streamsPtr[streamIndex] = &ringStreams[streamIndex]; + } + streamsPtr[streamIndex] = &stream; + + HCCL_INFO("[ParallelTaskLoaderProcess]main stream[%p], streams size[%u]", stream.ptr(), streamsPtr.size()); + + // 准备多线程启动参数 + CHK_RET(parallelTaskLoader_->Prepare(streamsPtr, outerCommInfo)); + + // 启动多线程处理 + CHK_RET(parallelTaskLoader_->StartTaskLoad()); + + // 等待多线程处理结果 + CHK_RET(parallelTaskLoader_->WaitTaskLoadFinish()); + + // 销毁通信域 + CHK_RET(parallelTaskLoader_->ClearTagCommInfo()); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::CalcStreamNum(u32& streamNum) +{ + streamNum = 0U; + if (FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(topoAttr_.deviceType, + topoAttr_.userRankSize)) { + streamNum = topoAttr_.meshAggregationRankSize - 1; + } else { + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE || isAlltoAllZCopyMode_) { + if ((GetExternalInputHcclAlgoConfig()[0] != HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE || + GetExternalInputHcclAlgoConfig()[1] != HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE) && + const_cast(topoAttr_).pairLinkCounter[static_cast( + LinkTypeInServer::HCCS_SW_TYPE)] == 0 && topoAttr_.meshAggregationRankSize != 1) { + streamNum = topoAttr_.meshAggregationRankSize - MINORS_NUM_TWO; + } + } + } + + HCCL_INFO("[CollRunAlltoAllVStaged][CalcStreamNum] tag[%s] streamNum[%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +void CollRunAlltoAllVStaged::CalcWorkSpaceMemSize(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &allMeshAggregationSendRecvInfo, u64 &workspaceMemSize, + u32 meshAggregationRankSize) +{ + for (const auto &oneMeshAggregationSendRecvInfo : allMeshAggregationSendRecvInfo) { + for (const auto &sendLength : oneMeshAggregationSendRecvInfo.sendLength) { + HCCL_DEBUG("[CalcWorkSpaceMemSize] sendLength[%llu]", sendLength); + } + for (const auto &sendOffset : oneMeshAggregationSendRecvInfo.sendOffset) { + HCCL_DEBUG("[CalcWorkSpaceMemSize] sendOffset[%llu]", sendOffset); + } + for (const auto &recvLength : oneMeshAggregationSendRecvInfo.recvLength) { + HCCL_DEBUG("[CalcWorkSpaceMemSize] recvLength[%llu]", recvLength); + } + for (const auto &recvOffset : oneMeshAggregationSendRecvInfo.recvOffset) { + HCCL_DEBUG("[CalcWorkSpaceMemSize] recvOffset[%llu]", recvOffset); + } + } + if (allMeshAggregationSendRecvInfo.size() % meshAggregationRankSize != 0 || + allMeshAggregationSendRecvInfo.size() == 0) { + workspaceMemSize = 0; + HCCL_ERROR("Invalid Send Recv Info Size[%u]", allMeshAggregationSendRecvInfo.size()); + return; + } + workspaceMemSize = 0; + u32 meshAggregationIndex = userRankInfo.userRank / meshAggregationRankSize; + u32 meshAggregationRankBegin = meshAggregationIndex * meshAggregationRankSize; + for (u32 infoIndex = userRankInfo.userRank % meshAggregationRankSize; infoIndex < userRankInfo.userRankSize; + infoIndex += meshAggregationRankSize) { + for (u32 k = meshAggregationRankBegin; k < meshAggregationRankBegin + meshAggregationRankSize; k++) { + workspaceMemSize += allMeshAggregationSendRecvInfo[k].sendLength[infoIndex]; + } + } + HCCL_INFO("[AlltoAllVStagedCalculator][CalcWorkSpaceMemSize] workspaceMemSize[%llu]", workspaceMemSize); +} + +HcclResult CollRunAlltoAllVStaged::CalcScratchMemSize(u64& scratchMemSize) +{ + scratchMemSize = 0U; + u64 maxWorkSpaceMemSize = 0; + + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + u64 workSpaceMemSize = 0; + AlltoAllUserRankInfo tmpUserRankInfo; + tmpUserRankInfo.userRankSize = topoAttr_.userRankSize; + tmpUserRankInfo.userRank = topoAttr_.userRank; + CalcWorkSpaceMemSize(tmpUserRankInfo, allMeshAggregationSendRecvInfo_, workSpaceMemSize, + topoAttr_.meshAggregationRankSize); + scratchMemSize = CalAlltoAllVScratchMemSize(workSpaceMemSize); + } else { + for (u32 rank = 0; rank < topoAttr_.userRankSize; rank++) { + u64 workSpaceMemSize = 0; + AlltoAllUserRankInfo tmpUserRankInfo; + tmpUserRankInfo.userRankSize = topoAttr_.userRankSize; + tmpUserRankInfo.userRank = rank; + CalcWorkSpaceMemSize(tmpUserRankInfo, allMeshAggregationSendRecvInfo_, workSpaceMemSize, + topoAttr_.meshAggregationRankSize); + maxWorkSpaceMemSize = std::max(workSpaceMemSize, maxWorkSpaceMemSize); + } + scratchMemSize = CalAlltoAllVScratchMemSize(maxWorkSpaceMemSize); + } + + HCCL_INFO("[CollRunAlltoAllVStaged][CalcScratchMemSize] scratchMemSize[%llu]", scratchMemSize); + return HCCL_SUCCESS; +} + +bool CollRunAlltoAllVStaged::CheckNeedRecreateComm(u64 lastScratchMemSize) +{ + u64 tmpScratchMemSize = 0; + CalcScratchMemSize(tmpScratchMemSize); + return ((lastScratchMemSize < tmpScratchMemSize) ? (true) : (false)); +} + +HcclResult CollRunAlltoAllVStaged::CheckNeedCreateVirtualLinks(AlgResourceRequest &resourceRequest) +{ + bool alltoallMeshReadOnly = FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(topoAttr_.deviceType, + topoAttr_.userRankSize); + HCCL_DEBUG("[CollRunAlltoAllVStaged][CheckNeedCreateVirtualLinks] alltoallMeshReadOnly[%d]," \ + "resourceRequest.streamNum[%d], GetExternalInputHcclEnableFfts()[%d], isAlltoAllZCopyMode_[%d]", + alltoallMeshReadOnly, resourceRequest.streamNum, GetExternalInputHcclEnableFfts(), isAlltoAllZCopyMode_); + if (!alltoallMeshReadOnly && (resourceRequest.streamNum != 0) && (!GetExternalInputHcclEnableFfts()) + && isAlltoAllZCopyMode_) { + for (auto &levelNSubCommTransport : resourceRequest.opTransport) { + for (auto &singleSubCommTransport : levelNSubCommTransport) { + singleSubCommTransport.needVirtualLink = true; + HCCL_INFO("[CollRunAlltoAllVStaged][CheckNeedCreateVirtualLinks] needVirtualLink is true"); + } + } + } + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_MESH_L0, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_MESH_L0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::CalcLevel1CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_MESH_L1, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_MESH_L1], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MESH); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::CalStagedAlltoallVCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + // 将网卡初始化判断,提到上层调用,减少无必要的循环依赖。 + bool alltoallMeshReadOnly = FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(topoAttr_.deviceType, + topoAttr_.userRankSize); + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + !isAlltoAllZCopyMode_) { // 单算子 && BCopy模式 + HCCL_INFO("cal comm in opbase and Bcopy mode"); + CalcLevel0CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + CalcLevel1CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + CalcLevel2CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_) { // 单算子 && ZCopy模式 + HCCL_INFO("cal comm in opbase and Zcopy mode"); + if (topoAttr_.isSingleMeshAggregation) { + CalcLevel0CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + } else { + CalcLevel0CommInfo(TransportMemType::CCL_INPUT, (alltoallMeshReadOnly ? + TransportMemType::CCL_OUTPUT : TransportMemType::SCRATCH), opTransport); + CalcLevel1CommInfo(TransportMemType::SCRATCH, TransportMemType::CCL_OUTPUT, opTransport); + } + CalcLevel2CommInfo(TransportMemType::CCL_INPUT, TransportMemType::CCL_OUTPUT, opTransport); + } else { + HCCL_INFO("cal comm in graph mode"); + CalcLevel0CommInfo(TransportMemType::PARAM_INPUT, TransportMemType::SCRATCH, opTransport); + CalcLevel1CommInfo(TransportMemType::SCRATCH, TransportMemType::PARAM_OUTPUT, opTransport); + CalcLevel2CommInfo(TransportMemType::PARAM_INPUT, TransportMemType::PARAM_OUTPUT, opTransport); + } + HCCL_DEBUG("[CollRunAlltoAllVStaged][CalStagedAlltoallVCommInfo] ends"); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + + CalStagedAlltoallVCommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::PrepareAlltoAllVStaged1(DeviceMem &sendBuf, DeviceMem &recvBuf, DeviceMem &scratchMem, + std::map> &sendAddrInfosIntra, + std::map> &recvAddrInfosIntra, + Stream &stream, const std::string &tag, std::unique_ptr &alltoallOuter, + ExecMem &execMem) +{ + // opbase BCopy 不支持fullmesh算法,因此不必做算法选择 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + !isAlltoAllZCopyMode_) { // 单算子 && Buffer拷贝模式 + HCCL_INFO("Running alltoallv Staged Pairwise intra Server"); + alltoallOuter.reset(new (std::nothrow)AlltoAllVStagedPairwise(dispatcher_, stream)); + CHK_SMART_PTR_NULL(alltoallOuter); + CHK_RET(alltoallOuter->Prepare(sendBuf, scratchMem, execMem.inputMem, execMem.outputMem, sendAddrInfosIntra, + recvAddrInfosIntra, isAlltoAllZCopyMode_)); + } else { + bool isOpBaseZCopy = GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && isAlltoAllZCopyMode_; + DeviceMem inBuf = (isOpBaseZCopy) ? execMem.inputMem : sendBuf; + // 单MeshAggregation下, 分级算法不做第二级, 结果输出到outCCLbuffer_ + DeviceMem outBuf = (isOpBaseZCopy && topoAttr_.isSingleMeshAggregation) ? recvBuf : scratchMem; + // opbase ZCopy 与 graph,除input buffer差异外,其余行为应保持一致 + if (isOpBaseZCopy) { // 单算子 && ZCopy模式 + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.inputMem, sendBuf, stream)); + } + // 互联场景, alltoall暂不支持走fullmesh+pairwise + if ((GetExternalInputHcclAlgoConfig()[0] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE && + GetExternalInputHcclAlgoConfig()[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE) || + const_cast(topoAttr_).pairLinkCounter[static_cast(LinkTypeInServer::HCCS_SW_TYPE)] != 0 || + topoAttr_.meshAggregationRankSize == 1) { + HCCL_INFO("Running alltoallv Staged Pairwise intra Server"); + alltoallOuter.reset(new (std::nothrow)AlltoAllVStagedPairwise(dispatcher_, stream)); + CHK_SMART_PTR_NULL(alltoallOuter); + CHK_RET(alltoallOuter->Prepare(inBuf, outBuf, sendAddrInfosIntra, recvAddrInfosIntra, + isAlltoAllZCopyMode_)); + } else { + HCCL_INFO("Running alltoallv Staged Mesh intra Server"); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + ActiveSlaveStreams(AlltoAllVParam_.stream); + } + // 添加从流profiling, 用于维护planID + CHK_RET(AddSubStreamToProfiling()); + + if (GetExternalInputHcclEnableFfts() || streamInfo_.ringStreams.size() == 0) { + alltoallOuter.reset(new (std::nothrow) AlltoAllVStagedMesh(dispatcher_, stream, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, topoAttr_.userRank, streamInfo_.ringStreams)); + } else { + alltoallOuter.reset(new (std::nothrow) AlltoAllVStagedMesh(vDispatcher_, stream, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, topoAttr_.userRank, streamInfo_.ringStreams)); + } + CHK_SMART_PTR_NULL(alltoallOuter); + CHK_RET(alltoallOuter->Prepare(inBuf, outBuf, sendAddrInfosIntra, recvAddrInfosIntra, isAlltoAllZCopyMode_, + streamInfo_.ringStreams)); + } + } + return HCCL_SUCCESS; +} + +void CollRunAlltoAllVStaged::CalcInterMeshAggregationRecvRemoteOffset(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &allSendRecvInfo, u32 index, u64 &remoteOffset, u32 meshAggregationRankSize) +{ + // 对于stage1 来说,相当于是从rand index 发送给 userRankInfo.userRank, 然后计算这种情况下的stage1 的接收偏移 + remoteOffset = 0; + u32 anchoruserRank_ = index; + u32 anchorIndex = userRankInfo.userRank; + u32 beginIndex = anchorIndex % meshAggregationRankSize; + u32 beginRank = anchoruserRank_ / meshAggregationRankSize * meshAggregationRankSize; + bool getAnchor = false; + for (index = beginIndex; index <= anchorIndex; index += meshAggregationRankSize) { + for (u32 rank = beginRank; rank < beginRank + meshAggregationRankSize; rank++) { + if (index == anchorIndex && rank == anchoruserRank_) { + getAnchor = true; + break; + } + remoteOffset += allSendRecvInfo[rank].sendLength[index]; + } + if (getAnchor) { + break; + } + } +} + +void CollRunAlltoAllVStaged::CalcInterMeshAggregationAlltoAllMemInfo( + const AlltoAllUserRankInfo &userRankInfo, const std::vector &allSendRecvInfo, + std::map> &sendAddrInfosInter, + std::map> &recvAddrInfosInter, + u32 meshAggregationRankSize) +{ + sendAddrInfosInter.clear(); + recvAddrInfosInter.clear(); + + u64 localOffsetMarker = 0; + for (u32 toRank = 0; toRank < userRankInfo.userRankSize; toRank++) { + u32 myRank = userRankInfo.userRank; + u32 myMeshAggregationRankBegin = myRank / meshAggregationRankSize * meshAggregationRankSize; + u32 myMeshAggregationRankEnd = myMeshAggregationRankBegin + meshAggregationRankSize; + + for (u32 myMeshAggregationRank = myMeshAggregationRankBegin; myMeshAggregationRank < myMeshAggregationRankEnd; + myMeshAggregationRank++) { + if (toRank % meshAggregationRankSize == myRank % meshAggregationRankSize) { + OneSendRecvAddrInfo sendAddrInfo; + sendAddrInfo.localLength = allSendRecvInfo[myMeshAggregationRank].sendLength[toRank]; + sendAddrInfo.localOffset = localOffsetMarker; + localOffsetMarker += sendAddrInfo.localLength; + sendAddrInfo.remoteOffset = allSendRecvInfo[toRank].recvOffset[myMeshAggregationRank]; + sendAddrInfo.remoteLength = allSendRecvInfo[toRank].recvLength[myMeshAggregationRank]; + u32 remoteRankInter = toRank / meshAggregationRankSize; + sendAddrInfosInter[remoteRankInter].push_back(sendAddrInfo); + HCCL_DEBUG("[CalcInterMeshAggregationAlltoAllMemInfo] sendAddrInfo localOffset[%llu], "\ + "localLength[%llu], remoteOffset[%llu], remoteLength[%llu]", sendAddrInfo.localOffset, + sendAddrInfo.localLength, sendAddrInfo.remoteOffset, sendAddrInfo.remoteLength); + } + } + } + + // 构造接收数据结构 + for (u32 index = 0; index < userRankInfo.userRankSize; index++) { + OneSendRecvAddrInfo recvAddrInfo; + u32 meshAggregationIndex = index / meshAggregationRankSize; + + recvAddrInfo.localOffset = allSendRecvInfo[userRankInfo.userRank].recvOffset[index]; + recvAddrInfo.localLength = allSendRecvInfo[userRankInfo.userRank].recvLength[index]; + // index 是 从那个rank 来的 + recvAddrInfo.remoteLength = allSendRecvInfo[index].sendLength[userRankInfo.userRank]; + u64 remoteOffset = 0; + CalcInterMeshAggregationRecvRemoteOffset(userRankInfo, allSendRecvInfo, index, remoteOffset, + meshAggregationRankSize); + + recvAddrInfo.remoteOffset = remoteOffset; + recvAddrInfosInter[meshAggregationIndex].push_back(recvAddrInfo); + HCCL_DEBUG("[CalcInterMeshAggregationAlltoAllMemInfo] recvAddrInfo localOffset[%llu], "\ + "localLength[%llu], remoteOffset[%llu], remoteLength[%llu]", recvAddrInfo.localOffset, + recvAddrInfo.localLength, recvAddrInfo.remoteOffset, recvAddrInfo.remoteLength); + } +} + +HcclResult CollRunAlltoAllVStaged::PrepareAlltoAllVStaged2(DeviceMem &recvBuf, DeviceMem &scratchMem, + std::map> &sendAddrInfosInter, + std::map> &recvAddrInfosInter, + Stream &stream, const std::string &tag, std::unique_ptr &alltoallInner, + ExecMem &execMem) +{ + alltoallInner.reset(new (std::nothrow)AlltoAllVStagedPairwise(dispatcher_, stream)); + CHK_SMART_PTR_NULL(alltoallInner); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + !isAlltoAllZCopyMode_) { // 单算子 && BCopy模式 + CHK_RET(alltoallInner->Prepare(scratchMem, recvBuf, execMem.inputMem, execMem.outputMem, sendAddrInfosInter, + recvAddrInfosInter, isAlltoAllZCopyMode_)); + } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_) { // 单算子 && ZCopy模式 + CHK_RET(alltoallInner->Prepare(scratchMem, execMem.outputMem, execMem.inputMem, execMem.outputMem, + sendAddrInfosInter, recvAddrInfosInter, isAlltoAllZCopyMode_)); + } else { + CHK_RET(alltoallInner->Prepare(scratchMem, recvBuf, sendAddrInfosInter, recvAddrInfosInter, + isAlltoAllZCopyMode_)); + } + return HCCL_SUCCESS; +} + +HcclResult CollRunAlltoAllVStaged::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollRunAlltoAllVStaged][KernelRun] alltoall staged starts"); + CHK_PRT_RET(topoAttr_.userRankSize % topoAttr_.meshAggregationRankSize != 0, + HCCL_ERROR("userRankSize[%u] is not an Integer multiple of MeshAggregation Dev Num[%u]", + topoAttr_.userRankSize, topoAttr_.meshAggregationRankSize), HCCL_E_PARA); + + AlltoAllUserRankInfo userRankInfo; + userRankInfo.userRank = topoAttr_.userRank; + userRankInfo.userRankSize = topoAttr_.userRankSize; + bool alltoallMeshReadOnly = FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(topoAttr_.deviceType, + topoAttr_.userRankSize); + + std::map> sendAddrInfosIntra; + std::map> recvAddrInfosIntra; + bool isSingleMesh = GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_ && topoAttr_.isSingleMeshAggregation; + CalcIntraMeshAggregationAlltoAllMemInfo(userRankInfo, allMeshAggregationSendRecvInfo_, sendAddrInfosIntra, + recvAddrInfosIntra, topoAttr_.meshAggregationRankSize, isSingleMesh); + + CHK_RET(CheckCommSize(COMM_MESH_L0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_MESH_L0, COMM_INDEX_0); + + if (alltoallMeshReadOnly) { + HCCL_INFO("[AlltoAllOperator][RunAlltoAllVStaged] staged 1 read only algo"); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + ActiveSlaveStreams(param.stream); + } + // 添加从流profiling, 用于维护planID + CHK_RET(AddSubStreamToProfiling()); + std::unique_ptr alltoallReadOnly = nullptr; + if (GetExternalInputHcclEnableFfts()) { + alltoallReadOnly.reset(new (std::nothrow) AlltoAllVMeshReadOnly(dispatcher_, + const_cast(param.stream), streamInfo_.ringStreams, streamInfo_.ringSignal, + streamInfo_.ringSignalAux, topoAttr_.userRank, topoAttr_.meshAggregationRankSize, + outerCommInfo.links, allMeshAggregationSendRecvInfo_)); + } else { + alltoallReadOnly.reset(new (std::nothrow) AlltoAllVMeshReadOnly(dispatcher_, + const_cast(param.stream), streamInfo_.ringStreams, streamInfo_.ringSignal, + streamInfo_.ringSignalAux, topoAttr_.userRank, topoAttr_.meshAggregationRankSize, + outerCommInfo.links, allMeshAggregationSendRecvInfo_)); + } + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + CHK_RET(alltoallReadOnly->Prepare(algRes_.paramInputMem, (topoAttr_.isSingleMeshAggregation ? + algRes_.paramOutputMem : execMem.scratchMem), execMem.inputMem, execMem.outputMem, sendAddrInfosIntra, + recvAddrInfosIntra, GetWorkflowMode())); + } else { + CHK_RET(alltoallReadOnly->Prepare(algRes_.paramInputMem, (topoAttr_.isSingleMeshAggregation ? + algRes_.paramOutputMem : execMem.scratchMem), algRes_.paramInputMem, algRes_.paramOutputMem, + sendAddrInfosIntra, recvAddrInfosIntra, GetWorkflowMode())); + } + alltoallReadOnly->RunAsync(); + } else { + std::unique_ptr alltoallOuter = nullptr; + CHK_RET(PrepareAlltoAllVStaged1(algRes_.paramInputMem, algRes_.paramOutputMem, execMem.scratchMem, + sendAddrInfosIntra, recvAddrInfosIntra, const_cast(param.stream), tag_, alltoallOuter, execMem)); + if ((streamInfo_.ringStreams.size() != 0) && + (!GetExternalInputHcclEnableFfts()) && isAlltoAllZCopyMode_) { + HCCL_INFO("[AlltoAllOperator][RunAlltoAllVStaged] staged 0 use parallel multi-thread delivery of tasks"); + CHK_RET(RunTemplateWithVirtualLink(alltoallOuter, outerCommInfo)); + // 多流场景下,并行多线程下发task处理 + CHK_RET(ParallelTaskLoaderProcess(tag_, const_cast(param.stream), outerCommInfo, + streamInfo_.ringStreams)); + } else { + CHK_RET(RunAlltoAllVTemplateStaged(alltoallOuter, outerCommInfo)); + } + + HCCL_INFO("[hcclImpl][RunAlltoAllVStaged] stage0 run success!"); + } + std::map> sendAddrInfosInter; + std::map> recvAddrInfosInter; + CalcInterMeshAggregationAlltoAllMemInfo(userRankInfo, allMeshAggregationSendRecvInfo_,sendAddrInfosInter, + recvAddrInfosInter, topoAttr_.meshAggregationRankSize); + + if (((GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_) || alltoallMeshReadOnly) && topoAttr_.isSingleMeshAggregation) { + HCCL_DEBUG("we don't need to do stage 2 when there is only one mesh aggregation"); + // we don't need to do stage 2 when there is only one mesh aggregation + } else { + HCCL_INFO("[hcclImpl][RunAlltoAllVStaged] stage1 run starts!"); + CHK_RET(CheckCommSize(COMM_MESH_L1, COMM_INDEX_0 + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_MESH_L1, COMM_INDEX_0); + std::unique_ptr alltoallInner = nullptr; + PrepareAlltoAllVStaged2(algRes_.paramOutputMem, execMem.scratchMem, sendAddrInfosInter, recvAddrInfosInter, + const_cast(param.stream), tag_, alltoallInner, execMem); + CHK_RET(RunAlltoAllVTemplateStaged(alltoallInner, innerCommInfo)); + } + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + isAlltoAllZCopyMode_ && !topoAttr_.isSingleMeshAggregation) { + DeviceMem srcMem = (execMem.outputMem).range(0, algRes_.paramOutputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, algRes_.paramOutputMem, srcMem, const_cast(param.stream))); + } + + HCCL_INFO("[CollRunAlltoAllVStaged][kernelRun] alltoall staged ends"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("RunAlltoAllVStaged", AlltoAllVStaged, CollRunAlltoAllVStaged); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_staged_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_staged_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..5de041fee091c30c3ee6cd82aa8dbd5603c1a01a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_to_all/coll_all_to_all_v_staged_executor.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef COLL_RUN_ALLTOALLV_TWO_LEVEL_PIPELINE_H +#define COLL_RUN_ALLTOALLV_TWO_LEVEL_PIPELINE_H +#include "coll_all_to_all_executor.h" +namespace hccl { +class CollRunAlltoAllVStaged : public CollAlltoAllExecutor { + +public: + CollRunAlltoAllVStaged(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollRunAlltoAllVStaged() = default; + + bool CheckNeedRecreateComm(u64 lastScratchMemSize) override; + HcclResult CheckNeedCreateVirtualLinks(AlgResourceRequest &resourceRequest) override; + HcclResult ParallelTaskLoaderProcess(const std::string &tag, Stream &stream, SubCommInfo &outerCommInfo, + std::vector &ringStreams); + +private: + + HcclResult CalcStreamNum(u32& streamNum) override; + void CalcWorkSpaceMemSize(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &allMeshAggregationSendRecvInfo, u64 &workspaceMemSize, + u32 meshAggregationRankSize); + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel1CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalStagedAlltoallVCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; + + HcclResult PrepareAlltoAllVStaged1(DeviceMem &sendBuf, DeviceMem &recvBuf, DeviceMem &scratchMem, + std::map> &sendAddrInfosIntra, + std::map> &recvAddrInfosIntra, + Stream &stream, const std::string &tag, std::unique_ptr &alltoallOuter, + ExecMem &execMem); + void CalcInterMeshAggregationRecvRemoteOffset(const AlltoAllUserRankInfo &userRankInfo, + const std::vector &allSendRecvInfo, u32 index, u64 &remoteOffset, u32 meshAggregationRankSize); + void CalcInterMeshAggregationAlltoAllMemInfo( + const AlltoAllUserRankInfo &userRankInfo, const std::vector &allSendRecvInfo, + std::map> &sendAddrInfosInter, + std::map> &recvAddrInfosInter, + u32 meshAggregationRankSize); + HcclResult PrepareAlltoAllVStaged2(DeviceMem &recvBuf, DeviceMem &scratchMem, + std::map> &sendAddrInfosInter, + std::map> &recvAddrInfosInter, + Stream &stream, const std::string &tag, std::unique_ptr &alltoallInner, + ExecMem &execMem); +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cd0f83b50a903ef6cee2d2570b4464aa0d73fd0b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/CMakeLists.txt @@ -0,0 +1,7 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_broadcast_for_310p_comm_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/coll_broadcast_for_310p_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/coll_broadcast_for_310p_comm_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a53f91b6fc8e72f1c41843b8f967246e7a97096 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/coll_broadcast_for_310p_comm_executor.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_broadcast_for_310p_comm_executor.h" + +namespace hccl { +CollBroadcastFor310PCommExecutor::CollBroadcastFor310PCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollBroadcastExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollBroadcastFor310PCommExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + + +HcclResult CollBroadcastFor310PCommExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) +{ + HCCL_INFO("[CollBroadcastFor310PCommExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollBroadcastFor310PCommExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastFor310PCommExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + std::unique_ptr executor; + + executor.reset(new (std::nothrow) BroadcastRing(dispatcher_)); + CHK_SMART_PTR_NULL(executor); + // 获取root + u32 rootRank = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL0, COMM_INDEX_0, param.root, rootRank)); + + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, param.root)); + + u32 rankSize = outerCommInfo.localRankSize; + CHK_RET(executor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(executor, outerCommInfo)); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("BroadCastCommFor310P", BroadcastFor310PComm, CollBroadcastFor310PCommExecutor); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/coll_broadcast_for_310p_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/coll_broadcast_for_310p_comm_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c0a2308d174054f84277d3e0bd821c10e0123cdd --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/310P/coll_broadcast_for_310p_comm_executor.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef COLL_BROADCAST_FOR_310P_RING_EXECUTOR_H +#define COLL_BROADCAST_FOR_310P_RING_EXECUTOR_H +#include "../coll_broadcast_executor.h" +namespace hccl { +class CollBroadcastFor310PCommExecutor : public CollBroadcastExecutor { + +public: + CollBroadcastFor310PCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollBroadcastFor310PCommExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..206e8360671f6b6f9295e9c75576b6a86dc513e4 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/CMakeLists.txt @@ -0,0 +1,13 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_broadcast_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_broadcast_double_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_broadcast_mesh_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_broadcast_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_broadcast_comm_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) + +add_subdirectory(310P) diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_comm_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..cafe1fca9a4aadbb81647ffddc87c56cd96b546c --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_comm_executor.cc @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_broadcast_comm_executor.h" + +namespace hccl { + +CollBroadcastCommExecutor::CollBroadcastCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollBroadcastExecutor(dispatcher, topoMatcher) +{ +} + + +HcclResult CollBroadcastCommExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcCombinedCommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastCommExecutor::CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_COMBINE, CommType::COMM_TAG_MAX); + if (UseInterServerNHRAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING; + } else if (UseInterServerNHRV1Algo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING_V1; + } else if (UseInterServerNBAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_BRUCK; + } else { + commParaInfo.commType = CommType::COMM_TAG_RING_INNER; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_COMBINE], inputType, outputType)); + + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastCommExecutor::CalcStreamNum(u32& streamNum) +{ + // 只传递从流数量 + streamNum = 0; + HCCL_INFO("[CollBroadcastCommExecutor][CalcStreamNum]tag[%s] streamNum_ is [%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastCommExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_COMBINE, COMM_INDEX_0 + 1)); + SubCommInfo combinedCommInfo = GetSubCommInfo(COMM_COMBINE, 0); + + std::unique_ptr executor; + u64 curSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + if (UseInterServerNHRAlgo(algType_)) { + if (curSize <= NHR_BCAST_SMALL_SIZE) { + executor.reset(new (std::nothrow) BroadcastNHROneshot(dispatcher_)); + } else { + executor.reset(new (std::nothrow) BroadcastNHR(dispatcher_)); + } + HCCL_INFO("broadcast comm: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + executor.reset(new (std::nothrow) BroadcastNHRV1(dispatcher_)); + HCCL_INFO("broadcast comm: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + if (ShouldUseBinaryBroadcastOfNB(curSize, combinedCommInfo.localRankSize, topoAttr_.userRankSize, + topoAttr_.deviceNumPerAggregation)) { + executor.reset(new (std::nothrow) BroadcastNBBinary(dispatcher_)); + } else { + executor.reset(new (std::nothrow) BroadcastNB(dispatcher_)); + } + HCCL_INFO("broadcast comm: using nonuniform-bruck algo inter-server."); + } else { + executor.reset(new (std::nothrow) BroadcastRing(dispatcher_)); + HCCL_INFO("broadcast comm: using ring algo inter-server."); + } + CHK_SMART_PTR_NULL(executor); + + // 获取root + u32 rootRank = 0; + CHK_RET(GetRankByUserRank(COMM_COMBINE, COMM_INDEX_0, param.root, rootRank)); + + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, rootRank)); + CHK_RET(RunTemplate(executor, combinedCommInfo)); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("BroadCastComm", BroadcastComm, CollBroadcastCommExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_comm_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..fecbeada90e6dc79fa74c5397bf2262023c9e655 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_comm_executor.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_BROADCAST_EXECUTOR_H +#define COLL_BROADCAST_EXECUTOR_H +#include "coll_broadcast_executor.h" +namespace hccl { +class CollBroadcastCommExecutor : public CollBroadcastExecutor { + +public: + CollBroadcastCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollBroadcastCommExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcCombinedCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport); + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_double_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_double_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..30924cff5e269acfb39422a7da95f74faa7e4f0a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_double_ring_executor.cc @@ -0,0 +1,316 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_broadcast_double_ring_executor.h" + +namespace hccl { + +CollBroadcastDoubleRingExecutor::CollBroadcastDoubleRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollBroadcastExecutor(dispatcher, topoMatcher) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + DMAReduceFlag_ = true; + } else { + DMAReduceFlag_ = false; + } +} + +HcclResult CollBroadcastDoubleRingExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollBroadcastDoubleRingExecutor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastDoubleRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastDoubleRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastDoubleRingExecutor::CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastDoubleRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[BroadCastOperator][BroadCastDoubleRingExecutor] The BroadCastDoubleRingExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector> mulRingSlice; // 数据基于该rank上环0的偏移 + // step1: 节点内的scatter + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + // 按ranksize得到内存切分slice数 + u32 sliceNum = outerCommInfo.localRankSize; + // 将根节点数据切分成sliceNum份 + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + HCCL_DEBUG("[BroadCastDoubleRingExecutor]: ringNum[%u] sliceNum[%u]", ringNum, sliceNum); + + /* 外层:scatter */ + // 将每slice再切分成2份,按各ring的dev顺序排列 + // 构造ring algorithm对应的reduce-scatter实例 + mulRingSlice = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, topoAttr_.nicList); + CHK_PRT_RET(mulRingSlice.size() != ringNum, + HCCL_ERROR("[BroadCastOperator][BroadCastDoubleRingExecutor]" + "ringNum[%u] !=mulRingSlice size[%llu]", + ringNum, mulRingSlice.size()), HCCL_E_INTERNAL); + + HcomCollOpInfo *scatterOpInfoPtr = nullptr; + HcomCollOpInfo scatterOpInfo = { + "", execMem.inputPtr, nullptr, param.DataDes.count, param.DataDes.dataType, param.root + }; + + if (DMAReduceFlag_) { + scatterOpInfoPtr = &scatterOpInfo; + } + CHK_RET(MultiRingScatter(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, param.DataDes.dataType, + mulRingSlice, param.root, param.stream, scatterOpInfoPtr)); + + HCCL_INFO("Broadcast double ring stage0 run success"); + + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0); + u32 level2RankSize = level2CommInfo.localRankSize; + + u64 hdCount = 0; + u64 hdSize = 0; + + if (topoAttr_.devNumInLevel2 <= 1) { + HCCL_INFO("Broadcast double ring No level2."); + // step2: server间的broadcast + u32 segmentIdx; + u32 commIndex; + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, outerCommInfo, mulRingSlice, param.tag)); + + hdCount = hdSize / perDataSize; + + HCCL_DEBUG("commIdx:%u TagCommInfo[%s].commInner.size():%llu", commIndex, param.tag.c_str(), level2RankSize); + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + std::unique_ptr innerExecutor; + u64 curSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + if (UseInterServerNHRAlgo(algType_)) { + HCCL_DEBUG("broadcast ring: curSize[%llu] deviceNumPerAggregation[%u] commOuterSize[%u]", + curSize, topoAttr_.deviceNumPerAggregation, outerCommInfo.localRankSize); + if (curSize / topoAttr_.deviceNumPerAggregation <= NHR_BCAST_SMALL_SIZE) { + innerExecutor.reset(new (std::nothrow) BroadcastNHROneshot(dispatcher_)); + } else { + innerExecutor.reset(new (std::nothrow) BroadcastNHR(dispatcher_)); + } + HCCL_INFO("broadcast ring: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) BroadcastNHRV1(dispatcher_)); + HCCL_INFO("broadcast ring: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + const u32 innerRankSize = innerCommInfo.localRankSize; + if (ShouldUseBinaryBroadcastOfNB(curSize / topoAttr_.deviceNumPerAggregation, innerRankSize, + topoAttr_.userRankSize, topoAttr_.deviceNumPerAggregation)) { + innerExecutor.reset(new (std::nothrow) BroadcastNBBinary(dispatcher_)); + } else { + innerExecutor.reset(new (std::nothrow) BroadcastNB(dispatcher_)); + } + HCCL_INFO("broadcast ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) BcastRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("broadcast ring: using Recursive halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + CHK_PRT_RET( + subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[HcclImpl][BroadCastDoubleRingExecutor]subUserrankRoot[%u] is invalid,userRank[%u],root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), + HCCL_E_INTERNAL); + u32 planeRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, planeRoot)); + u32 ranksize = innerCommInfo.localRankSize; + // 节点间的hd 使用环0来记录 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.outputMem, hdCount, param.DataDes.dataType, + param.stream, HCCL_REDUCE_RESERVED, planeRoot, std::vector(0), + dataSegsSlice[segmentIdx].offset)); + + CHK_RET(innerExecutor->RegisterProfiler((ranksize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + HCCL_INFO("Broadcast double ring stage1 run success"); + + // step3: 节点内的allgatherring + HcomCollOpInfo *allgatherOpInfoPtr = nullptr; + HcomCollOpInfo allgatherOpInfo = { + "", nullptr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, param.root + }; + if (DMAReduceFlag_) { + allgatherOpInfoPtr = &allgatherOpInfo; + } + CHK_RET(MultiRingAllGather(param.tag, execMem.inputMem, execMem.outputMem, hdCount, param.DataDes.dataType, + mulRingSlice, param.stream, PROF_STAGE_2, 0, allgatherOpInfoPtr)); + + HCCL_INFO("Broadcast double ring stage2 run success"); + } else { + HCCL_INFO("Broadcast double ring WITH level2."); + // step2: 节点间的scatter + /* count数据准备 */ + std::vector level1dataSegsSlice; // 数据分成inner ranksize份,每份的起始偏移和大小 + u32 commIndex = outerCommInfo.localRank; + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + // level 1通信数据量 + u64 level1count = execMem.count / sliceNum; + + // 按ranksize得到内存切分slice数 + u32 level1sliceNum = innerCommInfo.localRankSize; + // 将根节点数据切分成level1sliceNum份 + CHK_RET(ExecutorBase::PrepareSliceData(level1count, perDataSize, level1sliceNum, 0, level1dataSegsSlice)); + + u64 level1segmentIdx = innerCommInfo.localRank; + + DeviceMem level1InputMem + = execMem.inputMem.range(level1dataSegsSlice[level1segmentIdx].offset, (level1count * perDataSize)); + DeviceMem level1OutputMem + = execMem.outputMem.range(level1dataSegsSlice[level1segmentIdx].offset, (level1count * perDataSize)); + + std::unique_ptr level1Executor; + level1Executor.reset(new (std::nothrow) ScatterRing(dispatcher_)); + CHK_SMART_PTR_NULL(level1Executor); + CHK_RET(level1Executor->Prepare(level1InputMem, level1InputMem, level1OutputMem, level1count, param.DataDes.dataType, param.stream, + HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, level1dataSegsSlice, + level1dataSegsSlice[level1segmentIdx].offset)); + CHK_RET(level1Executor->RegisterProfiler((level1sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(level1Executor, innerCommInfo)); + HCCL_INFO("Broadcast double ring [superpod] level1 run success"); + + // step3: 超节点间的broadcast + u64 level2hdSize; + u32 level2segmentIdx; + u32 level2commIndex; + std::vector> multiRingSlice; + multiRingSlice.push_back(level1dataSegsSlice); + CHK_RET(PrepareInnerCommInfo(level2segmentIdx, level2commIndex, level2hdSize, outerCommInfo, + multiRingSlice, param.tag)); + + u64 level2hdCount = level2hdSize / perDataSize; + + CHK_RET(CheckCommSize(COMM_LEVEL2, level2commIndex + 1)); + + std::unique_ptr level2Executor; + if (UseLevel2RingAlgo(algType_)) { + level2Executor.reset(new (std::nothrow) BroadcastRing(dispatcher_)); + HCCL_INFO("broadcast ring: using ring algo inter-server."); + } else { + level2Executor.reset(new (std::nothrow) BcastRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("broadcast ring: using Recursive halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(level2Executor); + + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + CHK_PRT_RET( + subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[BroadCastOperator][BroadCastDoubleRingExecutor]subUserrankRoot[%u] is invalid,userRank[%u]," + "root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), + HCCL_E_INTERNAL); + + u32 planeRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL2, COMM_INDEX_0, subUserrankRoot, planeRoot)); + + u32 ranksize = level2CommInfo.localRankSize; + // 节点间的hd 使用环0来记录 + CHK_RET(level2Executor->Prepare(execMem.inputMem, execMem.inputMem, execMem.outputMem, level2hdCount, param.DataDes.dataType, + param.stream, HCCL_REDUCE_RESERVED, planeRoot, std::vector(0), + level1dataSegsSlice[level2segmentIdx].offset)); + + CHK_RET(level2Executor->RegisterProfiler((ranksize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2Executor, level2CommInfo)); + + // step4: 节点间的allgather + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("allgather ring: using halving-doubling algo inter-server."); + } + + CHK_SMART_PTR_NULL(innerExecutor); + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(innerExecutor->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, level2hdCount, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, + INVALID_VALUE_RANKID, std::vector(COMM_INDEX_0), 0)); + + u32 rankSize = innerCommInfo.localRankSize; + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + // step5: 节点内的allgatherring + u64 level0count = level2hdCount * rankSize; + HcomCollOpInfo *allgatherOpInfoPtr = nullptr; + HcomCollOpInfo allgatherOpInfo = { + "", nullptr, execMem.outputPtr, execMem.count, param.DataDes.dataType, param.root + }; + if (DMAReduceFlag_) { + allgatherOpInfoPtr = &allgatherOpInfo; + } + CHK_RET(MultiRingAllGather(param.tag, execMem.inputMem, execMem.outputMem, level0count, param.DataDes.dataType, mulRingSlice, param.stream, + PROF_STAGE_2, 0, allgatherOpInfoPtr)); + HCCL_INFO("Broadcast[superpod] double ring stage5 run success"); + } + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("BroadCastDoubleRingExecutor", BroadcastDoubleRing, CollBroadcastDoubleRingExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_double_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_double_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c50eea54f85f0143c8cbc3211c81feb2473aa55b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_double_ring_executor.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_BROADCAST_DOUBLE_RING_EXECUTOR_H +#define COLL_BROADCAST_DOUBLE_RING_EXECUTOR_H +#include "coll_broadcast_executor.h" +namespace hccl { +class CollBroadcastDoubleRingExecutor : public CollBroadcastExecutor { + +public: + CollBroadcastDoubleRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollBroadcastDoubleRingExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..49b28b89e3af40936db17e405fd9e80c6d8bf267 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_executor.cc @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_broadcast_executor.h" + +namespace hccl { + +CollBroadcastExecutor::CollBroadcastExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollBroadcastExecutor::Orchestrate(const OpParam& param, + const AlgResourceResponse& algRes) +{ + HcclResult ret = HCCL_SUCCESS; + + // 由于bcast/allgather/reducescatter/reduce/send/recv暂不支持server间ring,需继续使用HD或NHR + if (!UseInterServerNHRAlgo(algType_) && !UseInterServerNHRV1Algo(algType_) && !UseInterServerNBAlgo(algType_)) { + ret = SetInterServerHDAlgo(algType_); + HCCL_WARNING("[BroadCastOperator][Broadcast] do not support ring in AlgoLevel1 yet, reset algType=HD."); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[BroadCastOperator][Broadcast]errNo[0x%016llx] tag[%s],broadcast set inter server "\ + "halving-doubling algo failed", HCCL_ERROR_CODE(ret), param.tag.c_str()), ret); + } + + tag_ = param.tag; + algResResp_ = &algRes; + GetStreamInfo(algRes); + auto rtStream = param.stream.ptr(); + HCCL_PROFILER_ADD_TAG(param.tag, algoAttr_.identifier, GetWorkflowMode()); + + // 添加从流profiling, 用于维护planID + CHK_RET(AddSubStreamToProfiling()); + + /* ------------执行算法-------------- */ + HcclUs startut = TIME_NOW(); + + // 图模式和单卡场景下不需要Loop + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.inputPtr; + HCCL_INFO("Orchestrate UserRank[%u], devicePhyId[%u], inputPtr_[%p], outputPtr[%p]", topoAttr_.userRank, + topoAttr_.devicePhyId, param.inputPtr, param.outputPtr); + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { // 图模式直接调KernelRun接口 + execMem.inputMem = algRes.paramInputMem; + execMem.outputMem = algRes.paramOutputMem; + ret = KernelRun(param, execMem); + } else if (topoAttr_.userRankSize == 1) { // 单卡 + return HCCL_SUCCESS; + } else { + ret = RunLoop(param, algRes); + } + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollBroadcastExecutor][Orchestrate]errNo[0x%016llx]broadcast excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_DEL_TAG(param.tag); + HCCL_PROFILER_DEL_STREAM(rtStream); + } + HCCL_INFO("tag[%s], Broadcast executor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + + u8 *curInputPtr = static_cast(param.inputPtr); + u8 *curOutputPtr = static_cast(param.outputPtr); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + u64 maxCountPerLoop = CalcLoopMaxCount(unitSize); + + HCCL_DEBUG("[CollBroadcastExecutor][RunLoop]tag[%s], userRankSize is [%u], maxCountPerLoop is [%llu].", + param.tag.c_str(), topoAttr_.userRankSize, maxCountPerLoop); + + for (u64 countLeft = param.DataDes.count, curCount = 0, inputOffset = 0; + countLeft > 0; countLeft -= curCount) { + curInputPtr += inputOffset; + // 判断剩余数据量对应的output size是否大于中转output size + curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; + u64 curSize = curCount * unitSize; // 单位:字节 + + ExecMem execMem; + execMem.count = curCount; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclInputMem; // broadcast只用一块CCL buffer + // 使用当前Loop偏移到的地址作为当前的inputPtr + execMem.inputPtr = curInputPtr; + execMem.outputPtr = curInputPtr; + + HCCL_DEBUG("[CollBroadcastExecutor] RunLoop tag[%s], inputOffset[%llu], " \ + "curInputPtr[%p], sendCount[%llu], sendSize[%llu], dataType[%s], realUserRank[%llu]", + param.tag.c_str(), inputOffset, curInputPtr, curCount, curSize, + GetDataTypeEnumStr(param.DataDes.dataType).c_str(), topoAttr_.realUserRank); + + RunLoopInner(param, execMem); + + inputOffset = curSize; + } + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastExecutor::RunLoopInner(const OpParam ¶m, ExecMem &execMem) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + bool isRootRank = param.root == topoAttr_.realUserRank ? true : false; + u64 curSize = execMem.count * unitSize; // 单位:字节 + auto inCCLbufferSize = execMem.inputMem.size(); + u8 *curPtr = static_cast(execMem.inputPtr); + auto originalAlgTypeLevel0 = GetLevel0AlgType(algType_); + bool isMeshTopo = IsAlgTypeLevel0Mesh(originalAlgTypeLevel0); + bool isDMAreduceOn91073 = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE + && (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) && !isMeshTopo); + HCCL_DEBUG("[CollBroadcastExecutor][RunLoopInner]inputMem[%p], outputMem[%p]" \ + "intputPtr[%p], curCount[%llu], curSize[%llu]", + execMem.inputMem.ptr(), execMem.outputMem.ptr(), execMem.inputPtr, execMem.count, curSize); + CHK_PRT_RET((execMem.count == 0), + HCCL_ERROR("[CollBroadcastExecutor][RunLoop]In OP_BASE curCount is zero."), HCCL_E_PARA); + + bool hugeData = (inCCLbufferSize / topoAttr_.deviceNumPerAggregation > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + bool isSmallData = IsBroadcastSmallData(curSize); + auto meta = HcclOpMetaInfo::GetOneForBroadcast(isRootRank, param.root, hugeData, isSmallData); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), meta.isEnableCache, meta.GetCacheKey())); + HCCL_INFO("RunLoopInner:curPtr[%p], curCount[%llu], curSize[%llu], isSmallData[%u], " + "deviceNumPerAggregation[%u]", curPtr, execMem.count, curSize, isSmallData, + topoAttr_.deviceNumPerAggregation); + /* 记录指令信息用于一致性校验 */ + CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_BROADCAST, param.tag, execMem.count, + param.DataDes.dataType, param.root, inCCLbufferSize, 0)); + + // 执行 + + HcclResult ret; + + // isDMAreduceOn91073场景 + if (isDMAreduceOn91073) { + ret = KernelRun(param, execMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollBroadcastExecutor][RunLoop]errNo[0x%016llx] DMA reduce 91073, tag[%s]", + HCCL_ERROR_CODE(ret), tag_.c_str()), ret); + } else { + // 如果使用in CCL buffer,需要将user buffer in中的结果拷贝到CCL buffer in + DeviceMem inCommMem = execMem.inputMem.range(0, curSize); + DeviceMem inMem(execMem.inputPtr, curSize); + if (topoAttr_.userRank == param.root) { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, inCommMem, inMem, const_cast(param.stream))); + } + HCCL_DEBUG("[CollBroadcastExecutor][RunLoop]copy from user in to ccl in."); + + ret = KernelRun(param, execMem); + if (topoAttr_.realUserRank != param.root) { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, inMem, inCommMem, const_cast(param.stream))); + } + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollBroadcastExecutor][RunLoop]errNo[0x%016llx]kernel run error, tag[%s], " \ + "inputMem ptr[%p], count[%llu], dataType[%d]", + HCCL_ERROR_CODE(ret), param.tag.c_str(), execMem.inputMem.ptr(), + execMem.count, param.DataDes.dataType), ret); + } + + CHK_RET(RankConsistent::GetInstance().DelOpPara(param.tag)); + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + return ret; +} + +u64 CollBroadcastExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = GetExternalInputCCLBuffSize() / unitSize; + HCCL_WARNING("[CollBroadcastExecutor][CalcLoopMaxCount]" \ + "using default maxCountPerLoop[%llu] as CCLBuffSize / unitSize.", maxCountPerLoop); + return maxCountPerLoop; +} + +bool CollBroadcastExecutor::IsBroadcastSmallData(u64 size) +{ + const AlgTypeLevel0 algLevel0 = GetLevel0AlgType(algType_); + + u64 actualSize; + u64 actualRankSize; + + if (algLevel0 == AlgTypeLevel0::ALG_LEVEL0_RESERVED) { + // level0算法配null走单层拓扑场景 + actualSize = size; + actualRankSize = topoAttr_.userRankSize; + } else { + // 非单层拓扑场景 + actualSize = size / topoAttr_.deviceNumPerAggregation; + actualRankSize = topoAttr_.userRankSize / topoAttr_.deviceNumPerAggregation; + } + + if (UseInterServerNHRAlgo(algType_)) { + return actualSize <= NHR_BCAST_SMALL_SIZE; + } else if (UseInterServerNBAlgo(algType_)) { + return ShouldUseBinaryBroadcastOfNB(actualSize, actualRankSize, topoAttr_.userRankSize, + topoAttr_.deviceNumPerAggregation); + } + return false; +} + +HcclResult CollBroadcastExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_INPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_INPUT; + } + HCCL_INFO("[CollBroadcastRingExecutor][CalcTransportMemType] tag[%s] inputType[%d] outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastExecutor::GetRankSliceSize(HcclDataType dataType, const u64 count, const u32 rankSize, + std::vector &sliceList) +{ + if (rankSize <= 0) { + HCCL_ERROR("[Get][RankSliceSize]errNo[0x%016llx] rankSize[%u] is invalid", HCCL_ERROR_CODE(HCCL_E_PARA), + rankSize); + return HCCL_E_PARA; + } + + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(dataType, perDataSize)); + + u64 align = (count * perDataSize) / rankSize; // 按128字节对齐整除均分 + if ((count % rankSize) > 0) { + align += 1; + } + + u64 sliceSize = ExecutorBase::RoundUpWithDivisor(align, HCCL_MIN_SLICE_ALIGN); + u64 residueSize = count * perDataSize; + + for (u32 i = 0; i < rankSize; i++) { + Slice slice; + slice.size = sliceSize < residueSize ? sliceSize : residueSize; + slice.offset = (slice.size == 0) ? 0 : (i * sliceSize); + residueSize -= slice.size; + + // 将cout转换为字节数 + sliceList.push_back(slice); + } + + return HCCL_SUCCESS; +} + +bool CollBroadcastExecutor::IsAlgTypeLevel0Mesh(AlgTypeLevel0 &originalAlgTypeLevel0) const +{ + return originalAlgTypeLevel0 == AlgTypeLevel0::ALG_LEVEL0_NP_MESH || + originalAlgTypeLevel0 == AlgTypeLevel0::ALG_LEVEL0_4P_MESH || + originalAlgTypeLevel0 == AlgTypeLevel0::ALG_LEVEL0_2P_MESH || + originalAlgTypeLevel0 == AlgTypeLevel0::ALG_LEVEL0_1P_MESH; +} + +HcclResult CollBroadcastExecutor::SetInterServerHDAlgo(AlgType &algType) const +{ + switch (algType) { + case AlgType::ALG_8P_RING_PLUS_PIPELINE: + case AlgType::ALG_8P_RING_PLUS_RING: + case AlgType::ALG_8P_RING_PLUS_NHR: + case AlgType::ALG_8P_RING_PLUS_NHR_V1: + case AlgType::ALG_8P_RING_PLUS_NB: + algType = AlgType::ALG_8P_RING_PLUS_HD; + break; + + case AlgType::ALG_4P_MESH_PLUS_PIPELINE: + case AlgType::ALG_4P_MESH_PLUS_RING: + case AlgType::ALG_4P_MESH_PLUS_NHR: + case AlgType::ALG_4P_MESH_PLUS_NHR_V1: + case AlgType::ALG_4P_MESH_PLUS_NB: + algType = AlgType::ALG_4P_MESH_PLUS_HD; + break; + + case AlgType::ALG_2P_MESH_PLUS_PIPELINE: + case AlgType::ALG_2P_MESH_PLUS_RING: + case AlgType::ALG_2P_MESH_PLUS_NHR: + case AlgType::ALG_2P_MESH_PLUS_NHR_V1: + case AlgType::ALG_2P_MESH_PLUS_NB: + algType = AlgType::ALG_2P_MESH_PLUS_HD; + break; + + case AlgType::ALG_1P_MESH_PLUS_PIPELINE: + case AlgType::ALG_1P_MESH_PLUS_RING: + case AlgType::ALG_1P_MESH_PLUS_NHR: + case AlgType::ALG_1P_MESH_PLUS_NHR_V1: + case AlgType::ALG_1P_MESH_PLUS_NB: + algType = AlgType::ALG_1P_MESH_PLUS_HD; + break; + + case AlgType::ALG_4P_RING_PLUS_PIPELINE: + case AlgType::ALG_4P_RING_PLUS_RING: + case AlgType::ALG_4P_RING_PLUS_NHR: + case AlgType::ALG_4P_RING_PLUS_NHR_V1: + case AlgType::ALG_4P_RING_PLUS_NB: + algType = AlgType::ALG_4P_RING_PLUS_HD; + break; + + case AlgType::ALG_NP_SINGLE_RING_PLUS_PIPELINE: + case AlgType::ALG_NP_SINGLE_RING_PLUS_RING: + case AlgType::ALG_NP_SINGLE_RING_PLUS_NHR: + case AlgType::ALG_NP_SINGLE_RING_PLUS_NHR_V1: + case AlgType::ALG_NP_SINGLE_RING_PLUS_NB: + algType = AlgType::ALG_NP_SINGLE_RING_PLUS_HD; + break; + + case AlgType::ALG_NP_MESH_PLUS_PIPELINE: + case AlgType::ALG_NP_MESH_PLUS_RING: + case AlgType::ALG_NP_MESH_PLUS_NHR: + case AlgType::ALG_NP_MESH_PLUS_NHR_V1: + case AlgType::ALG_NP_MESH_PLUS_NB: + algType = AlgType::ALG_NP_MESH_PLUS_HD; + break; + + case AlgType::ALG_NP_DOUBLE_RING_PLUS_PIPELINE: + case AlgType::ALG_DOUBLE_RING_PLUS_RING: + algType = AlgType::ALG_DOUBLE_RING_PLUS_HD; + break; + default: + break; + } + return HCCL_SUCCESS; +} + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..930c38bbfdf827ac10cf25b3237bf69145af1e24 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_executor.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_BROADCAST_COMM_EXECUTOR_H +#define COLL_BROADCAST_COMM_EXECUTOR_H +#include "coll_comm_executor.h" +#include "coll_alg_operator.h" + +namespace hccl { +class CollBroadcastExecutor : public CollCommExecutor { + +public: + CollBroadcastExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollBroadcastExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; +protected: + /* *************** 算法编排 *************** */ + // Broadcast Loop Executor公共接口 + HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + bool IsBroadcastSmallData(u64 size); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + u64 CalcLoopMaxCount(const u32 unitSize); + HcclResult GetRankSliceSize(HcclDataType dataType, const u64 count, const u32 rankSize, + std::vector &sliceList); + bool IsAlgTypeLevel0Mesh(AlgTypeLevel0 &originalAlgTypeLevel0) const; + HcclResult SetInterServerHDAlgo(AlgType &algType) const; + bool DMAReduceFlag_{false}; // 是否DMA消减 + +private: + HcclResult RunLoopInner(const OpParam ¶m, ExecMem &execMem); +}; +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_mesh_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..8687f9fd68179cba179c7a0839eb32946a5a8e21 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_mesh_executor.cc @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + #include "coll_broadcast_mesh_executor.h" + + namespace hccl { + +CollBroadcastMeshExecutor::CollBroadcastMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollBroadcastExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollBroadcastMeshExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + switch(algType_) { + case AlgType::ALG_4P_MESH_PLUS_HD: + case AlgType::ALG_4P_MESH_PLUS_RING: + case AlgType::ALG_4P_MESH_PLUS_NHR: + case AlgType::ALG_4P_MESH_PLUS_NHR_V1: + case AlgType::ALG_4P_MESH_PLUS_NB: + totalStreamNum = OUTER_PLANE_NUM_IN_4PMESH; + break; + default: + if ((GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B) && topoAttr_.isSingleMeshAggregation ) { + totalStreamNum = topoAttr_.deviceNumPerAggregation; + } else if ((topoAttr_.deviceType == DevType::DEV_TYPE_910_73)) { // && (isAicpuModeEn == true) + totalStreamNum = topoAttr_.deviceNumPerAggregation; + } else if ((GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B) && UseInterServerPipelineAlgo(algType_)) { + totalStreamNum = topoAttr_.deviceNumPerAggregation + 1; /* pipeline ring场景下性能优化 */ + } else { + totalStreamNum = topoAttr_.deviceNumPerAggregation - 1; + } + break; + } + + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollBroadcastMeshExecutor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastMeshExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastMeshExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = true; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + std::unique_ptr outer1Executor; + std::unique_ptr innerExecutor; + std::unique_ptr outer2Executor; + + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = outerCommInfo.localRank; + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + + outer1Executor.reset( + new (std::nothrow) ScatterMesh(dispatcher_, outerCommInfo.localRank, outerCommInfo.localRankSize)); + CHK_SMART_PTR_NULL(outer1Executor); + + /* 内层topo:all_reduce */ + /* 外层所有rank均参与内层的broadcast计算,所以此处对rank不作限制,但是每个rank需找到自己所在的内层通信域 */ + std::vector slice; + CHK_RET(GetRankSliceSize(param.DataDes.dataType, execMem.count, outerCommInfo.localRankSize, slice)); + + CHK_PRT_RET(slice.empty(), HCCL_ERROR("[BroadCastOperator][BroadCastMeshExecutor]got slice is empty"), + HCCL_E_INTERNAL); + + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + u64 curSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + if (UseInterServerNHRAlgo(algType_)) { + HCCL_DEBUG("broadcast mesh: curSize[%llu] deviceNumPerAggregation[%u] commOuterSize[%u]", + curSize, topoAttr_.deviceNumPerAggregation, outerCommInfo.localRankSize); + if (curSize / topoAttr_.deviceNumPerAggregation <= NHR_BCAST_SMALL_SIZE) { + innerExecutor.reset(new (std::nothrow) BroadcastNHROneshot(dispatcher_)); + } else { + innerExecutor.reset(new (std::nothrow) BroadcastNHR(dispatcher_)); + } + HCCL_INFO("broadcast mesh: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) BroadcastNHRV1(dispatcher_)); + HCCL_INFO("broadcast mesh: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + const u32 innerRankSize = innerCommInfo.localRankSize; + if (ShouldUseBinaryBroadcastOfNB(curSize / topoAttr_.deviceNumPerAggregation, innerRankSize, + topoAttr_.userRankSize, topoAttr_.deviceNumPerAggregation)) { + innerExecutor.reset(new (std::nothrow) BroadcastNBBinary(dispatcher_)); + } else { + innerExecutor.reset(new (std::nothrow) BroadcastNB(dispatcher_)); + } + HCCL_INFO("broadcast mesh: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) BcastRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("broadcast mesh: using Recursive halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + /* 外层topo:all_gather */ + if (topoAttr_.deviceType == DevType::DEV_TYPE_910B) { + outer2Executor.reset( + new (std::nothrow) AllGatherMeshAtomic(dispatcher_, streamInfo_.ringStreams, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, outerCommInfo.localRank, outerCommInfo.localRankSize, + topoAttr_.userRank)); + } else { + outer2Executor.reset( + new (std::nothrow) AllGatherMesh(dispatcher_, streamInfo_.ringStreams, streamInfo_.ringSignal, + streamInfo_.ringSignalAux, outerCommInfo.localRank, outerCommInfo.localRankSize, + topoAttr_.userRank)); + } + CHK_SMART_PTR_NULL(outer2Executor); + + /* 节点内执行器 stage0 */ + u32 rootRank = 0; + HcclResult ret = GetRankByUserRank(COMM_LEVEL0, COMM_INDEX_0, param.root, rootRank); + CHK_PRT_RET(ret == HCCL_E_PARA, + HCCL_ERROR("[BroadCastOperator][BroadCastMeshExecutor]invalid root[%u] to get userrank", param.root), ret); + + if (ret == HCCL_SUCCESS) { + CHK_RET(outer1Executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, rootRank, slice)); + + u32 rankSize = outerCommInfo.localRankSize; + CHK_RET(outer1Executor->RegisterProfiler( + (0 << PROF_RINGINDEX_OFFSET_OF_PLANEID)+(rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + + outerCommInfo.localRank, PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(outer1Executor, outerCommInfo)); + } else { + HCCL_ERROR("[BroadCastOperator][BroadCastMeshExecutor]invalid root[%u] to get userrank", param.root); + } + HCCL_INFO("broadcast meshhd stage0 run success"); + u64 hdCount = slice[outerCommInfo.localRank].size / perDataSize; + /* 节点间执行器 stage1 */ + + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + CHK_PRT_RET(subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[BroadCastOperator][BroadCastMeshExecutor]subUserrankRoot[%u] is invalid,userRank[%u],root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), + HCCL_E_INTERNAL); + + u32 subRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, subRoot)); + + // 增加偏移参数 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, hdCount, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, subRoot, + std::vector(0), slice[outerCommInfo.localRank].offset)); + + u32 rankSize = innerCommInfo.localRankSize; + CHK_RET(innerExecutor->RegisterProfiler((0 << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + HCCL_INFO("broadcast meshhd stage1 run success"); + + /* 节点内执行器 stage2 */ + { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline + for (u32 streamIndex = 0; streamIndex < streamInfo_.ringStreams.size(); streamIndex++) { + CHK_RET(StreamActiveManager::GetInstance(topoAttr_.deviceLogicId).StreamActive( + streamInfo_.ringStreams[streamIndex].ptr(), param.stream.ptr())); + } + } + CHK_RET(outer2Executor->Prepare(execMem.outputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, slice)); + + u32 rankSize = outerCommInfo.localRankSize; + CHK_RET(outer2Executor->RegisterProfiler((0 << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(outer2Executor, outerCommInfo)); + } + + HCCL_INFO("broadcast meshhd stage2 run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("BroadCastMeshExecutor", BroadcastMesh, CollBroadcastMeshExecutor); + + } // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_mesh_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..def04b0e315f35fcc1986e358a6a58c42b1f991f --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_mesh_executor.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_BROADCAST_MESH_EXECUTOR_H +#define COLL_BROADCAST_MESH_EXECUTOR_H +#include "coll_broadcast_executor.h" +namespace hccl { +class CollBroadcastMeshExecutor : public CollBroadcastExecutor { + +public: + CollBroadcastMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollBroadcastMeshExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf4102544ffb66d04a6434b2d88b15b7e3a7a120 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_ring_executor.cc @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_broadcast_ring_executor.h" + +namespace hccl { + +CollBroadcastRingExecutor::CollBroadcastRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollBroadcastExecutor(dispatcher, topoMatcher) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + DMAReduceFlag_ = true; + } else { + DMAReduceFlag_ = false; + } +} + +HcclResult CollBroadcastRingExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? OUTER_PLANE_NUM_IN_8PRING : + OUTER_PLANE_NUM_IN_NPRING_SINGLE; + + if (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + } + } + + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollBroadcastRingExecutor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport) +{ + HCCL_INFO("[CollBroadcastRingExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollBroadcastRingExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollBroadcastRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollBroadcastRingExecutor][Run]The CollBroadcastRingExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multRingsSliceZero; // 数据基于该rank上环0的偏移 + // step1: 节点内的scatter + u32 ringNum = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? OUTER_PLANE_NUM_IN_8PRING : + OUTER_PLANE_NUM_IN_NPRING_SINGLE; + + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + + SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + // 按ranksize得到内存切分slice数 + u32 sliceNum = level0CommInfo.localRankSize; + // 将根节点数据切分成sliceNum份 + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + HCCL_DEBUG("[CollBroadcastRingExecutor][KernelRun] execMem.count[%llu], perDataSize[%llu], sliceNum[%llu], ringNum[%llu] ", + execMem.count, perDataSize, sliceNum, ringNum); + + /* 外层:scatter */ + // 将每slice再切分成4份,按各ring的dev顺序排列 + if (ringNum == OUTER_PLANE_NUM_IN_8PRING) { + // 构造ring algorithm对应的scatter实例 + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, param.tag, false, topoAttr_.nicList); + CHK_PRT_RET(multRingsSliceZero.size() != ringNum, HCCL_ERROR("[CollBroadcastRingExecutor]"\ + "ringNum[%u] !=multRingsSliceZero size[%llu]", ringNum, multRingsSliceZero.size()), HCCL_E_INTERNAL); + } else { + multRingsSliceZero.push_back(dataSegsSlice); // 应该offset全为0,而大小和dataSegsSlice中一样,里面的offset不使用 + } + + HcomCollOpInfo *scatterOpInfoPtr = nullptr; + HcomCollOpInfo scatterOpInfo = { + "", execMem.inputPtr, nullptr, execMem.count, param.DataDes.dataType, param.root + }; + if (DMAReduceFlag_) { + scatterOpInfoPtr = &scatterOpInfo; + } + CHK_RET(MultiRingScatter(param.tag, execMem.inputMem, execMem.outputMem, execMem.count, param.DataDes.dataType, + multRingsSliceZero, param.root, param.stream, scatterOpInfoPtr)); + + HCCL_INFO("broadcast 8PringHD stage0 run success"); + + // step2: 节点间的broadcast + u64 hdSize; + u32 segmentIdx; + u32 commIndex; + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, level0CommInfo, multRingsSliceZero, param.tag)); + + u64 hdCount = hdSize / perDataSize; + auto nicList = topoAttr_.nicList; + bool isMultiNic = topoType_ == TopoType::TOPO_TYPE_8P_RING && nicList.size() != DEVICE_EIGHT; + std::vector::iterator iterNic = std::find(nicList.begin(), nicList.end(), topoAttr_.devicePhyId); + bool innRunRet = isMultiNic && (iterNic == nicList.end()); + if (!innRunRet) { // 满足以下条件, 不做server间通信: 1. 8P ring的拓扑 2. 网口不满配 3. 当前device不出网口 + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo level1CommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u64 curSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + std::unique_ptr innerExecutor; + if (UseInterServerNHRAlgo(algType_)) { + HCCL_DEBUG("broadcast ring: curSize[%llu] deviceNumPerAggregation[%u] commOuterSize[%u]", + curSize, topoAttr_.deviceNumPerAggregation, level0CommInfo.localRankSize); + if (curSize / topoAttr_.deviceNumPerAggregation <= NHR_BCAST_SMALL_SIZE) { + innerExecutor.reset(new (std::nothrow) BroadcastNHROneshot(dispatcher_)); + } else { + innerExecutor.reset(new (std::nothrow) BroadcastNHR(dispatcher_)); + } + HCCL_INFO("broadcast ring: using nhr algo inter-server."); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) BroadcastNHRV1(dispatcher_)); + HCCL_INFO("broadcast ring: using nhr_v1 algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + const u32 innerRankSize = level1CommInfo.localRankSize; + if (ShouldUseBinaryBroadcastOfNB(curSize / topoAttr_.deviceNumPerAggregation, innerRankSize, + topoAttr_.userRankSize, topoAttr_.deviceNumPerAggregation)) { + innerExecutor.reset(new (std::nothrow) BroadcastNBBinary(dispatcher_)); + } else { + innerExecutor.reset(new (std::nothrow) BroadcastNB(dispatcher_)); + } + HCCL_INFO("broadcast ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) BcastRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("broadcast ring: using Recursive halving-doubling algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + CHK_PRT_RET(subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[BroadCastOperator][BroadCastRingExecutor]subUserrankRoot[%u] is invalid,userRank[%u],root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), HCCL_E_INTERNAL); + u32 planeRoot = 0; + u32 level1RankSize = level1CommInfo.localRankSize; + u32 level1LocalRank = level1CommInfo.localRank; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, planeRoot)); + + // 节点间的hd 使用环0来记录 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.outputMem, hdCount, param.DataDes.dataType, + param.stream, HCCL_REDUCE_RESERVED, planeRoot, std::vector(0), dataSegsSlice[segmentIdx].offset)); + + CHK_RET(innerExecutor->RegisterProfiler((level1RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level1LocalRank, \ + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, level1CommInfo)); + } + HCCL_INFO("broadcast 8PringHD stage1 run success"); + + // step3: 节点内的allgatherring + HcomCollOpInfo *allgatherOpInfoPtr = nullptr; + HcomCollOpInfo allgatherOpInfo = { + "", nullptr, execMem.outputPtr, execMem.count, param.DataDes.dataType, param.root, HCCL_REDUCE_RESERVED + }; + if (DMAReduceFlag_) { + allgatherOpInfoPtr = &allgatherOpInfo; + } + CHK_RET(MultiRingAllGather(param.tag, execMem.inputMem, execMem.outputMem, hdCount, param.DataDes.dataType, + multRingsSliceZero, param.stream, PROF_STAGE_2, 0, allgatherOpInfoPtr)); + + HCCL_INFO("broadcast 8PringHD stage2 run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("BroadCastRingExecutor", BroadcastRing, CollBroadcastRingExecutor); + + } // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..2dd3dd270d07f9be88faa605e4a538ca556d8e84 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_broadcast/coll_broadcast_ring_executor.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_BROADCAST_RING_EXECUTOR_H +#define COLL_BROADCAST_RING_EXECUTOR_H +#include "coll_broadcast_executor.h" +namespace hccl { +class CollBroadcastRingExecutor : public CollBroadcastExecutor { + +public: + CollBroadcastRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollBroadcastRingExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.cc index 86914de216d2b9ec4073acf99ae7adfd9a7e75df..9ddb7506bf0dbfe0d3a5b0902f3fcfdf5a2f978d 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.cc @@ -16,8 +16,8 @@ #include "externalinput_pub.h" namespace hccl { -CollCommExecutor::CollCommExecutor(std::unique_ptr &pImpl) - : CollNativeExecutorBase(pImpl) +CollCommExecutor::CollCommExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher) + : CollNativeExecutorBase(dispatcher, topoMatcher) { } @@ -26,7 +26,12 @@ HcclResult CollCommExecutor::GetSubStreamInfoOnOneRing(const innerStreamInfo_t & std::vector> &mainSignalsInOneRing, std::vector> &subSignalsInOneRing) { - if (streamInfo.ringNum == OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING) { + if (topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent() && (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) && + (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING)) { + subStreamsInOneRing.push_back(streamInfo.ringStreams[ringIndex + RDMA_ADD_STREAMS_NUM]); + mainSignalsInOneRing.push_back(streamInfo.ringSignal[ringIndex + RDMA_ADD_STREAMS_NUM]); + subSignalsInOneRing.push_back(streamInfo.ringSignalAux[ringIndex + RDMA_ADD_STREAMS_NUM]); + } else if (streamInfo.ringNum == OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING) { // double ring subStreamsInOneRing.push_back(streamInfo.ringStreams[ringIndex + 1]); mainSignalsInOneRing.push_back(streamInfo.ringSignal[ringIndex + 1]); @@ -40,16 +45,117 @@ HcclResult CollCommExecutor::GetSubStreamInfoOnOneRing(const innerStreamInfo_t & return HCCL_SUCCESS; } +HcclResult CollCommExecutor::MultiRingAllReduce(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, + const std::vector> &multRingsSliceZero, Stream stream, s32 profStage, + const u64 baseOffset) +{ + HcclResult ret = HCCL_SUCCESS; + u32 ringNum = multRingsSliceZero.size(); + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + + u64 reduceAttr = GetReduceAttr(inputMem, outputMem, dataType, reductionOp); + + std::vector> ringNics; + CHK_RET(GetRingNics(tag, ringNics)); + + for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { + std::vector singleRingSliceZero = multRingsSliceZero[ringIndex]; + CHK_PRT_RET(singleRingSliceZero.empty(), + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]singleRingSliceZero is empty"), HCCL_E_INTERNAL); + + SubCommInfo outerRingCommInfo = GetSubCommInfo(COMM_LEVEL0, ringIndex); + + u32 rankSize = outerRingCommInfo.localRankSize; + u32 ringIndexOp = ringIndex; + std::unique_ptr executor; + executor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(executor); + + if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline + CHK_RET(StreamActiveManager::GetInstance(topoAttr_.deviceLogicId).StreamActive( + streamInfo_.ringStreams[ringIndex].ptr(), stream.ptr())); + } + + ret = LocalNotify::Wait(streamInfo_.ringStreams[ringIndex], dispatcher_, + streamInfo_.ringSignalAux[ringIndex], profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u] wait failed", + ringIndex), ret); + ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, + streamInfo_.ringStreams[ringIndex], reductionOp, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, + baseOffset, ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u], allreduce(ring) prepare failed,"\ + "return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler( + ((ringIndexOp + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerRingCommInfo.localRank, + profStage, HCCL_EXEC_STEP_NOT_SET, streamInfo_.ringStreams[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u], allreduce(ring) register Profiler "\ + "failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u], allreduce(ring) run failed,"\ + "return[%d]", ringIndex, ret), ret); + + ret = LocalNotify::Post(streamInfo_.ringStreams[ringIndex], dispatcher_, streamInfo_.ringSignal[ringIndex], + profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u] record failed", ringIndex), ret); + + ret = LocalNotify::Post(stream, dispatcher_, streamInfo_.ringSignalAux[ringIndex], profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u] record failed", ringIndex), ret); + } else { // 主环 + executor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(executor); + ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, stream, + reductionOp, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, baseOffset, ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u], allreduce(ring) prepare failed, "\ + "return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler( + ((ringIndexOp + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerRingCommInfo.localRank, + profStage, HCCL_EXEC_STEP_NOT_SET, stream); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u], allreduce(ring) register Profiler "\ + "failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u], allreduce(ring) run failed, "\ + "return[%d]", ringIndex, ret), ret); + + for (u32 ring = 0; ring < (ringNum - 1); ring++) { + /* 等待executor执行完毕 */ + ret = LocalNotify::Wait(stream, dispatcher_, streamInfo_.ringSignal[ring], profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingAllReduce]stream[%u] wait failed", ring), ret); + } + } + } + // 添加空task,保证执行时不乱序 + CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); + return HCCL_SUCCESS; +} + HcclResult CollCommExecutor::MultiRingAllGather(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const std::vector > multRingsSliceZero, - Stream stream, s32 profStage, const u64 baseOffset, const HcomCollOpInfo *opInfo) + Stream stream, s32 profStage, const u64 baseOffset, const HcomCollOpInfo *opInfo, + const std::vector> multRingsUserMemSlice) { HcclResult ret = HCCL_SUCCESS; u32 ringNum = multRingsSliceZero.size(); CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); + CHK_RET(GetRingNics(tag, ringNics)); // 拿到ring环映射关系 SubCommInfo outerZeroCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); auto nicList = topoAttr_.nicList; @@ -63,9 +169,14 @@ HcclResult CollCommExecutor::MultiRingAllGather(const std::string &tag, DeviceMe CHK_PRT_RET(singleRingSliceZero.empty(), HCCL_ERROR("[CollCommExecutor][MultiRingAllGather]"\ "singleRingSliceZero is empty"), HCCL_E_INTERNAL); + // 910_73场景 生成userMemOut_上对应的slices std::vector userMemOutputSlices; - CHK_RET( - CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, userMemOutputSlices)); + if (multRingsUserMemSlice.size() == 0) { + CHK_RET(CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, + userMemOutputSlices)); + } else { + userMemOutputSlices = multRingsUserMemSlice[ringIndex]; + } std::vector rankOrder; CHK_RET(GetRankOrder(multiRingsOrder, ringIndex, rankOrder)); @@ -74,6 +185,7 @@ HcclResult CollCommExecutor::MultiRingAllGather(const std::string &tag, DeviceMe u32 rankSize = outerRingCommInfo.localRankSize; u32 ringIndexOp = ringIndex; + // 910_73场景 准备环中的从流 std::vector subStreamsInOneRing; std::vector> mainSignalsInOneRing; std::vector> subSignalsInOneRing; @@ -83,10 +195,11 @@ HcclResult CollCommExecutor::MultiRingAllGather(const std::string &tag, DeviceMe } std::vector> threadManage; if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(hcclImpl_->GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); + threadManage.resize(streamInfo_.ringNum - 1); + CHK_RET(GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); } if (ringIndex != (ringNum - 1)) { // 最后一个环是主stream,所以这里减1,符合条件的走从stream - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { if (opInfo != nullptr) { threadManage[ringIndex]->Prepare( @@ -178,7 +291,7 @@ HcclResult CollCommExecutor::MultiRingAllGather(const std::string &tag, DeviceMe "return[%d]", ringIndex, ret), ret); for (u32 ring = 0; ring < (ringNum - 1); ring++) { - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { threadManage[ring]->WaitDone(); // 单算子模式,等待线程处理完成信号 } @@ -202,7 +315,7 @@ HcclResult CollCommExecutor::MultiRingAllGatherConcurrent(const std::string &tag u32 ringNum = multRingsSliceZero.size(); // 环数, 当前为4环 std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); + CHK_RET(GetRingNics(tag, ringNics)); auto halfRingSize = ringNum; if (ringNum > RDMA_PLANE_NUM_IN_NPRING_DOUBLE) { halfRingSize = ringNum / 2; // 2环 @@ -220,6 +333,7 @@ HcclResult CollCommExecutor::MultiRingAllGatherConcurrent(const std::string &tag CHK_PRT_RET(singleRingSliceZero.empty(), HCCL_ERROR("[CommonOperator][MultiRingAllGatherConcurrent]"\ "singleRingSliceZero is empty"), HCCL_E_INTERNAL); + // 910_73场景 生成userMemOut_上对应的slices std::vector userMemOutputSlices; CHK_RET( CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, userMemOutputSlices)); @@ -233,6 +347,7 @@ HcclResult CollCommExecutor::MultiRingAllGatherConcurrent(const std::string &tag u32 rankSize = outerRingCommInfo.localRankSize; u32 ringIndexOp = ringIndex; + // 910_73场景 准备环中的从流 std::vector subStreamsInOneRing; std::vector> mainSignalsInOneRing; std::vector> subSignalsInOneRing; @@ -242,10 +357,11 @@ HcclResult CollCommExecutor::MultiRingAllGatherConcurrent(const std::string &tag } std::vector> threadManage; if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(hcclImpl_->GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); + threadManage.resize(streamInfo_.ringNum - 1); + CHK_RET(GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); } if (ringIndex != (ringNum - 1)) { // 最后一个环是主stream,所以这里减1,符合条件的走从stream - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { if (opInfo != nullptr) { threadManage[ringIndex]->Prepare( @@ -338,7 +454,7 @@ HcclResult CollCommExecutor::MultiRingAllGatherConcurrent(const std::string &tag "return[%d]", ringIndex, ret), ret); for (u32 ring = 0; ring < (ringNum - 1); ring++) { - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { threadManage[ring]->WaitDone(); // 单算子模式,等待线程处理完成信号 } @@ -356,14 +472,15 @@ HcclResult CollCommExecutor::MultiRingAllGatherConcurrent(const std::string &tag HcclResult CollCommExecutor::MultiRingReduceScatter(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, const std::vector > multRingsSliceZero, Stream stream, s32 profStage, - const u64 baseOffset, const HcomCollOpInfo *opInfo) + const u64 baseOffset, const HcomCollOpInfo *opInfo, + const std::vector> multRingsUserMemSlice) { HcclResult ret = HCCL_SUCCESS; u32 ringNum = multRingsSliceZero.size(); CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); + CHK_RET(GetRingNics(tag, ringNics)); // 拿到ring环映射关系 SubCommInfo outerZeroCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); auto nicList = topoAttr_.nicList; @@ -381,8 +498,12 @@ HcclResult CollCommExecutor::MultiRingReduceScatter(const std::string &tag, Devi // 生成userMemIn_上对应的slices std::vector userMemInputSlices; - CHK_RET( - CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, userMemInputSlices)); + if (multRingsUserMemSlice.size() == 0) { + CHK_RET(CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, + userMemInputSlices)); + } else { + userMemInputSlices = multRingsUserMemSlice[ringIndex]; + } std::vector rankOrder; CHK_RET(GetRankOrder(multiRingsOrder, ringIndex, rankOrder)); @@ -400,7 +521,8 @@ HcclResult CollCommExecutor::MultiRingReduceScatter(const std::string &tag, Devi } std::vector> threadManage; if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(hcclImpl_->GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); + threadManage.resize(streamInfo_.ringNum - 1); + CHK_RET(GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); } if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline @@ -410,7 +532,7 @@ HcclResult CollCommExecutor::MultiRingReduceScatter(const std::string &tag, Devi HCCL_ERROR("[CollCommExecutor][MultiRingReduceScatter]active stream[%u], failed", ringIndex), ret); } - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { /* 更新线程参数 */ if (opInfo != nullptr) { @@ -502,7 +624,7 @@ HcclResult CollCommExecutor::MultiRingReduceScatter(const std::string &tag, Devi HCCL_ERROR("[CollCommExecutor][MultiRingReduceScatter]stream[%u],reduce scatter(ring) run "\ "failed,return[%d]", ringIndex, ret), ret); for (u32 ring = 0; ring < (ringNum - 1); ring++) { - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { threadManage[ring]->WaitDone(); } @@ -519,6 +641,103 @@ HcclResult CollCommExecutor::MultiRingReduceScatter(const std::string &tag, Devi return HCCL_SUCCESS; } +HcclResult CollCommExecutor::MultiRingGather(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, + const u64 count, const HcclDataType dataType, const std::vector > multRingsSliceZero, + HcclReduceOp op, u32 root, Stream stream, s32 profStage) +{ + u32 ringNum = multRingsSliceZero.size(); + std::vector> ringNics; + CHK_RET(GetRingNics(tag, ringNics)); + + HcclResult ret; + + for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { + std::vector singleRingSliceZero = multRingsSliceZero[ringIndex]; + CHK_PRT_RET(singleRingSliceZero.empty(), + HCCL_ERROR("[CommonOperator][MultiRingGather]singleRingSliceZero is empty"), HCCL_E_INTERNAL); + + SubCommInfo outerRingCommInfo = GetSubCommInfo(COMM_LEVEL0, ringIndex); + u32 rankSize = outerRingCommInfo.localRankSize; + u32 rootRank = 0; + ret = GetRankByUserRank(COMM_LEVEL0, COMM_INDEX_0, root, rootRank); + CHK_PRT_RET(ret == HCCL_E_PARA, + HCCL_ERROR("[CommonOperator][MultiRingGather]invalid root rank[%u] to get user rank", root), ret); + + std::unique_ptr executor = std::make_unique(dispatcher_); + CHK_SMART_PTR_NULL(executor); + + if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline + CHK_RET(StreamActiveManager::GetInstance(topoAttr_.deviceLogicId).StreamActive( + streamInfo_.ringStreams[ringIndex].ptr(), stream.ptr())); + } + ret = LocalNotify::Wait(streamInfo_.ringStreams[ringIndex], dispatcher_, + streamInfo_.ringSignalAux[ringIndex], profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CommonOperator][MultiRingGather]in stream[%u] wait failed", \ + ringIndex), ret); + if (singleRingSliceZero[0].size != 0) { + ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, + streamInfo_.ringStreams[ringIndex], op, rootRank, singleRingSliceZero, 0, + ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u],gather(ring) prepare failed, "\ + "return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler(outerRingCommInfo.localRank, profStage, HCCL_EXEC_STEP_NOT_SET, + streamInfo_.ringStreams[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u], gather(ring) register profiler "\ + "failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u],gather(ring) run failed,return[%d]", + ringIndex, ret), ret); + } + ret = LocalNotify::Post(streamInfo_.ringStreams[ringIndex], dispatcher_, streamInfo_.ringSignal[ringIndex], + profStage); + + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u] record failed", \ + ringIndex), ret); + + ret = LocalNotify::Post(stream, dispatcher_, streamInfo_.ringSignalAux[ringIndex], profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u] record failed", \ + ringIndex), ret); + } else { // 主环 + executor.reset(new (std::nothrow) GatherRing(dispatcher_)); + CHK_SMART_PTR_NULL(executor); + + ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, stream, + op, rootRank, singleRingSliceZero, 0, ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u],gather(ring) prepare failed, "\ + "return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler(((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerRingCommInfo.localRank, + profStage, HCCL_EXEC_STEP_NOT_SET, stream); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u], gather(ring) register "\ + "profiler failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u],gather(ring) run failed, "\ + "return[%d]", ringIndex, ret), ret); + for (u32 ring = 0; ring < (ringNum - 1); ring++) { + /* 等待executor执行完毕 , 当前环没有分配数据,跳过此环处理,继续下一个环 */ + ret = LocalNotify::Wait(stream, dispatcher_, streamInfo_.ringSignal[ring], profStage); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CommonOperator][MultiRingGather]stream[%u] wait failed", ring), ret); + } + } + } + + // 添加空task,保证子图执行时不乱序 + CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); + return HCCL_SUCCESS; +} + HcclResult CollCommExecutor::MultiRingReduceScatterConcurrent(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, const std::vector>> multRingsSliceZero, Stream stream, s32 profStage, @@ -528,7 +747,7 @@ HcclResult CollCommExecutor::MultiRingReduceScatterConcurrent(const std::string u32 ringNum = multRingsSliceZero.size(); std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); + CHK_RET(GetRingNics(tag, ringNics)); u32 halfRingSize = ringNum; u32 DoubleRing = 2; if (ringNum > RDMA_PLANE_NUM_IN_NPRING_DOUBLE) { @@ -573,7 +792,8 @@ HcclResult CollCommExecutor::MultiRingReduceScatterConcurrent(const std::string } std::vector> threadManage; if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(hcclImpl_->GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); + threadManage.resize(streamInfo_.ringNum - 1); + CHK_RET(GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); } if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline @@ -584,7 +804,7 @@ HcclResult CollCommExecutor::MultiRingReduceScatterConcurrent(const std::string ringIndex), ret); } - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { /* 更新线程参数 */ if (opInfo != nullptr) { @@ -680,7 +900,7 @@ HcclResult CollCommExecutor::MultiRingReduceScatterConcurrent(const std::string HCCL_ERROR("[CollCommExecutor][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) run "\ "failed,return[%d]", ringIndex, ret), ret); for (u32 ring = 0; ring < (ringNum - 1); ring++) { - if (!GetExternalInputHcclEnableFfts() && + if (!topoMatcher_->GetExternalInputHcclEnableFfts() && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { threadManage[ring]->WaitDone(); } @@ -697,6 +917,109 @@ HcclResult CollCommExecutor::MultiRingReduceScatterConcurrent(const std::string return HCCL_SUCCESS; } +HcclResult CollCommExecutor::MultiRingMultiRootScatter(const std::string &tag, DeviceMem &inputMem, + DeviceMem &outputMem, const u64 count, const HcclDataType dataType, + const std::vector> &multRingsSliceZero, u32 root, Stream stream, const u64 baseOffset) +{ + HcclResult ret = HCCL_SUCCESS; + u32 ringNum = multRingsSliceZero.size(); + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + + std::vector> ringNics; + CHK_RET(GetRingNics(tag, ringNics)); + + for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { + std::vector singleRingSliceZero = multRingsSliceZero[ringIndex]; + CHK_PRT_RET(singleRingSliceZero.empty(), + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]singleRingSliceZero is empty"), HCCL_E_INTERNAL); + + SubCommInfo outerRingCommInfo = GetSubCommInfo(COMM_LEVEL0, ringIndex); + + u32 rankSize = outerRingCommInfo.localRankSize; + std::unique_ptr executor; + executor.reset(new (std::nothrow) MultiRootScatterRing(dispatcher_)); + CHK_SMART_PTR_NULL(executor); + + if (ringIndex != (ringNum - 1)) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline + CHK_RET(StreamActiveManager::GetInstance(topoAttr_.deviceLogicId).StreamActive( + streamInfo_.ringStreams[ringIndex].ptr(), stream.ptr())); + } + } + + u32 rootRank = 0; + ret = GetRankByUserRank(COMM_LEVEL0, ringIndex, root, rootRank); + CHK_PRT_RET(ret == HCCL_E_PARA, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]invalid root [%u] to get userrank", root), ret); + + if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 + ret = LocalNotify::Wait(streamInfo_.ringStreams[ringIndex], dispatcher_, + streamInfo_.ringSignalAux[ringIndex], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]in stream[%u] wait failed", ringIndex), ret); + + ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, + streamInfo_.ringStreams[ringIndex], HcclReduceOp::HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, + singleRingSliceZero, baseOffset, ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) "\ + "prepare failed,return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler( + ((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + + outerRingCommInfo.localRank, PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, + streamInfo_.ringStreams[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u], multirootscatter(ring) "\ + "register profiler failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) "\ + "failed,return[%d]", ringIndex, ret), ret); + + ret = LocalNotify::Post(streamInfo_.ringStreams[ringIndex], dispatcher_, streamInfo_.ringSignal[ringIndex], + PROF_STAGE_0); + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u] record failed", ringIndex), ret); + + ret = LocalNotify::Post(stream, dispatcher_, streamInfo_.ringSignalAux[ringIndex], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u] record failed", ringIndex), ret); + } else { // 主环 + executor.reset(new (std::nothrow) MultiRootScatterRing(dispatcher_)); + CHK_SMART_PTR_NULL(executor); + ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, stream, + HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, baseOffset, ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) "\ + "prepare failed,return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler( + ((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + + outerRingCommInfo.localRank, PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u], multirootscatter(ring) "\ + "register profiler failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) run "\ + "failed,return[%d]", ringIndex, ret), ret); + for (u32 ring = 0; ring < (ringNum - 1); ring++) { + /* 等待executor执行完毕 , 当前环没有分配数据,跳过此环处理,继续下一个环 */ + ret = LocalNotify::Wait(stream, dispatcher_, streamInfo_.ringSignal[ring], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingMultiRootScatter]stream[%u] wait failed", ring), ret); + } + } + } + // 添加空task,保证子图执行时不乱序 + CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); + return HCCL_SUCCESS; +} + HcclResult CollCommExecutor::MultiStreamReduceScatterMeshAtomic(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, const std::vector &dataSliceVct, Stream &stream, @@ -854,6 +1177,21 @@ HcclResult CollCommExecutor::MultiStreamReduceScatterMesh(const std::string &tag return ret; } +HcclResult CollCommExecutor::PrepareReduceScatterSliceData(u64 dataCount, u32 unitSize, u32 sliceNum, + std::vector &dataSlice) +{ + CHK_PRT_RET((sliceNum == 0), HCCL_ERROR("[CollCommExecutor][PrepareReduceScatterSliceData]sliceNum is zero."), + HCCL_E_PARA); + + dataSlice.resize(sliceNum); + u64 sliceSize = dataCount * unitSize; + for (u32 i = 0; i < sliceNum; i++) { + dataSlice[i].size = sliceSize; + dataSlice[i].offset = (i * sliceSize); + } + return HCCL_SUCCESS; +} + std::vector> CollCommExecutor::GetRingsOrderByTopoType(u32 ranksSize, TopoType topoType, std::vector &nicList) { @@ -874,7 +1212,7 @@ std::vector> CollCommExecutor::GetRingsOrderByTopoType(u32 ran std::vector tmpOuter0; // 环0 std::vector tmpOuter1; // 环1 std::vector rohOuter; - if (GetExternalInputEnableRdmaSdmaConcurrent() && (CheckSdmaWithRohTopo(nicList, rohOuter))) { + if (topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent() && (topoMatcher_->CheckSdmaWithRohTopo(nicList, rohOuter))) { tmpOuter0 = rohOuter; // 环0, 8卡 { 0, 1, 3, 2, 4, 5, 7, 6 }; tmpOuter1.reserve(ranksSize); // 环1, 8卡 { 0, 6, 7, 5, 4, 2, 3, 1 }; tmpOuter1.push_back(rohOuter[0]); @@ -986,7 +1324,7 @@ void CollCommExecutor::NicSendSizeCal(const std::vector> &mut } sizeList.push_back(tempSize); } - hcclImpl_->SetNicSendSize(tag, sizeList); + SetNicSendSize(tag, sizeList); } std::vector > CollCommExecutor::PrepareMultiRingSlice(const std::vector &dataSegsSlice, @@ -1046,7 +1384,7 @@ std::vector > CollCommExecutor::PrepareMultiRingSlice(const s rankList.clear(); } - ret = hcclImpl_->SetRingNics(tag, ringRankList); + ret = SetRingNics(tag, ringRankList); if (ret != HCCL_SUCCESS) { HCCL_ERROR("[Prepare][MultiRingSlice]set nics in ring failed, ret[%u]", ret); std::vector > emptySlice; @@ -1077,10 +1415,11 @@ HcclResult CollCommExecutor::CalUserMemSlices(const HcclDataType dataType, const std::vector &userMemSlices) { if (opInfo == nullptr || opInfo->inputAddr == nullptr || opInfo->outputAddr == nullptr) { + // 910_73场景下,allreduce算子的userMem上的slice信息 userMemSlices = singleRingSliceZero; return HCCL_SUCCESS; } - + // 910_73场景下,reduce scatter和all gather算子的userMem上的slice信息 std::vector ring0 = multiRingsOrder[0]; for (u32 sliceIdx = 0; sliceIdx < singleRingSliceZero.size(); sliceIdx++) { Slice userMemSlice; @@ -1112,7 +1451,7 @@ HcclResult CollCommExecutor::GetRankOrder(const std::vector> &m u32 CollCommExecutor::RefreshCommIdx(u32 commIndex, std::vector nicList, u32 devicePhyId) { - if (GetExternalInputEnableRdmaSdmaConcurrent() && CheckRankNeighbors(nicList)) { + if (topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent() && CheckRankNeighbors(nicList)) { std::vector::iterator iterRank = std::find(nicList.begin(), nicList.end(), devicePhyId); // 按照实际topo寻找对应的rankID,即commIndex if (iterRank != nicList.end()) { @@ -1126,4 +1465,236 @@ u32 CollCommExecutor::RefreshCommIdx(u32 commIndex, std::vector nicList, u3 } return commIndex; } + +HcclResult CollCommExecutor::MultiRingScatter(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, + const u64 count, const HcclDataType dataType, const std::vector > multRingsSliceZero, + u32 root, Stream stream, const HcomCollOpInfo *opInfo) +{ + HcclResult ret = HCCL_SUCCESS; + u32 ringNum = multRingsSliceZero.size(); + + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + + std::vector> ringNics; + CHK_RET(GetRingNics(tag, ringNics)); + + // 拿到ring环映射关系 + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + auto nicList = topoAttr_.nicList; + std::vector> multiRingsOrder = GetRingsOrderByTopoType(outerCommInfo.localRankSize, topoType_, nicList); + + // 空拷贝用于后续操作附着 + CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); + for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { + std::vector singleRingSliceZero = multRingsSliceZero[ringIndex]; + CHK_PRT_RET(singleRingSliceZero.empty(), + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]singleRingSliceZero is empty"), HCCL_E_INTERNAL); + + // 生成userMemIn_上对应的slices + std::vector userMemInputSlices; + CHK_RET( + CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, userMemInputSlices)); + std::vector rankOrder; + CHK_RET(GetRankOrder(multiRingsOrder, ringIndex, rankOrder)); + SubCommInfo outerRingCommInfo = GetSubCommInfo(COMM_LEVEL0, ringIndex); + u32 rankSize = outerRingCommInfo.localRankSize; + + std::vector subStreamsInOneRing; + std::vector> mainSignalsInOneRing; + std::vector> subSignalsInOneRing; + std::unique_ptr executor; + if (opInfo != nullptr) { + CHK_RET(GetSubStreamInfoOnOneRing(streamInfo_, ringIndex, subStreamsInOneRing, mainSignalsInOneRing, + subSignalsInOneRing)); + executor.reset(new (std::nothrow) ScatterRingConcurrentDirect( + dispatcher_, opInfo, topoAttr_.userRank, subStreamsInOneRing, + mainSignalsInOneRing, subSignalsInOneRing, rankOrder, userMemInputSlices)); + } else { + executor.reset(new (std::nothrow) ScatterRing(dispatcher_)); + } + CHK_SMART_PTR_NULL(executor); + std::vector> threadManage; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + CHK_RET(GetStreamThreadManage(tag, streamInfo_.ringNum, threadManage)); + } + + if (ringIndex != (ringNum - 1)) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline + ret = StreamActiveManager::GetInstance(topoAttr_.deviceLogicId).StreamActive( + streamInfo_.ringStreams[ringIndex].ptr(), stream.ptr()); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u],active stream failed", ringIndex), ret); + } + } + + u32 rootRank = 0; + ret = GetRankByUserRank(COMM_LEVEL0, ringIndex, root, rootRank); + CHK_PRT_RET(ret == HCCL_E_PARA, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]invalid root [%u] to get userrank", root), ret); + + if (ret == HCCL_SUCCESS) { + if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 + ret = LocalNotify::Wait(streamInfo_.ringStreams[ringIndex], dispatcher_, + streamInfo_.ringSignalAux[ringIndex], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]in stream[%u] wait failed", ringIndex), ret); + + ret = executor->Prepare(inputMem, inputMem, outputMem, count, dataType, + streamInfo_.ringStreams[ringIndex], HCCL_REDUCE_RESERVED, rootRank, singleRingSliceZero, 0, + ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u],scatter(ring) prepare failed, "\ + "return[%d]", ringIndex, ret), ret); + + ret = executor->RegisterProfiler(((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerRingCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, streamInfo_.ringStreams[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u], scatter(ring) register profiler "\ + "failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u],scatter(ring) run failed, "\ + "return[%d]", ringIndex, ret), ret); + + ret = LocalNotify::Post(streamInfo_.ringStreams[ringIndex], dispatcher_, + streamInfo_.ringSignal[ringIndex], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u] record failed", ringIndex), ret); + /* 主环record启动从环 */ + ret = LocalNotify::Post(stream, dispatcher_, streamInfo_.ringSignalAux[ringIndex], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u] record failed", ringIndex), ret); + } else { // 主环 + std::unique_ptr executor; + if (opInfo != nullptr) { + executor.reset(new (std::nothrow) ScatterRingConcurrentDirect( + dispatcher_, opInfo, topoAttr_.userRank, subStreamsInOneRing, mainSignalsInOneRing, + subSignalsInOneRing, rankOrder, userMemInputSlices)); + } else { + executor.reset(new (std::nothrow) ScatterRing(dispatcher_)); + } + CHK_SMART_PTR_NULL(executor); + ret = executor->Prepare(inputMem, inputMem, outputMem, count, dataType, stream, + HCCL_REDUCE_RESERVED, rootRank, singleRingSliceZero, 0, ringNics[ringIndex]); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u],scatter(ring) prepare failed, "\ + "return[%d]", ringIndex, ret), ret); + ret = executor->RegisterProfiler(((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerRingCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u], scatter(ring) register profiler "\ + "failed,return[%d]", ringIndex, ret), ret); + + ret = RunTemplate(executor, outerRingCommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u],scatter(ring) run failed, "\ + "return[%d]", ringIndex, ret), ret); + + for (u32 ring = 0; ring < (ringNum - 1); ring++) { + /* 等待executor执行完毕 , 当前环没有分配数据,跳过此环处理,继续下一个环 */ + ret = LocalNotify::Wait(stream, dispatcher_, streamInfo_.ringSignal[ring], PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollCommExecutor][MultiRingScatter]stream[%u] wait failed", ring), ret); + } + } + } + } + // 添加空task,保证子图执行时不乱序 + CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); + return HCCL_SUCCESS; +} + +HcclResult CollCommExecutor::GetStreamThreadManage(const std::string &tag, u32 streamNum, + std::vector> &threadManager) +{ std::unique_lock mutiStreamLock(threadManageMapLock_); + auto iterRank = threadManageMap_.find(tag); + if (iterRank == threadManageMap_.end()) { + std::vector> threadManagerVec; + threadManagerVec.resize(streamNum -1); + for (u32 ringIndex = 0; ringIndex < streamNum -1; ringIndex ++) { + threadManagerVec[ringIndex].reset(new (std::nothrow) ThreadManage(topoAttr_.deviceLogicId, + topoAttr_.userRank, + dispatcher_)); + CHK_SMART_PTR_NULL(threadManagerVec[ringIndex]); + HcclResult ret = threadManagerVec[ringIndex]->Init(); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[Init][MultiRingResource]ringIndex[%u] ThreadManage failed,return[%d]", + ringIndex, ret), ret); + HCCL_INFO("ringThreadsManage Init success[%u]", ringIndex); + } + threadManageMap_.insert(std::pair>>(tag, std::move(threadManagerVec))); + } else { + threadManager = iterRank->second; + return HCCL_SUCCESS; + } + iterRank = threadManageMap_.find(tag); + threadManager = iterRank->second; + return HCCL_SUCCESS; +} + +HcclResult CollCommExecutor::SetRingNics(const std::string &tag, const std::vector> &ringNics) +{ + std::unique_lock lock(ringNicListLock_); + ringNicList_[tag] = ringNics; + return HCCL_SUCCESS; +} +HcclResult CollCommExecutor::GetRingNics(const std::string &tag, std::vector> &ringNics) +{ + std::unique_lock lock(ringNicListLock_); + auto iterRingNic = ringNicList_.find(tag); + if (iterRingNic == ringNicList_.end()) { + ringNics = {{0, 1, 2, 3, 4, 5, 6, 7}}; + } else { + ringNics = iterRingNic->second; + } + return HCCL_SUCCESS; +} +HcclResult CollCommExecutor::SetNicSendSize(const std::string &tag, std::vector &sizeList) +{ + std::unique_lock lock(nicSendSizeListLock_); + nicSendSizeList_[tag] = sizeList; + return HCCL_SUCCESS; +} +HcclResult CollCommExecutor::PrepareInnerCommInfo(u32 &segmentIdx, u32 &commIndex, u64 &hdSize, + const SubCommInfo &commInfo, + const std::vector> &multRingsSliceZero, + const std::string &tag) +{ + segmentIdx = topoAttr_.devicePhyId; + commIndex = topoAttr_.devicePhyId; + CHK_PRT_RET(multRingsSliceZero.empty(), HCCL_ERROR("[Prepare][InnerCommInfo]sicle map is empty"), HCCL_E_PARA); + if (multRingsSliceZero.size() > 1) { + std::vector::const_iterator iterNic = std::find(topoAttr_.nicList.begin(), + topoAttr_.nicList.end(), topoAttr_.devicePhyId); + if (iterNic != topoAttr_.nicList.end()) { // 如果当前rank为通信网口 + u32 nicIdx = std::distance(topoAttr_.nicList.begin(), iterNic); + std::unique_lock lock(nicSendSizeListLock_); + auto iter = nicSendSizeList_.find(tag); + CHK_PRT_RET(iter == nicSendSizeList_.end(), HCCL_ERROR("[Prepare][InnerCommInfo]find tag[%s] in "\ + "nicSendSizeList_ failed", tag.c_str()), HCCL_E_INTERNAL); + CHK_PRT_RET(nicIdx >= iter->second.size(), HCCL_ERROR("[Prepare][InnerCommInfo]tag[%s] nicIdx[%u] "\ + "invaild, expect less than %zu", tag.c_str(), nicIdx, iter->second.size()), HCCL_E_INTERNAL); + hdSize = iter->second[nicIdx]; // 通过nicSendSizeList_得到该网口传输数据量 + u32 ringRanks = multRingsSliceZero[0].size(); // 获取单个 ring 上设备的数量 + segmentIdx = ringRanks / topoAttr_.nicList.size() * nicIdx; // 通过网口位置得到该网口传输数据的起始位置 + if (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + commIndex = segmentIdx; + } + } else { // 如果当前rank不是通信网口,则不发送数据 + hdSize = 0; + } + } else if (multRingsSliceZero.size() == 1) { + segmentIdx = commInfo.localRank; // 针对0、4device下 + CHK_PRT_RET(segmentIdx >= multRingsSliceZero[0].size(), HCCL_ERROR("[Prepare][InnerCommInfo]index is out of "\ + "range. Idx[%u] Slice size[%llu]", segmentIdx, multRingsSliceZero[0].size()), HCCL_E_PARA); + hdSize = multRingsSliceZero[0][segmentIdx].size; + commIndex = segmentIdx; + } else { + return HCCL_E_PARA; + } + return HCCL_SUCCESS; +} } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.h index 959c7c5b0a5de2e7c9825eeb6baa797cf82cb6ba..a53abbfc7358f37cd1d22084bfa0dc70d3e41a79 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_comm_executor.h @@ -17,14 +17,20 @@ namespace hccl { class CollCommExecutor : public CollNativeExecutorBase { public: - CollCommExecutor(std::unique_ptr &pImpl); + CollCommExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollCommExecutor() = default; - // CCL Op Share 目前只包含AllReduce涉及的接口 + // CCL Op Share + HcclResult MultiRingAllReduce(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + const u64 count, const HcclDataType dataType, + const HcclReduceOp reductionOp, + const std::vector> &multRingsSliceZero, Stream stream, + s32 profStage, const u64 baseOffset = 0); HcclResult MultiRingReduceScatter(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, - const HcclDataType dataType, const HcclReduceOp reductionOp, - const std::vector> multRingsSliceZero, Stream stream, - s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr); + const HcclDataType dataType, const HcclReduceOp reductionOp, + const std::vector> multRingsSliceZero, Stream stream, + s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr, + const std::vector> multRingsUserMemSlice = std::vector> (0)); HcclResult MultiRingReduceScatterConcurrent(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, @@ -33,15 +39,20 @@ public: s32 profStage, const u64 baseOffset, const HcomCollOpInfo *opInfo = nullptr); HcclResult MultiRingAllGather(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, - const HcclDataType dataType, - const std::vector > multRingsSliceZero, Stream stream, - s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr); + const HcclDataType dataType, + const std::vector > multRingsSliceZero, Stream stream, + s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr, + const std::vector> multRingsUserMemSlice = std::vector> (0)); HcclResult MultiRingAllGatherConcurrent(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const std::vector>> multRingsSliceZero, Stream stream, s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr); + HcclResult MultiRingMultiRootScatter(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + const u64 count, const HcclDataType dataType, const std::vector> &multRingsSliceZero, + u32 root, Stream stream, const u64 baseOffset); + HcclResult MultiStreamReduceScatterMesh(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, @@ -49,6 +60,11 @@ public: Stream stream, const CommPlane commLevelIndex, const u64 baseOffset = 0); + + HcclResult MultiRingGather(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, + const HcclDataType dataType, const std::vector> multRingsSliceZero, + HcclReduceOp op, u32 root, Stream stream, s32 profStage); + HcclResult MultiStreamReduceScatterMeshAtomic(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, @@ -56,7 +72,11 @@ public: Stream &stream, const CommPlane commLevelIndex, const u64 baseOffset = 0, HcomCollOpInfo *opInfo = nullptr); + HcclResult PrepareReduceScatterSliceData(u64 dataCount, u32 unitSize, u32 sliceNum, std::vector &dataSlice); + HcclResult MultiRingScatter(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, + const HcclDataType dataType, const std::vector > multRingsSliceZero, + u32 root, Stream stream, const HcomCollOpInfo *opInfo); std::vector> GetRingsOrderByTopoType(u32 ranksSize, TopoType topoType, std::vector &nicList); HcclResult MutliSegSlicePrepare(const std::vector &dataSegsSlice, std::vector >& mutliSegsSlices, u32 ringCount); @@ -75,7 +95,11 @@ public: bool IsMultiMeshInlineReduce(void *inputPtr, void *outputPtr, HcclDataType dataType, HcclReduceOp op); u64 GetReduceAttr(DeviceMem &inputMem, DeviceMem &outputMem, HcclDataType dataType, HcclReduceOp op); -private: + HcclResult PrepareInnerCommInfo(u32 &segmentIdx, u32 &commIndex, u64 &hdSize, + const SubCommInfo &commInfo, + const std::vector > &multRingsSliceZero, + const std::string &tag); +protected: HcclResult GetSubStreamInfoOnOneRing(const innerStreamInfo_t &streamInfo, const u32 ringIndex, std::vector &subStreamsInOneRing, std::vector> &mainSignalsInOneRing, @@ -86,6 +110,17 @@ private: std::vector &userMemSlices); HcclResult GetRankOrder(const std::vector> &multiRingsOrder, u32 ringIndex, std::vector &rankOrder); + HcclResult SetRingNics(const std::string &tag, const std::vector> &ringNics); + HcclResult GetRingNics(const std::string &tag, std::vector> &ringNics); + HcclResult SetNicSendSize(const std::string &tag, std::vector &sizeList); + HcclResult GetStreamThreadManage(const std::string &tag, u32 streamNum, + std::vector> &threadManager); + std::mutex ringNicListLock_; + std::map>> ringNicList_; + std::mutex nicSendSizeListLock_; + std::map> nicSendSizeList_; + std::mutex threadManageMapLock_; + std::map>> threadManageMap_; }; } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.cc index 5c7bb901af89c65855aad95ffffe619b39b675e1..09261dd6ed951ea82f72dc88e0a14cbfbc955443 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.cc @@ -12,8 +12,8 @@ namespace hccl { -CollExecutorBase::CollExecutorBase(std::unique_ptr &pImpl) - : hcclImpl_(pImpl) +CollExecutorBase::CollExecutorBase(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher) + : dispatcher_(dispatcher), topoMatcher_(topoMatcher) { } @@ -23,6 +23,24 @@ HcclResult CollExecutorBase::SetAlgType(const AlgType algType) return HCCL_SUCCESS; } +HcclResult CollExecutorBase::SetVirtualDispatcher(const HcclDispatcher vDispatcher) +{ + vDispatcher_ = vDispatcher; + return HCCL_SUCCESS; +} + +HcclResult CollExecutorBase::SetCCLInBuffer(u64 cclbufferSize) +{ + inCCLbufferSize_ = cclbufferSize; + return HCCL_SUCCESS; +} + +HcclResult CollExecutorBase::SetParallelTaskLoader(ParallelTaskLoader* parallelTaskLoader) +{ + parallelTaskLoader_ = parallelTaskLoader; + return HCCL_SUCCESS; +} + HcclResult CollExecutorBase::RunTemplate(const std::unique_ptr &executor, const SubCommInfo &commInfo) { HcclResult ret = executor->RunAsync(commInfo.localRank, commInfo.localRankSize, commInfo.links); @@ -33,14 +51,46 @@ HcclResult CollExecutorBase::RunTemplate(const std::unique_ptr &ex commInfo.localRank, commInfo.localRankSize), ret); return HCCL_SUCCESS; } + +HcclResult CollExecutorBase::CalcIncreLinkRequest(const OpParam& param, AlgResourceRequest &resourceRequest) +{ + return HCCL_SUCCESS; +} +HcclResult CollExecutorBase::RunAlltoAllTemplate(const std::unique_ptr &executor, + const SubCommInfo &commInfo) +{ + HcclResult ret = executor->RunAsync(commInfo.localRank, commInfo.localRankSize, commInfo.links); + CHK_PRT_RET(ret == HCCL_E_AGAIN, HCCL_WARNING("[CollExecutorBase][RunAlltoAllTemplate]" \ + "group has been destroyed. Break!"), ret); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollExecutorBase][RunAlltoAllTemplate]run executor rank[%u] rank size[%u] failed", + commInfo.localRank, commInfo.localRankSize), ret); + return HCCL_SUCCESS; +} -bool CollExecutorBase::NeedIncrCreateLink(const OpParam& param) +HcclResult CollExecutorBase::RunAlltoAllVTemplateStaged(const std::unique_ptr &executor, + const SubCommInfo &commInfo) { - return false; + HcclResult ret = executor->RunAsync(commInfo.localRank, commInfo.localRankSize, commInfo.links); + CHK_PRT_RET(ret == HCCL_E_AGAIN, HCCL_WARNING("[CollExecutorBase][RunAlltoAllVTemplateStaged]" \ + "group has been destroyed. Break!"), ret); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollExecutorBase][RunAlltoAllVTemplateStaged]run executor rank[%u] rank size[%u] failed", + commInfo.localRank, commInfo.localRankSize), ret); + return HCCL_SUCCESS; } - -HcclResult CollExecutorBase::CalcIncreLinkRequest(const OpParam& param, AlgResourceRequest &resourceRequest) + +// deprecated +HcclResult CollExecutorBase::RunTemplateWithVirtualLink(const std::unique_ptr &executor, + const SubCommInfo &commInfo) { + HcclResult ret = executor->RunAsync(commInfo.localRank, commInfo.localRankSize, commInfo.virtualLinks); + CHK_PRT_RET(ret == HCCL_E_AGAIN, HCCL_WARNING("[CollExecutorBase][RunTemplateWithVirtualLink]" \ + "group has been destroyed. Break!"), ret); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollExecutorBase][RunTemplateWithVirtualLink]run executor rank[%u] rank size[%u] failed", + commInfo.localRank, commInfo.localRankSize), ret); return HCCL_SUCCESS; } + } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.h index 3b150c93cb27a978223bdea25c79e110483510d1..99015bf30d19ba18038299e2b21ee1821a6ff670 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_executor_base.h @@ -12,6 +12,7 @@ #define COLL_EXECUTOR_BASE_H #include "hccl_impl.h" +#include "topo_matcher.h" #include "coll_alg_param.h" #include "executor_impl.h" @@ -19,27 +20,45 @@ namespace hccl { class CollExecutorBase { public: - CollExecutorBase(std::unique_ptr &pImpl); + CollExecutorBase(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); virtual ~CollExecutorBase() = default; // 每次构造完必须调用 SetAlgType HcclResult SetAlgType(const AlgType algType); - // 对于原生算子,将在CollNativeExecutorBase中实现 + HcclResult SetVirtualDispatcher(const HcclDispatcher virtualDispatcher); + + HcclResult SetCCLInBuffer(u64 cclbufferSize); + + HcclResult SetParallelTaskLoader(ParallelTaskLoader *parallelTaskLoader); + virtual HcclResult CalcResRequest(const OpParam& param, AlgResourceRequest &resourceRequest) = 0; - // 对于原生算子,将由每个算子独立实现 + virtual bool CheckNeedRecreateComm(u64 lastScratchMemSize) = 0; virtual HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) = 0; - // 增量建链 - virtual bool NeedIncrCreateLink(const OpParam& param); - + // batchsendrecv需要增量建链 virtual HcclResult CalcIncreLinkRequest(const OpParam& param, AlgResourceRequest &resourceRequest); + virtual HcclResult SetExcutorExtraInfo(const std::vector &allMeshAggregationSendRecvInfo) + { + return HCCL_SUCCESS; + } static HcclResult RunTemplate(const std::unique_ptr &executor, const SubCommInfo &commInfo); + static HcclResult RunAlltoAllTemplate(const std::unique_ptr &executor, + const SubCommInfo &commInfo); + static HcclResult RunAlltoAllVTemplateStaged(const std::unique_ptr &executor, + const SubCommInfo &commInfo); + static HcclResult RunTemplateWithVirtualLink(const std::unique_ptr &executor, + const SubCommInfo &commInfo); + protected: - std::unique_ptr &hcclImpl_; + const HcclDispatcher dispatcher_; + HcclDispatcher vDispatcher_; + u64 inCCLbufferSize_{0}; // CCLIN大小,用于计算scratch AlgType algType_; // 算法类型 + std::unique_ptr &topoMatcher_; + ParallelTaskLoader* parallelTaskLoader_; // 并行下发taskloader管理 }; } #endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.cc index cf94168f9702cc458f03c24a4f614f21157808fd..625e3677c9518f8a7f77b77bf48b5b015136564c 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.cc @@ -12,17 +12,24 @@ #include "profiling_manager_pub.h" namespace hccl { -CollNativeExecutorBase::CollNativeExecutorBase(std::unique_ptr &pImpl) - : CollExecutorBase(pImpl), dispatcher_(pImpl->dispatcher_), - algoAttr_(pImpl->algoAttr_), topoAttr_(pImpl->topoAttr_), - is310P3Common_(hcclImpl_->Is310P3Common()) +CollNativeExecutorBase::CollNativeExecutorBase(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollExecutorBase(dispatcher, topoMatcher), topoAttr_(topoMatcher_->GetTopoInfo()), + algoAttr_(topoMatcher_->GetAlgoInfo()) { - hcclImpl_->GetTopoType(topoType_); + topoType_ = topoAttr_.topoType; + is310P3Common_ = topoAttr_.is310P3Common; } void CollNativeExecutorBase::ParseParam(const OpParam& param) { tag_ = param.tag; + root_ = param.root; +} + +bool CollNativeExecutorBase::CheckNeedRecreateComm(u64 lastScratchMemSize) +{ + return false; } // ----------------------资源计算接口---------------------- @@ -49,22 +56,7 @@ HcclResult CollNativeExecutorBase::CalcResRequest(const OpParam& param, AlgResou resourceRequest.streamNum, resourceRequest.notifyNum, resourceRequest.scratchMemSize, resourceRequest.needAivBuffer); // 打印建链诉求 - for (u32 levelIndex = 0; levelIndex < COMM_LEVEL_RESERVED; levelIndex++) { - LevelNSubCommTransport &levelTransport = resourceRequest.opTransport[levelIndex]; - u32 ringSize = levelTransport.size(); - for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { - SingleSubCommTransport &subCommTransport = levelTransport[ringIndex]; - u32 rankSize = subCommTransport.transportRequests.size(); - for (u32 rankIndex = 0; rankIndex < rankSize; rankIndex++) { - if (subCommTransport.transportRequests[rankIndex].isValid == true) { - HCCL_INFO("[CollNativeExecutorBase][CalcResRequest]" \ - "levelIndex[%u], ringIndex[%u], rankIndex[%u], userRank[%u], remoteRank[%u]", - levelIndex, ringIndex, rankIndex, subCommTransport.transportRequests[rankIndex].localUserRank, - subCommTransport.transportRequests[rankIndex].remoteUserRank); - } - } - } - } + PrintTransportRequest(resourceRequest); return HCCL_SUCCESS; } @@ -109,7 +101,7 @@ HcclResult CollNativeExecutorBase::CalcCommPlaneInfo(const std::string &tag, con std::vector &commTransport, TransportMemType inPutMemType, TransportMemType outPutMemType) { - return hcclImpl_->CalcCommPlaneInfo(tag, commParaInfo, commTransport, inPutMemType, outPutMemType); + return topoMatcher_->CalcCommPlaneInfo(tag, commParaInfo, commTransport, inPutMemType, outPutMemType); } HcclResult CollNativeExecutorBase::CalcLevel1CommInfo(TransportMemType inputType, @@ -118,7 +110,7 @@ HcclResult CollNativeExecutorBase::CalcLevel1CommInfo(TransportMemType inputType { HCCL_INFO("[CollNativeExecutorBase][CalcInnerCommInfo]tag[%s]start", tag_.c_str()); - CommParaInfo commParaLevel1(COMM_LEVEL1, CommType::COMM_TAG_MAX); + CommParaInfo commParaLevel1(COMM_LEVEL1, CommType::COMM_TAG_MAX, root_); if (UseInterServerRingAlgo(algType_)) { commParaLevel1.commType = CommType::COMM_TAG_RING_INNER; HCCL_INFO("[CollNativeExecutorBase][CalcInnerCommInfo]tag[%s] Calc RingCommInfo", tag_.c_str()); @@ -137,7 +129,7 @@ HcclResult CollNativeExecutorBase::CalcLevel1CommInfo(TransportMemType inputType } commParaLevel1.forceRdma = false; CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel1, opTransport[COMM_LEVEL1], inputType, outputType)); - if (GetExternalInputEnableRdmaSdmaConcurrent() && UseInterServerRingAlgo(algType_)) { + if (topoMatcher_->GetExternalInputEnableRdmaSdmaConcurrent() && UseInterServerRingAlgo(algType_)) { CommParaInfo commParaLevel1Rdma(COMM_LEVEL1_RDMA, CommType::COMM_TAG_RING_INNER); commParaLevel1Rdma.forceRdma = true; CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel1Rdma, opTransport[COMM_LEVEL1_RDMA], inputType, @@ -162,6 +154,26 @@ HcclResult CollNativeExecutorBase::CalcLevel2CommInfo(TransportMemType inputType return HCCL_SUCCESS; } +HcclResult CollNativeExecutorBase::PrintTransportRequest(AlgResourceRequest& resourceRequest) +{ + for (u32 levelIndex = 0; levelIndex < COMM_LEVEL_RESERVED; levelIndex++) { + LevelNSubCommTransport &levelTransport = resourceRequest.opTransport[levelIndex]; + u32 ringSize = levelTransport.size(); + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { + SingleSubCommTransport &subCommTransport = levelTransport[ringIndex]; + u32 rankSize = subCommTransport.transportRequests.size(); + for (u32 rankIndex = 0; rankIndex < rankSize; rankIndex++) { + if (subCommTransport.transportRequests[rankIndex].isValid == true) { + HCCL_INFO("[CollNativeExecutorBase][CalcResRequest]" \ + "levelIndex[%u], ringIndex[%u], rankIndex[%u], userRank[%u], remoteRank[%u]", + levelIndex, ringIndex, rankIndex, subCommTransport.transportRequests[rankIndex].localUserRank, + subCommTransport.transportRequests[rankIndex].remoteUserRank); + } + } + } + } + return HCCL_SUCCESS; +} // ----------------------算法编排接口---------------------- HcclResult CollNativeExecutorBase::KernelRun(const OpParam ¶m, ExecMem &execMem) { @@ -240,6 +252,7 @@ SubCommInfo CollNativeExecutorBase::GetSubCommInfo(const CommPlane levelIndex, c info.localRank = transportInfo.userRank2subCommRank[topoAttr_.userRank]; info.localRankSize = transportInfo.transportRequests.size(); info.links = transportInfo.links; + info.virtualLinks = transportInfo.virtualLinks; return info; } @@ -262,6 +275,11 @@ bool CollNativeExecutorBase::UseInterServerRingAlgo(AlgType algType) return GetLevel1AlgType(algType) == AlgTypeLevel1::ALG_LEVEL1_RING; } +bool CollNativeExecutorBase::UseInterServerHDAlgo(AlgType algType) +{ + return GetLevel1AlgType(algType) == AlgTypeLevel1::ALG_LEVEL1_HD; +} + bool CollNativeExecutorBase::UseInterServerNHRAlgo(AlgType algType) { return GetLevel1AlgType(algType) == AlgTypeLevel1::ALG_LEVEL1_NHR; @@ -282,6 +300,11 @@ bool CollNativeExecutorBase::UseLevel2RingAlgo(AlgType algType) return GetLevel2AlgType(algType) == AlgTypeLevel2::ALG_LEVEL2_RING; } +bool CollNativeExecutorBase::UseInterServerPipelineAlgo(AlgType algType) +{ + return GetLevel1AlgType(algType) == AlgTypeLevel1::ALG_LEVEL1_PIPELINE; +} + AlgTypeLevel2 CollNativeExecutorBase::GetLevel2AlgType(const AlgType algType) const { const u32 algLevel2 = static_cast(algType) >> (HCCL_LEVEL_ALGO_WIDTH * 2); @@ -299,4 +322,21 @@ HcclResult CollNativeExecutorBase::BuildResourceRequest(u64 scratchMemSize, u32 return HCCL_SUCCESS; } -} \ No newline at end of file +HcclResult CollNativeExecutorBase::GetRankByUserRank(CommPlane levelIndex, u32 subLevelIndex, u32 userRank, u32 &rank) +{ + CHK_RET(CheckCommSize(levelIndex, subLevelIndex + 1)); + SingleSubCommTransport &transportInfo = + const_cast(algResResp_->opTransportResponse[levelIndex][subLevelIndex]); + rank = transportInfo.userRank2subCommRank[userRank]; + return HCCL_SUCCESS; +} + +HcclResult CollNativeExecutorBase::GetUserRankByRank(CommPlane levelIndex, u32 subLevelIndex, u32 rank, u32 &userRank) +{ + CHK_RET(CheckCommSize(levelIndex, subLevelIndex + 1)); + SingleSubCommTransport &transportInfo = + const_cast(algResResp_->opTransportResponse[levelIndex][subLevelIndex]); + userRank = transportInfo.subCommRank2UserRank[rank]; + return HCCL_SUCCESS; +} +} diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.h index a97d68bf1552ec8d93dcf123459b980b314b7cb3..99a5b6d1ee0745d4a4e601337c926f5c74424c9b 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_native_executor_base.h @@ -31,10 +31,12 @@ struct ExecMem { class CollNativeExecutorBase : public CollExecutorBase { public: - CollNativeExecutorBase(std::unique_ptr &pImpl); + CollNativeExecutorBase(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollNativeExecutorBase() = default; HcclResult CalcResRequest(const OpParam& param, AlgResourceRequest &resourceRequest) override; + bool CheckNeedRecreateComm(u64 lastScratchMemSize) override; + protected: /* *************** 资源计算 *************** */ virtual void ParseParam(const OpParam& param); @@ -60,7 +62,7 @@ protected: TransportMemType outPutMemType); HcclResult BuildResourceRequest(u64 scratchMemSize, u32 streamNum, u32 notifyNum, bool needAivBuffer, std::vector& opTransport, AlgResourceRequest& resourceRequest); - + HcclResult PrintTransportRequest(AlgResourceRequest& resourceRequest); /* *************** 算法编排 *************** */ // 虚函数,执行具体的数据搬运、reduce操作。 各Executor重载。 // 按Inner、Outer、Level2可继续进行拆分。 @@ -74,6 +76,7 @@ protected: HcclResult AddSubStreamToProfiling(); // 检查通信域大小 HcclResult CheckCommSize(const CommPlane levelIndex, const u32 subLevelIndex); + // 获取不同类型通信域中的 transport 信息 // 为了避免循环调用时反复校验Range引发性能问题,此处不做Range校验,建议调用该接口前先调用CheckCommSize避免OutOfRange问题 SubCommInfo GetSubCommInfo(const CommPlane levelIndex, const u32 subLevelIndex); @@ -84,22 +87,26 @@ protected: AlgTypeLevel2 GetLevel2AlgType(const AlgType algType) const; bool UseInterServerRingAlgo(AlgType algType); + bool UseInterServerHDAlgo(AlgType algType); bool UseInterServerNHRAlgo(AlgType algType); bool UseInterServerNHRV1Algo(AlgType algType); bool UseInterServerNBAlgo(AlgType algType); bool UseLevel2RingAlgo(AlgType algType); + bool UseInterServerPipelineAlgo(AlgType algType); + HcclResult GetRankByUserRank(CommPlane levelIndex, u32 subLevelIndex, u32 userRank, u32 &rank); + HcclResult GetUserRankByRank(CommPlane levelIndex, u32 subLevelIndex, u32 rank, u32 &userRank); /* ---------------以下为 protected 成员变量定义领域-------------------------- */ std::string tag_; + u32 root_ = INVALID_VALUE_RANKID; const AlgResourceResponse *algResResp_ = nullptr; innerStreamInfo_t streamInfo_; - // Infos got from hcclImpl - const HcclDispatcher dispatcher_; - const HcclAlgoAttr &algoAttr_; - const HcclTopoAttr &topoAttr_; + // Infos got from topoMatcher_ + const HcclTopoInfo topoAttr_; + const HcclAlgoInfo algoAttr_; TopoType topoType_; bool is310P3Common_ = false; }; } -#endif \ No newline at end of file +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2bd0aadd86952c174d54afae472c97909425d605 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/CMakeLists.txt @@ -0,0 +1,12 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_mesh_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_ring_plus_hd_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_double_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_comm_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_single_rank_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_comm_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..194befb103c65de067b635685ebb426ef014ebe6 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_comm_executor.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_reduce_comm_executor.h" + +namespace hccl { + +CollReduceCommExecutor::CollReduceCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceCommExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcCombinedCommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollReduceCommExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollReduceCommExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceCommExecutor::CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_COMBINE, CommType::COMM_TAG_MAX); + commParaInfo.commType = CommType::COMM_TAG_RING_INNER; + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_COMBINE], inputType, outputType)); + + return HCCL_SUCCESS; +} + +HcclResult CollReduceCommExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_COMBINE, 1)); + SubCommInfo combinedCommInfo = GetSubCommInfo(COMM_COMBINE, 0); + + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.outputMem, param.DataDes.dataType, param.reduceType); + + std::unique_ptr executor; + executor.reset(new (std::nothrow) ReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("Reduce comm: using ring algo inter-server."); + CHK_SMART_PTR_NULL(executor); + + u32 rankSize = combinedCommInfo.localRankSize; + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType, + OUTER_BRIDGE_RANK_ID, std::vector(0), 0)); + + CHK_RET(executor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + + combinedCommInfo.localRank, PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(executor, combinedCommInfo)); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceComm", ReduceComm, CollReduceCommExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_comm_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..8add1a636c4353eef090fd2add6bb0fb14dfdaa2 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_comm_executor.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCE_COMM_EXECUTOR_H +#define COLL_REDUCE_COMM_EXECUTOR_H +#include "coll_reduce_executor.h" +namespace hccl { +class CollReduceCommExecutor : public CollReduceExecutor { + +public: + CollReduceCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceCommExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_double_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_double_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd39aecb4182721ba734f70595039d958b28c835 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_double_ring_executor.cc @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_reduce_double_ring_executor.h" + +namespace hccl { + +CollReduceDoubleRingExecutor::CollReduceDoubleRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceDoubleRingExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + // DoubleRing只支持910_73场景 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollAllReduceDoubleRingExecutor][CalcStreamNum] tag[%s] streamNum_[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceDoubleRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcLevel0CommInfo(inputType, outputType, opTransport); + CalcLevel1CommInfo(inputType, outputType, opTransport); + CalcLevel2CommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollReduceDoubleRingExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollReduceDoubleRingExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceDoubleRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollReduceDoubleRingExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollReduceDoubleRingExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollReduceDoubleRingExecutor::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceDoubleRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollReduceDoubleRingExecutor][Run]The CollReduceDoubleRingExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multiRingsSliceZero; // 数据基于该rank上环0的偏移 + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 sliceNum = outerCommInfo.localRankSize; + // 根据数据量计算每个环上数据的偏移和大小 + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + + /* 三步算法step1:外层 - 节点内 reduce-scatter */ + // 构造ring algorithm对应的reduce-scatter实例 + multiRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, tag_, false, topoAttr_.nicList); + CHK_PRT_RET(multiRingsSliceZero.size() != ringNum, HCCL_ERROR("[CollReduceDoubleRingExecutor][Run]"\ + "ringNum[%u] != multiRingsSliceZero size[%llu]", ringNum, multiRingsSliceZero.size()), + HCCL_E_INTERNAL); + + HcomCollOpInfo *reduceScatterOpInfoPtr = nullptr; + // 第一步的reducescatter输出放在CCL buffer上,通过设置nullptr指示不做最后一步的DMA削减动作 + + CHK_RET(MultiRingReduceScatter(tag_, execMem.inputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.reduceType, multiRingsSliceZero, param.stream, + PROF_STAGE_0, 0, reduceScatterOpInfoPtr)); + HCCL_INFO("reduce double ring stage0 run success"); + + // step2: 节点间的reduce + u32 commIndex = 0; + u64 level1Size = 0; + u32 segmentIdx = 0; + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, level1Size, outerCommInfo, multiRingsSliceZero, tag_)); + u64 level1Count = level1Size / perDataSize; + if (topoAttr_.devNumInLevel2 <= 1) { + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + DeviceMem reduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, level1Size); + CHK_SMART_PTR_NULL(reduceInput); + DeviceMem reduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, level1Size); + CHK_SMART_PTR_NULL(reduceOutput); + u64 reduceAttr = GetReduceAttr(reduceInput, reduceOutput, param.DataDes.dataType, param.reduceType); + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("[CollReduceDoubleRingExecutor]using ring algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) ReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("[CollReduceDoubleRingExecutor]using Recursive halving-doubling algo inter-server."); + } + u32 rankSize = innerCommInfo.localRankSize; + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + CHK_PRT_RET(subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[CollReduceDoubleRingExecutor]subUserrankRoot[%u] is invalid,userRank[%u],root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), HCCL_E_INTERNAL); + u32 planeRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, planeRoot)); + // 节点间的hd 使用环0来记录 + CHK_SMART_PTR_NULL(innerExecutor); + CHK_RET(innerExecutor->Prepare(reduceInput, reduceOutput, reduceOutput, level1Count, param.DataDes.dataType, + param.stream, param.reduceType, planeRoot, std::vector(0), + dataSegsSlice[segmentIdx].offset)); + CHK_RET(innerExecutor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } else { + //节点间 reduce scatter + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + SubCommInfo innerZeroCommInfo = GetSubCommInfo(COMM_LEVEL1, COMM_INDEX_0); + sliceNum = innerZeroCommInfo.localRankSize; + CHK_RET(ExecutorBase::PrepareSliceData(level1Count, perDataSize, sliceNum, 0, dataSegsSlice)); + + DeviceMem reducescatterInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, level1Size); + DeviceMem reducescatterOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, level1Size); + + u64 reduceAttr = GetReduceAttr(reducescatterInput, reducescatterOutput, param.DataDes.dataType, param.reduceType); + std::unique_ptr level1RSExecutor; + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + u32 planeRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, planeRoot)); + if (UseInterServerRingAlgo(algType_)) { + level1RSExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(level1RSExecutor); + CHK_RET(level1RSExecutor->Prepare( + reducescatterInput, reducescatterInput, reducescatterOutput, level1Count, param.DataDes.dataType, param.stream, param.reduceType, + planeRoot, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reducescatter ring: using ring algo inter-server."); + } else { + level1RSExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(level1RSExecutor); + CHK_RET(level1RSExecutor->Prepare( + reducescatterInput, reducescatterOutput, reducescatterOutput, level1Count, param.DataDes.dataType, param.stream, param.reduceType, + planeRoot, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reducescatter ring: using halving-doubling algo inter-server."); + } + CHK_RET(level1RSExecutor->RegisterProfiler( + (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level1RSExecutor, innerCommInfo)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reduce double ring [superpod] level1 reduce-scatter run success"); + + // 超节点 reduce + u64 rSize; + std::vector> rdSlice; + rdSlice.push_back(dataSegsSlice); + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, rSize, innerZeroCommInfo, rdSlice, tag_)); + u64 arCount = rSize / perDataSize; + + CHK_RET(CheckCommSize(COMM_LEVEL2, commIndex + 1)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, commIndex); + u32 rankSize = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0).localRankSize; + + u32 subUserrankRootSupperPod = topoMatcher_->GetSubRootUserRankWithSuperPod(topoAttr_.userRank, param.root); + u32 planeRootSupperPod = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL2, commIndex, subUserrankRootSupperPod, planeRootSupperPod)); + + DeviceMem reduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, rSize); + DeviceMem reduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, rSize); + + reduceAttr = GetReduceAttr(reduceInput, reduceOutput, param.DataDes.dataType, param.reduceType); + std::unique_ptr level2RExecutor; + if (UseLevel2RingAlgo(algType_)) { + level2RExecutor.reset(new (std::nothrow) ReduceRing(dispatcher_, reduceAttr)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reducescatter ring: using ring algo inter-server."); + } else { + level2RExecutor.reset(new (std::nothrow) ReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reducescatter ring: using halving-doubling algo inter-server."); + } + + CHK_RET(level2RExecutor->Prepare( + reduceInput, reduceOutput, reduceOutput, arCount, param.DataDes.dataType, param.stream, param.reduceType, planeRootSupperPod, + std::vector(0), dataSegsSlice[segmentIdx].offset)); + CHK_SMART_PTR_NULL(level2RExecutor); + CHK_RET(level2RExecutor->RegisterProfiler( + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2RExecutor, level2CommInfo)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reduce double ring [superpod] level2 reduce run success"); + // 节点间 gather + std::unique_ptr level1GExecutor; + DeviceMem gatherInput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, rSize); + DeviceMem gatherOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, rSize*sliceNum); + level1GExecutor.reset(new (std::nothrow) GatherRing(dispatcher_)); + CHK_SMART_PTR_NULL(level1GExecutor); + CHK_RET(level1GExecutor->Prepare(gatherOutput, gatherOutput, gatherOutput, arCount, param.DataDes.dataType, param.stream, + HcclReduceOp::HCCL_REDUCE_RESERVED, planeRoot, dataSegsSlice, + dataSegsSlice[segmentIdx].offset)); + CHK_RET(level1GExecutor->RegisterProfiler( + (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level1GExecutor, innerCommInfo)); + HCCL_INFO("[CollReduceDoubleRingExecutor]reduce double ring [superpod] level1 gather run success"); + } + HCCL_INFO("[CollReduceDoubleRingExecutor]stage1 run success"); + + // step3: 节点内的gatherring,只有在root所在server内进行gather操作 + SingleSubCommTransport &outerTransportInfo = + const_cast(algResResp_->opTransportResponse[COMM_LEVEL0][COMM_INDEX_0]); + + if (outerTransportInfo.userRank2subCommRank.find(param.root) != + outerTransportInfo.userRank2subCommRank.end()) { + CHK_RET(MultiRingGather(tag_, execMem.outputMem, execMem.outputMem, level1Count, param.DataDes.dataType, + multiRingsSliceZero, param.reduceType, param.root, const_cast(param.stream), PROF_STAGE_2)); + } + HCCL_INFO("[CollReduceDoubleRingExecutor]reduce double ring stage2 run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceDoubleRingExecutor", ReduceDoubleRing, CollReduceDoubleRingExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_double_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_double_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..7b515107941f23e4ad2cd6514db802e81ff3e55f --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_double_ring_executor.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCE_DOUBLE_RING_EXECUTOR_H +#define COLL_REDUCE_DOUBLE_RING_EXECUTOR_H +#include "coll_reduce_executor.h" +namespace hccl { +class CollReduceDoubleRingExecutor : public CollReduceExecutor { + +public: + CollReduceDoubleRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceDoubleRingExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcStreamNum(u32& streamNum); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; + + bool meshSinglePlane_ = false; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..71c5493375a7c8c05b3b70f3a272c6fb99e6218f --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_executor.cc @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_reduce_executor.h" + +namespace hccl { + +CollReduceExecutor::CollReduceExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceExecutor::Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + + tag_ = param.tag; + if (UseInterServerHDAlgo(algType_)) { + u32 part1Size = 2 * (topoAttr_.moduleNum - (1 << static_cast(log2(topoAttr_.moduleNum)))); + u32 rootId = param.root / topoAttr_.deviceNumPerAggregation; + std::string appendTag = std::to_string((rootId >= part1Size) || ((rootId % 2) == 0)); + tag_ = param.tag + '_' + appendTag; + if (param.opBaseAtraceInfo != nullptr) { + CHK_RET(param.opBaseAtraceInfo->SavealgtypeTraceInfo(appendTag, param.tag)); + } + } + + algResResp_ = &algRes; + GetStreamInfo(algRes); + auto rtStream = param.stream.ptr(); + HCCL_PROFILER_ADD_TAG(tag_, algoAttr_.identifier, GetWorkflowMode()); + HCCL_PROFILER_ADD_STREAM(rtStream, tag_, 0, algType_); + HCCL_PROFILER_ADD_OPDATA(tag_, param.DataDes.count, param.inputPtr, param.outputPtr, param.DataDes.dataType, \ + param.root, algoAttr_.identifier); + HCCL_PROFILER_ADD_GROUPRANK(algoAttr_.identifier, topoAttr_.userRankSize, topoAttr_.userRank); + CHK_RET(AddSubStreamToProfiling()); + + HcclResult ret = HCCL_SUCCESS; + // 图模式和单卡场景下不需要Loop + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + execMem.inputMem = algRes.paramInputMem; + execMem.outputMem = algRes.paramOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } else if (topoAttr_.userRankSize == 1) { + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } else { + ret = RunLoop(param, algRes); + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceExecutor][Orchestrate]errNo[0x%016llx]reudce excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_DEL_STREAM(rtStream); + HCCL_PROFILER_DEL_TAG(tag_); + HCCL_PROFILER_DEL_OPDATA(tag_); + HCCL_PROFILER_DEL_GROUPRANK(tag_); + } + + HCCL_INFO("tag[%s], Reduce executor orchestrate success, take time [%lld]us.", tag_.c_str(), + DURATION_US(TIME_NOW() - startut)); + + return HCCL_SUCCESS; +} + +HcclResult CollReduceExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + ReduceType reduceType = ((param.reduceType != HCCL_REDUCE_PROD) && + (param.DataDes.dataType != HCCL_DATA_TYPE_INT64)) ? + ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; + + u8 *curInputPtr = static_cast(param.inputPtr); + u8 *curOutputPtr = static_cast(param.outputPtr); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + + u64 maxCountPerLoop = CalcLoopMaxCount(unitSize, algRes); // override + + HCCL_DEBUG("[CollReduceExecutor][RunLoop]tag[%s], userRankSize is [%llu], maxCountPerLoop is [%llu].", + tag_.c_str(), topoAttr_.userRankSize, maxCountPerLoop); + + u64 inputOffset = 0; + u64 outputOffset = 0; + u64 countLeft = param.DataDes.count; + while (countLeft > 0) { + curInputPtr += inputOffset; + curOutputPtr += outputOffset; + // 判断剩余数据量对应的output size是否大于中转output size + u64 curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; + u64 curSize = curCount * unitSize; // 单位:字节 + + HCCL_DEBUG("[CollReduceExecutor][RunLoop]tag[%s], inputOffset[%llu], outputOffset[%llu], " \ + "sendBuf[%p], recvBuf[%p], sendCount[%llu], dataType[%d].", + tag_.c_str(), inputOffset, outputOffset, curInputPtr, curOutputPtr, curCount, param.DataDes.dataType); + + ExecMem execMem; + execMem.count = curCount; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + // 使用当前Loop偏移到的地址作为当前的inputPtr和outputPtr + execMem.inputPtr = curInputPtr; + execMem.outputPtr = curOutputPtr; + + CHK_RET(RunLoopInner(param, reduceType, execMem)); + + countLeft -= curCount; + inputOffset = curSize; + outputOffset = curSize; + } + return HCCL_SUCCESS; +} + +HcclResult CollReduceExecutor::RunLoopInner(const OpParam ¶m, const ReduceType &reduceType, ExecMem &execMem) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + u64 curSize = execMem.count * unitSize; // 单位:字节 + HCCL_DEBUG("[CollReduceExecutor][RunLoopInner]inputMem[%p][%llu], outputMem[%p][%llu], " \ + "intputPtr[%p], outputPtr[%p], curCount[%llu], curSize[%llu]", + execMem.inputMem.ptr(), execMem.inputMem.size(), execMem.outputMem.ptr(), execMem.outputMem.size(), + execMem.inputPtr, execMem.outputPtr, execMem.count, curSize); + CHK_PRT_RET((execMem.count == 0), + HCCL_ERROR("[CollAllReduceExecutor][RunLoop]In OP_BASE curCount is zero."), HCCL_E_PARA); + + /* 设置子图复用标志 */ + bool isRootRank = param.root == topoAttr_.realUserRank ? true : false; + auto autoSelectedAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + bool hugeData = IsHugeData(curSize); // override + auto opMeta = + HcclOpMetaInfo::GetOneForReduce(isRootRank, param.root, autoSelectedAlgTypeLevel1, param.DataDes.dataType, reduceType, hugeData); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + /* 记录指令信息用于一致性校验 */ + CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_REDUCE, + tag_, execMem.count, param.DataDes.dataType, param.reduceType, param.root, execMem.inputMem.size(), + execMem.outputMem.size())); + + execMem.inputMem = DeviceMem::create(execMem.inputMem.ptr(), curSize); + execMem.outputMem = DeviceMem::create(execMem.outputMem.ptr(), curSize); + + // 执行 + // 如果使用in CCL buffer,需要将user buffer in中的结果拷贝到CCL buffer in + DeviceMem inMem(execMem.inputPtr, curSize); + DeviceMem inCommMem = execMem.inputMem.range(0, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, inCommMem, inMem, const_cast(param.stream))); + HCCL_DEBUG("[CollReduceExecutor][RunLoop]copy from user in to ccl in."); + + HcclResult ret = KernelRun(param, execMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllReduceExecutor][RunLoop]errNo[0x%016llx]kernel run error, tag[%s], " \ + "inputMem ptr[%p], outputMem ptr[%p], count[%llu], dataType[%d], reduce op type[%d]", + HCCL_ERROR_CODE(ret), tag_.c_str(), execMem.inputMem.ptr(), execMem.outputMem.ptr(), + execMem.count, param.DataDes.dataType, param.reduceType), + ret); + + if (topoAttr_.realUserRank == param.root) { // 只root rank需要把数据从中转内存拷贝出去 + DeviceMem outMem(execMem.outputPtr, curSize); + DeviceMem outCommMem = execMem.outputMem.range(0, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outMem, outCommMem, const_cast(param.stream))); + } + + CHK_RET(RankConsistent::GetInstance().DelOpPara(tag_)); + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + return ret; +} + +u64 CollReduceExecutor::CalcLoopMaxCount(const u32 unitSize, const AlgResourceResponse& algRes) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = algRes.cclInputMem.size() / unitSize; + HCCL_WARNING("[CollReduceExecutor][CalcLoopMaxCount]" \ + "using default maxCountPerLoop[%llu] as CCLBuffSize / unitSize.", maxCountPerLoop); + return maxCountPerLoop; +} + +bool CollReduceExecutor::IsHugeData(const u64 curSize) +{ + HCCL_WARNING("[CollReduceExecutor][IsHugeData]opMeta is using the default option."); + bool hugeData = (curSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + return hugeData; +} +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c1b21ed6497532c60c30d7ff91e1e6c09dc96034 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_executor.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCE_EXECUTOR_H +#define COLL_REDUCE_EXECUTOR_H +#include "coll_comm_executor.h" + +namespace hccl { +class CollReduceExecutor : public CollCommExecutor { + +public: + CollReduceExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; +protected: + /* *************** 算法编排 *************** */ + // Reduce Loop Executor公共接口 + virtual u64 CalcLoopMaxCount(const u32 unitSize, const AlgResourceResponse& algRes); + virtual bool IsHugeData(const u64 curSize); + HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + +private: + HcclResult RunLoopInner(const OpParam ¶m, const ReduceType &reduceType, ExecMem &execMem); +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_mesh_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..49cd04e1c45ca4974dbc21793126fefcae6404db --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_mesh_executor.cc @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_reduce_mesh_executor.h" + +namespace hccl { + +CollReduceMeshExecutor::CollReduceMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceMeshExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation > 1U ? topoAttr_.deviceNumPerAggregation - 1U : 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollReduceMeshExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceMeshExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcLevel0CommInfo(inputType, outputType, opTransport); + CalcLevel1CommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollReduceMeshExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollReduceMeshExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceMeshExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = meshSinglePlane_; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + std::vector dataSegsSlicePerDie; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + std::unique_ptr innerExecutor; + std::unique_ptr outer2Executor; + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + u32 sliceNum = outerCommInfo.localRankSize; + // 根据数据量算每个环上数据的偏移和大小 + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + // mesh算法stream数量为server内rank数减1 + + ActiveSlaveStreams(param.stream); + + if (topoMatcher_->GetExternalInputHcclDeterministic() == DETERMINISTIC_CONFIG_DISABLE && (param.DataDes.dataType != HCCL_DATA_TYPE_INT64) && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B && param.reduceType != HCCL_REDUCE_PROD)) { + CHK_RET(MultiStreamReduceScatterMeshAtomic(tag_, execMem.inputMem, execMem.outputMem, execMem.count, param.DataDes.dataType, param.reduceType, + dataSegsSlice, const_cast(param.stream), COMM_LEVEL0)); + } else { + std::vector> multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + // mesh算法stream数量为rank数减1 + CHK_RET(ExecutorBase::PrepareSliceMeshStreams(dataSegsSlice, sliceNum - 1, multiStreamSlice)); + CHK_RET(MultiStreamReduceScatterMesh(tag_, execMem.inputMem, execMem.outputMem, execMem.count, param.DataDes.dataType, param.reduceType, multiStreamSlice, + const_cast(param.stream), COMM_LEVEL0)); + } + HCCL_INFO("reduce mesh stage0 run success"); + + // step2: 节点间的reduce + u32 commIndex = outerCommInfo.localRank; + CHK_PRT_RET(commIndex >= dataSegsSlice.size(), HCCL_ERROR("[CollReduceMeshExecutor][Run]commIndex[%u] >= dataSegsSlice size[%llu]", + commIndex, dataSegsSlice.size()), HCCL_E_INTERNAL); + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + HCCL_DEBUG("commIdx:%u TagCommInfo[%s].commInner.size():%llu", commIndex, tag_.c_str(), + innerCommInfo.links.size()); + + DeviceMem reduceInput = execMem.inputMem.range(dataSegsSlice[commIndex].offset, dataSegsSlice[commIndex].size); + CHK_SMART_PTR_NULL(reduceInput); + DeviceMem reduceOutput = execMem.outputMem.range(dataSegsSlice[commIndex].offset, dataSegsSlice[commIndex].size); + CHK_SMART_PTR_NULL(reduceOutput); + + u32 rankSize = innerCommInfo.localRankSize; + if (rankSize > 1) { + u64 reduceAttr = GetReduceAttr(reduceInput, reduceOutput, param.DataDes.dataType, param.reduceType); + + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + + CHK_PRT_RET(subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[ReduceOperator][ReduceMeshExecutor]subUserrankRoot[%u] is invalid,userRank[%u],root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), HCCL_E_INTERNAL); + + u32 planeRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, planeRoot)); + + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceRing(dispatcher_, reduceAttr)); + } else { + innerExecutor.reset(new (std::nothrow) ReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + } + CHK_SMART_PTR_NULL(innerExecutor); + // 节点间的hd 使用环0来记录 + u64 hdCount = dataSegsSlice[commIndex].size / perDataSize; + + CHK_RET(innerExecutor->Prepare(reduceInput, reduceOutput, reduceOutput, hdCount, param.DataDes.dataType, param.stream, + param.reduceType, planeRoot, std::vector(0), dataSegsSlice[commIndex].offset)); + + CHK_RET(innerExecutor->RegisterProfiler((innerCommInfo.localRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } else { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, reduceOutput, reduceInput, const_cast(param.stream))); + } + + HCCL_INFO("reduce mesh stage1 run success"); + + SingleSubCommTransport &outerTransportInfo = + const_cast(algResResp_->opTransportResponse[COMM_LEVEL0][COMM_INDEX_0]); + + if (outerTransportInfo.userRank2subCommRank.find(param.root) != + outerTransportInfo.userRank2subCommRank.end()) { + const u32 rootRank = outerTransportInfo.userRank2subCommRank[param.root]; + + std::unique_ptr outerExecutor; + outerExecutor.reset(new (std::nothrow) GatherMesh(dispatcher_, streamInfo_.ringStreams, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, topoAttr_.userRank)); + CHK_SMART_PTR_NULL(outerExecutor); + CHK_RET(outerExecutor->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, execMem.count, + param.DataDes.dataType, const_cast(param.stream), param.reduceType, rootRank, dataSegsSlice)); + + u32 rankSize = outerCommInfo.localRankSize; + CHK_RET(outerExecutor->RegisterProfiler((0 << PROF_RINGINDEX_OFFSET_OF_PLANEID) + + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(outerExecutor, outerCommInfo)); + } + HCCL_INFO("reduce mesh stage2 run success"); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceMeshExecutor", ReduceMesh, CollReduceMeshExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_mesh_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..415fcbaa26c394581b487b634de10cd11b68f800 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_mesh_executor.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCE_MESH_EXECUTOR_H +#define COLL_REDUCE_MESH_EXECUTOR_H +#include "coll_reduce_executor.h" +namespace hccl { +class CollReduceMeshExecutor : public CollReduceExecutor { + +public: + CollReduceMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceMeshExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcStreamNum(u32& streamNum); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; + + bool meshSinglePlane_ = false; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_ring_plus_hd_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_ring_plus_hd_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f526927e5edd09656e8a78d05916026fa7d0cc65 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_ring_plus_hd_executor.cc @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_reduce_ring_plus_hd_executor.h" + +namespace hccl { + +CollReduceRingPlusHdExecutor::CollReduceRingPlusHdExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceRingPlusHdExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 1U; + switch (algType_) { + case AlgType::ALG_8P_RING_PLUS_HD: + case AlgType::ALG_8P_RING_PLUS_RING: + case AlgType::ALG_8P_RING_PLUS_NHR: + case AlgType::ALG_8P_RING_PLUS_NHR_V1: + case AlgType::ALG_8P_RING_PLUS_NB: + case AlgType::ALG_8P_RING_PLUS_PIPELINE: + totalStreamNum = OUTER_PLANE_NUM_IN_8PRING; + break; + case AlgType::ALG_NP_SINGLE_RING_PLUS_RING: + case AlgType::ALG_NP_SINGLE_RING_PLUS_HD: + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + break; + default: + break; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollReduceRingPlusHdExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceRingPlusHdExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcLevel0CommInfo(inputType, outputType, opTransport); + CalcLevel1CommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollReduceRingPlusHdExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollReduceRingPlusHdExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceRingPlusHdExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollReduceRingPlusHdExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollReduceRingPlusHdExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollReduceRingPlusHdExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > mulRingSlice; // 数据基于该rank上环0的偏移 + + // step1: 节点内的reducescatter + u32 ringNum = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? OUTER_PLANE_NUM_IN_8PRING : + OUTER_PLANE_NUM_IN_NPRING_SINGLE; + + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + // 按ranksize得到内存切分slice数为8 + u32 sliceNum = outerCommInfo.localRankSize; + CHK_RET(ExecutorBase::PrepareSliceData(execMem.count, perDataSize, sliceNum, 0, dataSegsSlice)); + + /* 外层:reducescatter */ + // 将每slice再切分成4份,按各ring的dev顺序排列 + if (ringNum == OUTER_PLANE_NUM_IN_8PRING) { + // 构造ring algorithm对应的reduce-scatter实例 + mulRingSlice = PrepareMultiRingSlice(dataSegsSlice, tag_, false, topoAttr_.nicList); + CHK_PRT_RET(mulRingSlice.size() != ringNum, HCCL_ERROR("[CollReduceRingPlusHdExecutor]ringNum[%u] "\ + "!=mulRingSlice size[%llu]", ringNum, mulRingSlice.size()), HCCL_E_INTERNAL); + } else { + mulRingSlice.push_back(dataSegsSlice); // 应该offset全为0,而大小和dataSegsSlice中一样,里面的offset不使用 + } + CHK_RET(MultiRingReduceScatter(tag_, execMem.inputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.reduceType, mulRingSlice, param.stream, + PROF_STAGE_0, 0, nullptr)); + + HCCL_INFO("reduce 8PringHD stage0 run success"); + + // step2: 节点间的reduce + u64 hdSize; + u32 segmentIdx; + u32 commIndex; + CHK_RET(PrepareInnerCommInfo(segmentIdx, commIndex, hdSize, outerCommInfo, mulRingSlice, tag_)); + + u64 hdCount = hdSize / perDataSize; + + HCCL_DEBUG("commIdx:%u TagCommInfo[%s].commInner.size():%llu", commIndex, tag_.c_str(), + outerCommInfo.localRankSize); + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + DeviceMem reduceInput = execMem.inputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + CHK_SMART_PTR_NULL(reduceInput); + DeviceMem reduceOutput = execMem.outputMem.range(dataSegsSlice[segmentIdx].offset, hdSize); + CHK_SMART_PTR_NULL(reduceOutput); + + u64 reduceAttr = GetReduceAttr(reduceInput, reduceOutput, param.DataDes.dataType, param.reduceType); + + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceRing(dispatcher_, reduceAttr)); + } else { + innerExecutor.reset(new (std::nothrow) ReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + } + CHK_SMART_PTR_NULL(innerExecutor); + + u32 subUserrankRoot = topoMatcher_->GetSubRootUserRank(topoAttr_.userRank, param.root); + CHK_PRT_RET(subUserrankRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[ReduceOperator][ReduceRingPlusHd]subUserrankRoot[%u] is invalid,userRank[%u],root[%u]", + subUserrankRoot, topoAttr_.userRank, param.root), HCCL_E_INTERNAL); + + u32 planeRoot = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL1, commIndex, subUserrankRoot, planeRoot)); + + u32 ranksize = innerCommInfo.localRankSize; + // 节点间的hd 使用环0来记录 + CHK_RET(innerExecutor->Prepare(reduceInput, reduceOutput, reduceOutput, hdCount, param.DataDes.dataType, + param.stream, param.reduceType, planeRoot, std::vector(0), dataSegsSlice[segmentIdx].offset)); + + CHK_RET(innerExecutor->RegisterProfiler((ranksize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, \ + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + HCCL_INFO("reduce 8PringHD stage1 run success"); + + // step3: 节点内的gatherring,只有在root所在server内进行gather操作 + SingleSubCommTransport &outerTransportInfo = + const_cast(algResResp_->opTransportResponse[COMM_LEVEL0][COMM_INDEX_0]); + + if (outerTransportInfo.userRank2subCommRank.find(param.root) != + outerTransportInfo.userRank2subCommRank.end()) { + CHK_RET(MultiRingGather(tag_, execMem.outputMem, execMem.outputMem, hdCount, param.DataDes.dataType, + mulRingSlice, param.reduceType, param.root, const_cast(param.stream), PROF_STAGE_2)); + } + HCCL_INFO("reduce 8PringHD stage2 run success"); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceRingPlusHd", ReduceRingPlusHd, CollReduceRingPlusHdExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_ring_plus_hd_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_ring_plus_hd_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..d4f19eb5bc2d515120935a5508bbee52884d0ab6 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_ring_plus_hd_executor.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCE_RING_PLUS_HD_EXECUTOR_H +#define COLL_REDUCE_RING_PLUS_HD_EXECUTOR_H +#include "coll_reduce_executor.h" +namespace hccl { +class CollReduceRingPlusHdExecutor : public CollReduceExecutor { + +public: + CollReduceRingPlusHdExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceRingPlusHdExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcStreamNum(u32& streamNum); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; + + bool meshSinglePlane_ = false; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_single_rank_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_single_rank_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..7456bdf1d7c47ef9cceb5aeac314b760efdc17a6 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_single_rank_executor.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_single_rank_executor.h" + +namespace hccl { + +CollReduceSingleRankExecutor::CollReduceSingleRankExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceSingleRankExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u64 totalSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + ReduceType reduceType = + ((param.reduceType != HCCL_REDUCE_PROD) && (param.DataDes.dataType != HCCL_DATA_TYPE_INT64)) ? + ReduceType::INLINE_REDUCE : + ReduceType::TBE_REDUCE; + auto autoSelectedAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + bool isRootRank = param.root == topoAttr_.realUserRank ? true : false; + bool hugeData = IsHugeData(totalSize); // override + + auto opMeta = HcclOpMetaInfo::GetOneForReduce(isRootRank, param.root, autoSelectedAlgTypeLevel1, + param.DataDes.dataType, reduceType, hugeData); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + + DeviceMem srcMem(execMem.inputPtr, totalSize); + DeviceMem dstMem(execMem.outputPtr, totalSize); + HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream)); + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceSingleExecutor", ReduceSingleRank, CollReduceSingleRankExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_single_rank_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_single_rank_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..b6a8bbd0c73fce5fbda47c4b8493bfea3253233a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce/coll_reduce_single_rank_executor.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCE_SINGLE_RANK_EXECUTOR_H +#define COLL_REDUCE_SINGLE_RANK_EXECUTOR_H +#include "coll_reduce_executor.h" +namespace hccl { +class CollReduceSingleRankExecutor : public CollReduceExecutor { + +public: + CollReduceSingleRankExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceSingleRankExecutor() = default; + +private: + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e3aca1fc22de063611cf39b1d108cdb402a33c4 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/CMakeLists.txt @@ -0,0 +1,7 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_for_310p_ring_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/coll_reduce_scatter_for_310p_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/coll_reduce_scatter_for_310p_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f2ecb04eb02aafc7e7f1dce34a9cb8173a06609 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/coll_reduce_scatter_for_310p_ring_executor.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_for_310p_ring_executor.h" + +namespace hccl { +CollReduceScatterFor310PRingExecutor::CollReduceScatterFor310PRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollReduceScatterFor310PRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcLevel0CommInfo(inputType, outputType, opTransport); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterFor310PRingExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollReduceScatterFor310PRingExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterFor310PRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollReduceScatterFor310PRingExecutor][CalcLevel0CommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaInfo(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollReduceScatterFor310PRingExecutor][CalcLevel0CommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterFor310PRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + bool isInlineReduce = IsSupportSDMAReduce(execMem.inputMem.ptr(), execMem.outputMem.ptr(), param.DataDes.dataType, + param.reduceType); + u64 reduceAttr = 0; + if (isInlineReduce) { + SalSetBitOne(reduceAttr, ATTR_POS_INLINE_REDUCE); + } + + std::unique_ptr executor; + executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(executor); + + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType)); + + CHK_RET(executor->RegisterProfiler( + (outerCommInfo.localRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(executor, outerCommInfo)); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterRing", ReduceScatterFor310PRing, CollReduceScatterFor310PRingExecutor); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/coll_reduce_scatter_for_310p_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/coll_reduce_scatter_for_310p_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..97a8eb2e083d14a384dadd8f963c8ca70afeb0ab --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/310P/coll_reduce_scatter_for_310p_ring_executor.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_FOR_310P_RING_EXECUTOR_H +#define COLL_REDUCESCATTER_FOR_310P_RING_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterFor310PRingExecutor : public CollReduceScatterExecutor { + +public: + explicit CollReduceScatterFor310PRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterFor310PRingExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b846641219a500cf8d64574c99c7c7424e4b2ea --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/CMakeLists.txt @@ -0,0 +1,18 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_comm_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_deter_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_ring_for_910_73_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_mesh_dma_elimination.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_mesh_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_mesh_opbase_pipeline_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_single_rank_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_reduce_scatter_double_ring_concurrent_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) + +add_subdirectory(310P) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_comm_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..6462eab8fc00b4595a15e9cb1a9792a6f69d772a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_comm_executor.cc @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_comm_executor.h" + +namespace hccl { + +CollReduceScatterCommExecutor::CollReduceScatterCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +void CollReduceScatterCommExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; + + // 是否需要scratch memory + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B || topoAttr_.deviceType == DevType::DEV_TYPE_910_73) && + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType) && + IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType)) { + scratchMemFlag_ = false; + } else { + scratchMemFlag_ = true; + } + + // 记录图模式总数据量 + totalSize_ = topoAttr_.userRankSize * param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; +} + +HcclResult CollReduceScatterCommExecutor::CalcScratchMemSize(u64& scratchMemSize) +{ + if (scratchMemFlag_) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + scratchMemSize = inCCLbufferSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } else { + scratchMemSize = totalSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } + } else { + scratchMemSize = 0U; + } + + HCCL_INFO("[CollReduceScatterCommExecutor][CalcScratchMemSize] tag[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterCommExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcCombinedCommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterCommExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::CCL_OUTPUT; + } + } else { + inputType = TransportMemType::PARAM_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::PARAM_OUTPUT; + } + } + HCCL_INFO("[CollReduceScatterCommExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterCommExecutor::CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_COMBINE, CommType::COMM_TAG_MAX); + if (UseInterServerNHRAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING; + } else if (UseInterServerNHRV1Algo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING_V1; + } else if (UseInterServerNBAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_BRUCK; + } else { + commParaInfo.commType = CommType::COMM_TAG_RING_INNER; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_COMBINE], inputType, outputType)); + + return HCCL_SUCCESS; +} + +u64 CollReduceScatterCommExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = inCCLbufferSize_ / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +bool CollReduceScatterCommExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = (curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + return hugeData; +} + +HcclResult CollReduceScatterCommExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + CHK_RET(CheckCommSize(COMM_COMBINE, COMM_INDEX_0 + 1)); + SubCommInfo combinedCommInfo = GetSubCommInfo(COMM_COMBINE, COMM_INDEX_0); + + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.outputMem, param.DataDes.dataType, param.reduceType); + + // 构造ring algorithm对应的reduce-scatter实例 + std::unique_ptr executor; + if (UseInterServerNHRAlgo(algType_)) { + executor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter comm: using nhr algo inter-server."); + CHK_SMART_PTR_NULL(executor); + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType)); + } else if (UseInterServerNHRV1Algo(algType_)) { + executor.reset(new (std::nothrow) ReduceScatterNHRV1(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter comm: using nhr_v1 algo inter-server."); + CHK_SMART_PTR_NULL(executor); + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType)); + CHK_RET(RunTemplate(executor, combinedCommInfo)); + } else if (UseInterServerNBAlgo(algType_)) { + executor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter comm: using nonuniform-bruck algo inter-server."); + CHK_SMART_PTR_NULL(executor); + CHK_RET(executor->Prepare(execMem.inputMem, execMem.outputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType)); + CHK_RET(RunTemplate(executor, combinedCommInfo)); + } else { + executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter comm: using ring algo inter-server."); + CHK_SMART_PTR_NULL(executor); + CHK_RET(executor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType)); + CHK_RET(RunTemplate(executor, combinedCommInfo)); + // 将cclInBuffer中与userRank_对应的部分拷贝至cclOutBuffer + u64 dataSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + DeviceMem srcMem = execMem.inputMem.range(dataSize * topoAttr_.userRank, dataSize); + DeviceMem dstMem = execMem.outputMem.range(0, dataSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + } + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterComm", ReduceScatterComm, CollReduceScatterCommExecutor); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_comm_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..b2dde5467e235e23868c93c72c3bd303524ab112 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_comm_executor.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_COMM_EXECUTOR_H +#define COLL_REDUCESCATTER_COMM_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterCommExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterCommExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_deter_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_deter_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ad212f2a874e5b9e3da021aef4824ff1e5e0b11 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_deter_executor.cc @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_deter_executor.h" + +namespace hccl { + +CollReduceScatterDeterExecutor::CollReduceScatterDeterExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = true; + CCLMemSlice_ = false; +} + +void CollReduceScatterDeterExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; + + // 910B 图模式非确定计算,inlineReduce使能,MESH拓扑场景下,创建一个mesh平面 + meshSinglePlane_ = false; + + // 是否需要scratch memory 选中确定性计算Executor,其他条件必定满足,只需区分是否为图模式 + scratchMemFlag_ = (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); + + // 记录图模式总数据量 + totalSize_ = topoAttr_.userRankSize * param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; +} + +HcclResult CollReduceScatterDeterExecutor::CalcScratchMemSize(u64& scratchMemSize) +{ + if (scratchMemFlag_) { // 确定性计算只有图模式需要scratch memory + scratchMemSize = totalSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } else { + scratchMemSize = 0U; + } + HCCL_INFO("[CollReduceScatterDeterExecutor][CalcScratchMemSize] tag[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDeterExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + totalStreamNum = topoAttr_.deviceNumPerAggregation - 1U; + } else { + totalStreamNum = topoAttr_.deviceNumPerAggregation; + } + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollReduceScatterDeterExecutor][CalcStreamNum] tag[%s] streamNum[%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDeterExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDeterExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::CCL_OUTPUT; + } + } else { + inputType = TransportMemType::PARAM_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::PARAM_OUTPUT; + } + } + HCCL_INFO("[CollReduceScatterDeterExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDeterExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = meshSinglePlane_; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterDeterExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + u64 maxCountPerLoop = (inCCLbufferSize_ - HCCL_MIN_SLICE_ALIGN_910B * topoAttr_.deviceNumPerAggregation) / + unitSize / (topoAttr_.deviceNumPerAggregation - 1); + maxCountPerLoop = maxCountPerLoop / HCCL_MIN_SLICE_ALIGN_910B; + maxCountPerLoop = maxCountPerLoop * HCCL_MIN_SLICE_ALIGN_910B; + return maxCountPerLoop; +} + +bool CollReduceScatterDeterExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +bool CollReduceScatterDeterExecutor::IsSmallData(const u64 totalSize, const u64 curSize) +{ + bool smallData = curSize <= HCCL_SMALL_COUNT_32_KB; + return smallData; +} + +HcclResult CollReduceScatterDeterExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::unique_ptr outerExecutor; + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + CHK_RET(ActiveSlaveStreams(param.stream)); + + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.outputMem, param.DataDes.dataType, param.reduceType); + HcomCollOpInfo opInfo = {"", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, + param.root, param.reduceType}; + + if ((param.DataDes.count * unitSize > HCCL_SMALL_COUNT_32_KB) || + (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) || + ((topoAttr_.deviceNumPerAggregation != DEVICE_EIGHT) && (topoAttr_.deviceNumPerAggregation != DEVICE_FOUR))) { + outerExecutor.reset(new (std::nothrow) ReduceScatterLocalReduce(dispatcher_, reduceAttr, + streamInfo_.ringStreams, streamInfo_.ringSignal, streamInfo_.ringSignalAux, topoAttr_.userRank, &opInfo)); + } else { + outerExecutor.reset(new (std::nothrow) ReduceScatterHDStage(dispatcher_, reduceAttr, streamInfo_.ringStreams, + streamInfo_.ringSignal, streamInfo_.ringSignalAux, topoAttr_.userRank, &opInfo)); + } + + CHK_SMART_PTR_NULL(outerExecutor); + CHK_RET(outerExecutor->Prepare(execMem.inputMem, execMem.scratchMem, execMem.outputMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, dataSegsSlice, 0)); + + CHK_RET(outerExecutor->RegisterProfiler( + (outerCommInfo.localRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + outerCommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(outerExecutor, outerCommInfo)); + HCCL_INFO("reducescatter mesh deter run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterDeterExecutor", ReduceScatterDeter, CollReduceScatterDeterExecutor); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_deter_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_deter_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..798fcff4c9bbcf40ff4bfe5b19a2f766c5ae53ef --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_deter_executor.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_DETER_EXECUTOR_H +#define COLL_REDUCESCATTER_DETER_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterDeterExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterDeterExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterDeterExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + bool IsSmallData(const u64 totalSize, const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; + + bool meshSinglePlane_ = false; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_double_ring_concurrent_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_double_ring_concurrent_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f250303061bb1fad08a60ad9d093b2878ab0cd3 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_double_ring_concurrent_executor.cc @@ -0,0 +1,562 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_double_ring_concurrent_executor.h" + +namespace hccl { + +CollReduceScatterDoubleRingConcurrentExecutor::CollReduceScatterDoubleRingConcurrentExecutor( + const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +void CollReduceScatterDoubleRingConcurrentExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; + + // 是否需要scratch memory + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType) && + IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType)) { + scratchMemFlag_ = false; + } else { + scratchMemFlag_ = true; + } + + // 记录图模式总数据量 + totalSize_ = topoAttr_.userRankSize * param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::CalcScratchMemSize(u64& scratchMemSize) +{ + if (scratchMemFlag_) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + scratchMemSize = inCCLbufferSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } else { + scratchMemSize = totalSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } + } else { + scratchMemSize = 0U; + } + HCCL_INFO("[CollReduceScatterDoubleRingConcurrentExecutor][CalcScratchMemSize] tag[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 0U; + // DoubleRing只支持910_73场景 + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + totalStreamNum += RDMA_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollReduceScatterDoubleRingConcurrentExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::CCL_OUTPUT; + } + } else { + inputType = TransportMemType::PARAM_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::PARAM_OUTPUT; + } + } + HCCL_INFO("[CollReduceScatterDoubleRingConcurrentExecutor][CalcTransportMemType] tag[%s] " + "inputType[%d], outputType[%d]", tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + commParaLevel0.forceRdma = false; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + CommParaInfo commParaLevel0Rdma(COMM_LEVEL0_RDMA, CommType::COMM_TAG_RING_INNER); + commParaLevel0Rdma.forceRdma = true; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0Rdma, opTransport[COMM_LEVEL0_RDMA], + inputType, outputType)); + } + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterDoubleRingConcurrentExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count,放开ranksize限制 + u64 maxCountPerLoop = inCCLbufferSize_ / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +bool CollReduceScatterDoubleRingConcurrentExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = (curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + return hugeData; +} + +u32 CollReduceScatterDoubleRingConcurrentExecutor::CalcDataSplit(const u64 curSize) +{ + u32 dataSplit = 0; + u64 dataValue = curSize * topoAttr_.userRankSize; + if ((topoAttr_.serverNum > 1) && ((dataValue / topoAttr_.serverNum) <= HCCL_SDMA_RDMA_SPLIT_SIZE)) { + dataSplit = 1; + } else if (dataValue <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + dataSplit = HCCL_SPLIT_FLAG; + } + return dataSplit; +} + +HcclResult CollReduceScatterDoubleRingConcurrentExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun] The ReduceScatterDoubleRingConcurrentExecutor" + "starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + CHK_RET(CheckCommSize(COMM_LEVEL0, ringNum)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + + u32 sliceNum = outerCommInfo.localRankSize; + Slice sliceTemp; + u32 commIndex = outerCommInfo.localRank; + commIndex = RefreshCommIdx(commIndex, topoAttr_.nicList, topoAttr_.devicePhyId); + + /* 超节点间通信域是commLevel2 */ + CHK_RET(CheckCommSize(COMM_LEVEL2, COMM_INDEX_0 + 1)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0); + u32 level2RankSize = level2CommInfo.localRankSize; + + if (level2RankSize > 1) { + /* ****************** 超节点间 reducescatter *******************************/ + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.scratchMem, param.DataDes.dataType, param.reduceType); + std::unique_ptr level2Executor; + + if (UseLevel2RingAlgo(algType_)) { + level2Executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-superPod."); + CHK_SMART_PTR_NULL(level2Executor); + + u64 ringCount = execMem.inputMem.size() / (level2RankSize * perDataSize); + CHK_RET(level2Executor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } else { + level2Executor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-superPod."); + + CHK_SMART_PTR_NULL(level2Executor); + u64 inputDataCount = execMem.inputMem.size() / perDataSize; // count是output的数据个数 + CHK_RET(level2Executor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } + CHK_RET(level2Executor->RegisterProfiler( + (level2RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2Executor, level2CommInfo)); + + /* ****************** 节点间 reducescatter *******************************/ + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 innerRankSize = innerCommInfo.localRankSize; + + if (innerRankSize > 1) { + std::unique_ptr innerExecutor; + u32 level1Index = innerCommInfo.localRank; + + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / (innerRankSize * level2RankSize); + u64 ringCount = ringSize / perDataSize; + u64 level1SliceOffset = ringSize * level1Index; + DeviceMem level1InputMem = execMem.inputMem.range(level1SliceOffset, ringSize); + CHK_SMART_PTR_NULL(level1InputMem.ptr()); + + CHK_RET(innerExecutor->Prepare(level1InputMem, level1InputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0), + level1SliceOffset)); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + + CHK_SMART_PTR_NULL(innerExecutor); + u64 inputDataCount = execMem.inputMem.size() / (perDataSize * level2RankSize); + u64 level1SliceSize = execMem.inputMem.size() / level2RankSize; + u64 level1SliceOffset = level1SliceSize * level1Index; + + DeviceMem level1InputMem = execMem.inputMem.range(level1SliceOffset, level1SliceSize); + // count是output的数据个数 + CHK_RET(innerExecutor->Prepare(level1InputMem, level1InputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0), + level1SliceOffset)); + } + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } + + /* *********** 节点内reducescatter (正常场景) *****************************/ + CHK_RET(ActiveSlaveStreams(param.stream)); + + bool useInlineRduce = false; + bool isInlineReduce = IsSupportSDMAReduce(execMem.inputMem.ptr(), execMem.scratchMem.ptr(), + param.DataDes.dataType, param.reduceType); + useInlineRduce = isInlineReduce && algoAttr_.inlineReduceSwitchOn; + multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, execMem.outputMem, + dataSegsSlice, param.tag); + bool bRet = (multiStreamSlice.size() != ringNum); + CHK_PRT_RET(bRet, + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]sliceNum-1[%u] != multiStreamSlice" \ + "size[%llu]", sliceNum - 1, multiStreamSlice.size()), HCCL_E_INTERNAL); + + DeviceMem srcMem; + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / (innerRankSize * level2RankSize); + // 每个服务器对应的偏移 + u32 serverIndex = innerCommInfo.localRank; + u64 serverSliceOffset = serverSliceSize * serverIndex; + HCCL_DEBUG("inputMem.size=%llu, outerCommInfo.localRankSize=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ + "commIndex=%u innerCommInfo.localRank=%u", execMem.inputMem.size(), outerCommInfo.localRankSize, + serverSliceSize, serverSliceOffset, commIndex, innerCommInfo.localRank); + DeviceMem reduceScatterRingInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + DeviceMem reduceScatterRingOutput = execMem.scratchMem.range(serverSliceOffset, serverSliceSize); + + u64 countLocal = serverSliceSize / perDataSize; + CHK_RET(MultiRingReduceScatter(param.tag, reduceScatterRingInput, reduceScatterRingOutput, countLocal, + param.DataDes.dataType, param.reduceType, multiStreamSlice, param.stream, PROF_STAGE_1, serverSliceOffset)); + + srcMem = execMem.inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, + execMem.count * perDataSize); + CHK_SMART_PTR_NULL(srcMem.ptr()); + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + HCCL_INFO("reducescatter double ring concurrent run success"); + return HCCL_SUCCESS; + } + + /* ****************** 节点间 reducescatter *******************************/ + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 innerRankSize = innerCommInfo.localRankSize; + if (innerRankSize > 1) { + std::vector innerDataSegsSlice; // 节点间数据分成ranksize份,每份的起始偏移和大小 + // 基于2环数据切分2环SDMA+2环ROH; bool = true表示SDMA + std::vector>> innerMultSlice; + innerDataSegsSlice.resize(innerRankSize); + for (u32 i = 0; i < innerRankSize; i++) { + innerDataSegsSlice[i].size = execMem.inputMem.size() / innerRankSize; + innerDataSegsSlice[i].offset = (i * execMem.inputMem.size() / innerRankSize); + } + + u32 syncTrans = BEST_SPLIT_VALUE; + u64 totalDataSize = execMem.inputMem.size() / innerRankSize; + if (totalDataSize <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + syncTrans = MAX_SPLIT_VALUE; + } + // 把innerDataSegsSlice的一份数据分成 SDMA+RDMA + innerMultSlice.resize(RDMA_PLANE_NUM_IN_NPRING_DOUBLE); + std::vector sdmaSlice; + std::vector rdmaSlice; + for (u32 segsIndex = 0; segsIndex < innerDataSegsSlice.size(); segsIndex++) { + auto totalSize = innerDataSegsSlice[segsIndex].size; + auto sdmaSliceOffset = innerDataSegsSlice[segsIndex].offset; + auto sdmaSliceSize = (totalSize <= HCCL_MIN_SLICE_ALIGN_910_73) ? totalSize: + ((syncTrans * totalSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * HCCL_MIN_SLICE_ALIGN_910_73; + Slice sdmaSliceTmp; + sdmaSliceTmp.offset = sdmaSliceOffset; + sdmaSliceTmp.size = sdmaSliceSize; + Slice rdmaSliceTmp; + rdmaSliceTmp.offset = sdmaSliceOffset + sdmaSliceSize; + rdmaSliceTmp.size = totalSize - sdmaSliceSize; + sdmaSlice.push_back(sdmaSliceTmp); + rdmaSlice.push_back(rdmaSliceTmp); + HCCL_DEBUG("Inner data segId:%u, Orignal [offset %llu, size %llu], sdma [offset %llu, size %llu], "\ + "rdma [offset %llu, size %llu]", segsIndex, sdmaSliceOffset, totalSize, + sdmaSliceTmp.offset, sdmaSliceTmp.size, rdmaSliceTmp.offset, rdmaSliceTmp.size); + } + innerMultSlice[0] = std::make_pair(true, sdmaSlice); // true表示使用sdma + innerMultSlice[1] = std::make_pair(false, rdmaSlice); // false表示rdma + if (syncTrans == MAX_SPLIT_VALUE) { + innerMultSlice.erase(innerMultSlice.end() - 1, innerMultSlice.end()); + } + + u32 commPlaneNum = innerMultSlice.size(); + std::vector> ringNics; + CHK_RET(GetRingNics(param.tag, ringNics)); + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.scratchMem, param.DataDes.dataType, param.reduceType); + std::unique_ptr innerExecutor; + HcclResult ret = HCCL_SUCCESS; + // 节点间共2个通信域,分别走SDMA和RDMA + for (u32 planeIndex = 0; planeIndex < commPlaneNum; planeIndex++) { + std::vector singleSlice = innerMultSlice[planeIndex].second; + CHK_PRT_RET(singleSlice.empty(), + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]singleSlice is empty"), + HCCL_E_INTERNAL); + + CHK_RET(CheckCommSize(COMM_LEVEL1_RDMA, commIndex + 1)); + SubCommInfo innerRdmaCommInfo = GetSubCommInfo(COMM_LEVEL1_RDMA, commIndex); + SubCommInfo ringCommInfo = innerMultSlice[planeIndex].first ? innerCommInfo : innerRdmaCommInfo; + + if (planeIndex != (commPlaneNum - 1)) { // 0~ringNum-2的环 + ret = streamInfo_.ringSignalAux[planeIndex]->Wait( + streamInfo_.ringStreams[planeIndex], dispatcher_, PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]stream[%u] wait failed", + planeIndex), ret); + + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, streamInfo_.ringStreams[planeIndex], param.reduceType, + OUTER_BRIDGE_RANK_ID, singleSlice)); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-hierarchical-ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, streamInfo_.ringStreams[planeIndex], param.reduceType, + OUTER_BRIDGE_RANK_ID, singleSlice)); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-bruck algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, streamInfo_.ringStreams[planeIndex], param.reduceType, + OUTER_BRIDGE_RANK_ID, singleSlice)); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, + reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + + CHK_SMART_PTR_NULL(innerExecutor); + u64 inputDataCount = execMem.inputMem.size() / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, + inputDataCount, param.DataDes.dataType, streamInfo_.ringStreams[planeIndex], param.reduceType, + OUTER_BRIDGE_RANK_ID, singleSlice)); + } + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + ringCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, ringCommInfo)); + + ret = streamInfo_.ringSignal[planeIndex]->Post( + streamInfo_.ringStreams[planeIndex], dispatcher_, PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]stream[%u] record failed", + planeIndex), ret); + + /* 主环record启动从环 */ + ret = streamInfo_.ringSignalAux[planeIndex]->Post(const_cast(param.stream), dispatcher_, + PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]stream[%u] record failed", + planeIndex), ret); + } else { + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, singleSlice)); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-hierarchical-ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, singleSlice)); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-bruck algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, singleSlice)); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, + reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + + CHK_SMART_PTR_NULL(innerExecutor); + u64 inputDataCount = execMem.inputMem.size() / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, + inputDataCount, param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + singleSlice)); // count是output的数据个数 + } + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, ringCommInfo)); + + for (u32 ring = 0; ring < (commPlaneNum - 1); ring++) { + /* 等待executor执行完毕 */ + ret = streamInfo_.ringSignal[ring]->Wait(const_cast(param.stream), dispatcher_, + PROF_STAGE_0); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]stream[%u] wait failed", + ring), ret); + } + } + } + CHK_RET(ExecutorBase::ExecEmptyTask(execMem.inputMem, execMem.outputMem, const_cast(param.stream), + dispatcher_)); + } + /* *********** 节点内reducescatter (正常场景) *****************************/ + std::vector>> mult4RingsSlice; // 基于2环数据切分2环SDMA+2环ROH bool = true表示SDMA + CHK_RET(ActiveSlaveStreams(param.stream)); + + bool useInlineRduce = false; + bool isInlineReduce = IsSupportSDMAReduce(execMem.inputMem.ptr(), execMem.scratchMem.ptr(), param.DataDes.dataType, + param.reduceType); + useInlineRduce = isInlineReduce && algoAttr_.inlineReduceSwitchOn; + multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, execMem.outputMem, + dataSegsSlice, param.tag); + bool bRet = (multiStreamSlice.size() != ringNum); + CHK_PRT_RET(bRet, + HCCL_ERROR("[CollReduceScatterDoubleRingConcurrentExecutor][KernelRun]sliceNum-1[%u] != multiStreamSlice " + "size[%llu]", sliceNum - 1, multiStreamSlice.size()), HCCL_E_INTERNAL); + u32 syncTrans = BEST_SPLIT_VALUE; + u64 totalDataSize = execMem.outputMem.size() * dataSegsSlice.size(); + if (totalDataSize <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + syncTrans = MAX_SPLIT_VALUE; + } + mult4RingsSlice.resize(multiStreamSlice.size() * SLICES_FACTOR); + for (u32 ringIndex = 0; ringIndex < multiStreamSlice.size(); ringIndex++) { + std::vector sdmaSlice; + std::vector rdmaSlice; + for (u32 segsIndex = 0; segsIndex < multiStreamSlice[ringIndex].size(); segsIndex++) { + auto totalSize = multiStreamSlice[ringIndex][segsIndex].size; + auto sdmaSliceOffset = multiStreamSlice[ringIndex][segsIndex].offset; + auto sdmaSliceSize = (totalSize <= HCCL_MIN_SLICE_ALIGN_910_73) ? totalSize: + ((syncTrans * totalSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * HCCL_MIN_SLICE_ALIGN_910_73; + Slice sdmaSliceTmp; + sdmaSliceTmp.offset = sdmaSliceOffset; + sdmaSliceTmp.size = sdmaSliceSize; + Slice rdmaSliceTmp; + rdmaSliceTmp.offset = sdmaSliceOffset + sdmaSliceSize; + rdmaSliceTmp.size = totalSize - sdmaSliceSize; + sdmaSlice.push_back(sdmaSliceTmp); + rdmaSlice.push_back(rdmaSliceTmp); + HCCL_DEBUG("Intra index:%u, segId:%u, Orignal [offset %llu, size %llu], sdma [offset %llu, size %llu], "\ + "rdma [offset %llu, size %llu]", ringIndex, segsIndex, sdmaSliceOffset, totalSize, + sdmaSliceTmp.offset, sdmaSliceTmp.size, rdmaSliceTmp.offset, rdmaSliceTmp.size); + } + mult4RingsSlice[ringIndex] = std::make_pair(true, sdmaSlice); // true表示使用sdma + mult4RingsSlice[ringIndex + multiStreamSlice.size()] = std::make_pair(false, rdmaSlice); // false表示rdma + } + if (syncTrans == MAX_SPLIT_VALUE) { + mult4RingsSlice.erase(mult4RingsSlice.end() - multiStreamSlice.size(), mult4RingsSlice.end()); + } + + DeviceMem srcMem; + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / innerRankSize; + // 每个服务器对应的偏移 + u32 serverIndex = innerCommInfo.localRank; + u64 serverSliceOffset = serverSliceSize * serverIndex; + HCCL_DEBUG("inputMem.size=%llu, outerCommInfo.localRankSize=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ + "commIndex=%u innerCommInfo.localRank=%u", execMem.inputMem.size(), outerCommInfo.localRankSize, + serverSliceSize, serverSliceOffset, commIndex, innerCommInfo.localRank); + DeviceMem reduceScatterRingInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterRingInput.ptr()); + DeviceMem reduceScatterRingOutput = execMem.scratchMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterRingOutput.ptr()); + u64 countLocal = serverSliceSize / perDataSize; + CHK_RET(MultiRingReduceScatterConcurrent(param.tag, reduceScatterRingInput, reduceScatterRingOutput, countLocal, + param.DataDes.dataType, param.reduceType, mult4RingsSlice, param.stream, PROF_STAGE_1, serverSliceOffset, + nullptr)); + + srcMem = execMem.inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, execMem.count * perDataSize); + CHK_SMART_PTR_NULL(srcMem.ptr()); + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + + HCCL_INFO("reducescatter double ring concurrent run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterDoubleRingConcurrentExecutor", ReduceScatterDoubleRingConcurrent, + CollReduceScatterDoubleRingConcurrentExecutor); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_double_ring_concurrent_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_double_ring_concurrent_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c08742585da93fcbb9ca616f18b158a3787d813f --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_double_ring_concurrent_executor.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_DOUBLE_RING_CONCURRENT_EXECUTOR_H +#define COLL_REDUCESCATTER_DOUBLE_RING_CONCURRENT_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterDoubleRingConcurrentExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterDoubleRingConcurrentExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterDoubleRingConcurrentExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + u32 CalcDataSplit(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a0825d7b64aaf435433c3f459edc0f011db804d --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_executor.cc @@ -0,0 +1,262 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_executor.h" + +namespace hccl { + +CollReduceScatterExecutor::CollReduceScatterExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceScatterExecutor::Orchestrate(const OpParam& param, + const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + ParseParam(param); + tag_ = param.tag; + GetStreamInfo(algRes); + algResResp_ = &algRes; + auto rtStream = param.stream.ptr(); + HCCL_PROFILER_ADD_TAG(param.tag, algoAttr_.identifier, GetWorkflowMode()); + HCCL_PROFILER_ADD_STREAM(rtStream, param.tag, 0, algType_); + HCCL_PROFILER_ADD_OPDATA(param.tag, param.DataDes.count, param.inputPtr, param.outputPtr, param.DataDes.dataType, \ + INVALID_VALUE_RANKID, algoAttr_.identifier); + HCCL_PROFILER_ADD_GROUPRANK(algoAttr_.identifier, topoAttr_.userRankSize, topoAttr_.userRank); + CHK_RET(AddSubStreamToProfiling()); + + HcclResult ret = HCCL_SUCCESS; + // 图模式和单卡场景下不需要Loop + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + execMem.inputMem = algRes.paramInputMem; + execMem.outputMem = algRes.paramOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } else if (topoAttr_.userRankSize == 1) { + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } else { + ret = RunLoop(param, algRes); + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterExecutor][Orchestrate]errNo[0x%016llx]excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_DEL_STREAM(rtStream); + HCCL_PROFILER_DEL_TAG(param.tag); + HCCL_PROFILER_DEL_OPDATA(param.tag); + HCCL_PROFILER_DEL_GROUPRANK(param.tag); + } + HCCL_INFO("tag[%s], ReduceScatter executor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = inCCLbufferSize_ / (topoAttr_.userRankSize * unitSize); + HCCL_INFO("[CollReduceScatterExecutor][CalcLoopMaxCount]" \ + "using default maxCountPerLoop[%llu] as CCLBuffSize / (userRankSize * unitSize).", maxCountPerLoop); + return maxCountPerLoop; +} + +bool CollReduceScatterExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = (curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + return hugeData; +} + +bool CollReduceScatterExecutor::IsSmallData(const u64 totalSize, const u64 curSize) +{ + HCCL_INFO("[CollReduceScatterExecutor][IsSmallData]opMeta is using the default option: not small data."); + return false; +} + +u32 CollReduceScatterExecutor::CalcDataSplit(const u64 curSize) +{ + HCCL_INFO("[CollReduceScatterExecutor][CalcDataSplit]opMeta is using the default option: not data split."); + return 0; +} + +HcclResult CollReduceScatterExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + ReduceType reduceType = ((param.reduceType != HCCL_REDUCE_PROD) && + (param.DataDes.dataType != HCCL_DATA_TYPE_INT64)) ? + ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; + + u8 *curInputPtr = static_cast(param.inputPtr); + u8 *curOutputPtr = static_cast(param.outputPtr); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + + u64 maxCountPerLoop = CalcLoopMaxCount(unitSize); + HCCL_DEBUG("[CollReduceScatterExecutor][RunLoop]tag[%s], userRankSize is [%llu], maxCountPerLoop is [%llu].", + param.tag.c_str(), topoAttr_.userRankSize, maxCountPerLoop); + + HcclResult ret; + for (u64 countLeft = param.DataDes.count, curCount = 0, inputOffset = 0, outputOffset = 0; + countLeft > 0; countLeft -= curCount) { + curInputPtr += inputOffset; + curOutputPtr += outputOffset; + // 判断剩余数据量对应的output size是否大于中转output size + curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; + u64 curSize = curCount * unitSize; // 单位:字节 + + HCCL_DEBUG("[CollReduceScatterExecutor][RunLoop]tag[%s], inputOffset[%llu], outputOffset[%llu], " \ + "sendBuf[%p], recvBuf[%p], sendCount[%llu], dataType[%d].", + param.tag.c_str(), inputOffset, outputOffset, curInputPtr, curOutputPtr, curCount, param.DataDes.dataType); + + ExecMem execMem; + execMem.count = curCount; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + if (scratchMemFlag_) { + execMem.scratchMem = algRes.scratchMem; + } else { + execMem.scratchMem = algRes.cclOutputMem; // 不需要申请则传入outputmem为scratchmem + } + HCCL_DEBUG("[CollReduceScatterExecutor][RunLoop]scratchMem address [%p]", execMem.scratchMem.ptr()); + + // 使用当前Loop偏移到的地址作为当前的inputPtr和outputPtr + execMem.inputPtr = curInputPtr; + execMem.outputPtr = curOutputPtr; + + ret = RunLoopInner(param, reduceType, execMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterExecutor][RunLoop]errNo[0x%016llx]kernel run error, tag[%s]", + HCCL_ERROR_CODE(ret), param.tag.c_str()), ret); + + inputOffset = curSize; + outputOffset = curSize; + } + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterExecutor::RunLoopInner(const OpParam ¶m, const ReduceType &reduceType, ExecMem &execMem) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + u64 curSize = execMem.count * unitSize; // 单位:字节 + CHK_PRT_RET((execMem.count == 0), + HCCL_ERROR("[CollReduceScatterExecutor][RunLoopInner]In OP_BASE curCount is zero."), HCCL_E_PARA); + + if (!is310P3Common_) { + /* 设置子图复用标志 */ + auto autoSelectedAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + bool hugeData = IsHugeData(curSize); + bool smallData = IsSmallData(param.DataDes.count * unitSize, curSize); + auto opMeta = HcclOpMetaInfo::GetOneForReduceScatter(autoSelectedAlgTypeLevel1, + param.DataDes.dataType, reduceType, hugeData, smallData); + opMeta.dataSplit = CalcDataSplit(curSize); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, param.tag, + execMem.count, param.DataDes.dataType, param.reduceType, execMem.inputMem.size(), + execMem.outputMem.size())); + } + + if (CCLMemSlice_) { + execMem.inputMem = execMem.inputMem.range(0, curSize * topoAttr_.userRankSize); + execMem.outputMem = execMem.outputMem.range(0, curSize); + if (scratchMemFlag_) { + execMem.scratchMem = execMem.scratchMem.range(0, curSize * topoAttr_.userRankSize); + } + } + + // 执行 + if (!DMAReduceFlag_) { + // 如果使用in CCL buffer,需要将user buffer in中的结果拷贝到CCL buffer in + DeviceMem dstMem; + DeviceMem srcMem; + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + // 拷贝input上每个slice的数据到中转内存,源端每个slice的size固定为output的size + dstMem = execMem.inputMem.range(curSize * i, curSize); + srcMem = DeviceMem::create(static_cast(execMem.inputPtr) + param.DataDes.count * unitSize * i, + curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + } + HCCL_DEBUG("[CollReduceScatterExecutor][RunLoopInner]copy from user in to ccl in."); + } + + HcclResult ret = KernelRun(param, execMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReduceScatterExecutor][RunLoopInner]errNo[0x%016llx]kernel run error, tag[%s], " \ + "inputMem ptr[%p], outputMem ptr[%p], count[%llu], dataType[%d], reduce op type[%d]", + HCCL_ERROR_CODE(ret), param.tag.c_str(), execMem.inputMem.ptr(), execMem.outputMem.ptr(), + execMem.count, param.DataDes.dataType, param.reduceType), + ret); + + if (!DMAReduceFlag_) { + // 如果使用CCL buffer,需要将CCL buffer out中的结果拷贝到user buffer out + DeviceMem srcMem = execMem.outputMem.range(0, curSize); + DeviceMem dstMem = DeviceMem::create(execMem.outputPtr, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + } + + if (!is310P3Common_) { + CHK_RET(RankConsistent::GetInstance().DelOpPara(param.tag)); + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + } + return ret; +} + +std::vector> CollReduceScatterExecutor::ReduceScatterRingSlicePrepare(u32 ringNum, u32 sliceNum, + bool useInlineReduce, DeviceMem& outputMem, std::vector& dataSegsSlice, const std::string &tag) +{ + std::vector> multiStreamSlice; + u64 outputMenSize = outputMem.size(); + dataSegsSlice.clear(); + Slice sliceTemp; + for (u32 i = 0; i < sliceNum; i++) { // 根据数据量算每个环上数据的偏移和大小 + sliceTemp.size = outputMenSize; + sliceTemp.offset = outputMenSize * i; + dataSegsSlice.push_back(sliceTemp); + } + + // 再将每个 slice 划分为 ringNum 份 + if (ringNum == OUTER_PLANE_NUM_IN_8PRING) { + if (useInlineReduce) { + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag); + } else if (outputMem.size() % CCE_REDUCE_ALIGN_SIZE == 0) { + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag); + } else { + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, true); + } + } else if (ringNum == OUTER_PLANE_NUM_IN_NPRING_DOUBLE) { + // 双环场景,需要传入正确的 niclist (不涉及网口裁剪) + if (useInlineReduce) { + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, false, topoAttr_.nicList); + } else if (outputMem.size() % CCE_REDUCE_ALIGN_SIZE == 0) { + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, false, topoAttr_.nicList); + } else { + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, true, topoAttr_.nicList); + } + } else { + multiStreamSlice.push_back(dataSegsSlice); + } + + return multiStreamSlice; +} + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c3accd89c6ecbced1ca5bee4efc084a0b5c2272b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_executor.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_EXECUTOR_H +#define COLL_REDUCESCATTER_EXECUTOR_H +#include "coll_comm_executor.h" + +namespace hccl { + +constexpr u64 CCE_REDUCE_ALIGN_FACTOR = 2; // cce reduce数据大小32字节对齐 2是指前后各有 + +class CollReduceScatterExecutor : public CollCommExecutor { +public: + explicit CollReduceScatterExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; +protected: + // ReduceScatter Loop Executor公共接口 + virtual u64 CalcLoopMaxCount(const u32 unitSize); + virtual bool IsHugeData(const u64 curSize); + virtual bool IsSmallData(const u64 totalSize, const u64 curSize); + virtual u32 CalcDataSplit(const u64 curSize); + virtual HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + + // 工具类 + std::vector> ReduceScatterRingSlicePrepare(u32 ringNum, u32 sliceNum, + bool useInlineReduce, DeviceMem& outputMem, std::vector& dataSegsSlice, const std::string &tag); + + bool CCLMemSlice_{true}; // 每次Loop是否需要对CCLMem进行切片 + bool DMAReduceFlag_{false}; // 是否DMA消减 + bool scratchMemFlag_{false}; // 是否需要申请scratch memory,不需要申请则传入outputmem为scratchmem + u64 totalSize_{0}; // 总数据量 + +private: + HcclResult RunLoopInner(const OpParam ¶m, const ReduceType &reduceType, ExecMem &execMem); +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_dma_elimination.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_dma_elimination.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff4b6602b9b1147871d3063beb98a120a3926c2e --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_dma_elimination.cc @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_mesh_dma_elimination.h" + +namespace hccl { + +CollReduceScatterMeshDmaEliminationExecutor::CollReduceScatterMeshDmaEliminationExecutor( + const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = true; + CCLMemSlice_ = false; +} + +void CollReduceScatterMeshDmaEliminationExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; +} + +HcclResult CollReduceScatterMeshDmaEliminationExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollReduceScatterMeshDmaEliminationExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshDmaEliminationExecutor::CalcCommInfo( + std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshDmaEliminationExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollReduceScatterMeshDmaEliminationExecutor][CalcTransportMemType] tag[%s] inputType[%d]," + " outputType[%d]", tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshDmaEliminationExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = true; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterMeshDmaEliminationExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count,放开ranksize限制 + u64 maxCountPerLoop = inCCLbufferSize_ / unitSize; + return maxCountPerLoop; +} + +bool CollReduceScatterMeshDmaEliminationExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +bool CollReduceScatterMeshDmaEliminationExecutor::IsSmallData(const u64 totalSize, const u64 curSize) +{ + bool smallData = curSize <= HCCL_SMALL_COUNT_32_KB; + return smallData; +} + +HcclResult CollReduceScatterMeshDmaEliminationExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + u32 commIndex = outerCommInfo.localRank; // 找到rank所在的节点间平面 + + /* *******************节点内reducescatter ******************************************/ + CHK_RET(ActiveSlaveStreams(param.stream)); + + u32 sliceNum = outerCommInfo.localRankSize; + // 根据数据量算每个环上数据的偏移和大小,把做完hd的slice均分成RankSize份 + std::vector dataSegsSlice; + CHK_RET(PrepareReduceScatterSliceData(execMem.count, perDataSize, sliceNum, dataSegsSlice)); + + HCCL_DEBUG("inputMem.size()=%llu, outerCommInfo.localRankSize=%u, commIndex=%u", + execMem.inputMem.size(), outerCommInfo.localRankSize, commIndex); + + HcomCollOpInfo *opInfoPtr = nullptr; + HcomCollOpInfo opInfo = {"", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, + param.root, param.reduceType}; + if (DMAReduceFlag_) { + opInfoPtr = &opInfo; + } + + if (topoMatcher_->GetExternalInputHcclDeterministic() == DETERMINISTIC_CONFIG_DISABLE && + (param.DataDes.dataType != HCCL_DATA_TYPE_INT64) && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B && param.reduceType != HCCL_REDUCE_PROD)) { + CHK_RET(MultiStreamReduceScatterMeshAtomic(param.tag, execMem.inputMem, execMem.scratchMem, + execMem.count, param.DataDes.dataType, param.reduceType, dataSegsSlice, const_cast(param.stream), + COMM_LEVEL0, 0, opInfoPtr)); + } else { + std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + // mesh算法stream数量为rank数减1 + CHK_RET(ExecutorBase::PrepareSliceMeshStreams(dataSegsSlice, sliceNum - 1, multiStreamSlice)); + CHK_RET(MultiStreamReduceScatterMesh(param.tag, execMem.inputMem, execMem.scratchMem, + execMem.count, param.DataDes.dataType, param.reduceType, multiStreamSlice, param.stream, COMM_LEVEL0, 0)); + } + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterMeshDmaEliminationExecutor", + ReduceScatterMeshDmaElimination, CollReduceScatterMeshDmaEliminationExecutor); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_dma_elimination.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_dma_elimination.h new file mode 100644 index 0000000000000000000000000000000000000000..72e2ce4888623dc3d115bd93b5019cefef7abc0e --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_dma_elimination.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_MESH_REDUCE_ELIMINATION_EXECUTOR_H +#define COLL_REDUCESCATTER_MESH_REDUCE_ELIMINATION_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterMeshDmaEliminationExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterMeshDmaEliminationExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterMeshDmaEliminationExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + bool IsSmallData(const u64 totalSize, const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa8550f8300e358135be7763d58df19596c11ac9 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_executor.cc @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_mesh_executor.h" + +namespace hccl { + +CollReduceScatterMeshExecutor::CollReduceScatterMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +void CollReduceScatterMeshExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; + + // 910B 图模式非确定计算,inlineReduce使能,MESH拓扑场景下,创建一个mesh平面 + bool isInlineReduce = IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, + param.reduceType); + meshSinglePlane_ = (topoAttr_.deviceType == DevType::DEV_TYPE_910B) && + topoMatcher_->GetExternalInputHcclDeterministic() == DETERMINISTIC_CONFIG_DISABLE && + isInlineReduce && (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); + + // 是否需要scratch memory + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B || topoAttr_.deviceType == DevType::DEV_TYPE_910_73) && + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType) && + IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType)) { + scratchMemFlag_ = false; + } else { + scratchMemFlag_ = true; + } + + // 记录图模式总数据量 + totalSize_ = topoAttr_.userRankSize * param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; +} + +HcclResult CollReduceScatterMeshExecutor::CalcScratchMemSize(u64& scratchMemSize) +{ + if (scratchMemFlag_) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + scratchMemSize = inCCLbufferSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } else { + scratchMemSize = totalSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } + } else { + scratchMemSize = 0U; + } + HCCL_INFO("[CollReduceScatterMeshExecutor][CalcScratchMemSize] tag[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation > 1U ? topoAttr_.deviceNumPerAggregation - 1U : 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollReduceScatterMeshExecutor][CalcStreamNum] tag[%s] streamNum[%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::CCL_OUTPUT; + } + } else { + inputType = TransportMemType::PARAM_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::PARAM_OUTPUT; + } + } + HCCL_INFO("[CollReduceScatterMeshExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = meshSinglePlane_; + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterMeshExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = inCCLbufferSize_ / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +bool CollReduceScatterMeshExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = (curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + return hugeData; +} + +HcclResult CollReduceScatterMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + /* ******************第一步: 节点间reducescatter *******************************/ + u32 commIndex = outerCommInfo.localRank; // 找到rank所在的节点间平面 + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + u32 innerRankSize = innerCommInfo.localRankSize; + if (innerRankSize > 1) { + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.outputMem, param.DataDes.dataType, param.reduceType); + std::unique_ptr innerExecutor; + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(innerExecutor); + HCCL_INFO("reducescatter mesh: using ring algo inter-server."); + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + // 申请临时内存作为scratch内存 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter mesh: using nhr algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + // 申请临时内存作为scratch内存 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHRV1(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter mesh: using nhr_v1 algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + // 申请临时内存作为scratch内存 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter mesh: using nonuniform-bruck algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + // 申请临时内存作为scratch内存 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(innerExecutor); + HCCL_INFO("reducescatter mesh: using halving-doubling algo inter-server."); + // 申请临时内存作为scratch内存 + u64 inputDataCount = execMem.inputMem.size() / perDataSize; // count是output的数据个数 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } + + /* *******************第二步: 节点内reducescatter ******************************************/ + CHK_RET(ActiveSlaveStreams(param.stream)); + + u32 sliceNum = outerCommInfo.localRankSize; + // 根据数据量算每个环上数据的偏移和大小,把做完hd的slice均分成RankSize份 + std::vector dataSegsSlice; + CHK_RET(PrepareReduceScatterSliceData(execMem.count, perDataSize, sliceNum, dataSegsSlice)); + + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / innerRankSize; + // 每个服务器对应的偏移 + u64 serverSliceOffset = serverSliceSize * innerCommInfo.localRank; + + HCCL_DEBUG("inputMem.size=%llu, outerCommInfo.localRankSize=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ + "commIndex=%u innerCommInfo.localRank=%u", execMem.inputMem.size(), outerCommInfo.localRankSize, + serverSliceSize, serverSliceOffset, commIndex, innerCommInfo.localRank); + + DeviceMem reduceScatterMeshInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterMeshInput); + DeviceMem reduceScatterMeshOutput = execMem.scratchMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterMeshOutput); + + HcomCollOpInfo *opInfoPtr = nullptr; + + if (topoMatcher_->GetExternalInputHcclDeterministic() == DETERMINISTIC_CONFIG_DISABLE && + (param.DataDes.dataType != HCCL_DATA_TYPE_INT64) && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B && param.reduceType != HCCL_REDUCE_PROD)) { + CHK_RET(MultiStreamReduceScatterMeshAtomic(param.tag, reduceScatterMeshInput, reduceScatterMeshOutput, // 非确定性 + execMem.count, param.DataDes.dataType, param.reduceType, dataSegsSlice, const_cast(param.stream), + COMM_LEVEL0, serverSliceOffset, opInfoPtr)); + } else { + std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + // mesh算法stream数量为rank数减1 + CHK_RET(ExecutorBase::PrepareSliceMeshStreams(dataSegsSlice, sliceNum - 1, multiStreamSlice)); + CHK_RET(MultiStreamReduceScatterMesh(param.tag, reduceScatterMeshInput, reduceScatterMeshOutput, // 确定性 + execMem.count, param.DataDes.dataType, param.reduceType, multiStreamSlice, + const_cast(param.stream), COMM_LEVEL0, serverSliceOffset)); + } + + DeviceMem srcMem = execMem.inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, + execMem.count * perDataSize); + CHK_SMART_PTR_NULL(srcMem); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterMeshExecutor", ReduceScatterMesh, CollReduceScatterMeshExecutor); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..52a5644cb477d7543ba0fd2a72e942943595f269 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_executor.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_MESH_EXECUTOR_H +#define COLL_REDUCESCATTER_MESH_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterMeshExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterMeshExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; + + bool meshSinglePlane_ = false; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_opbase_pipeline_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_opbase_pipeline_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4bd274c5241d2c07f008c82ff7a313d0e75ae8b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_opbase_pipeline_executor.cc @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_mesh_opbase_pipeline_executor.h" + +namespace hccl { + +CollReduceScatterMeshOpbasePipelineExecutor::CollReduceScatterMeshOpbasePipelineExecutor( + const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = true; +} + +void CollReduceScatterMeshOpbasePipelineExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; +} + +HcclResult CollReduceScatterMeshOpbasePipelineExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation + 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollReduceScatterMeshOpbasePipelineExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshOpbasePipelineExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshOpbasePipelineExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + HCCL_INFO("[CollReduceScatterMeshOpbasePipelineExecutor][CalcTransportMemType] tag[%s] inputType[%d]," + " outputType[%d]", tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterMeshOpbasePipelineExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaInfo.meshSinglePlane = true; + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +// PipeLine模式下使用Ring算法 +HcclResult CollReduceScatterMeshOpbasePipelineExecutor::CalcLevel1CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_LEVEL1, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL1], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterMeshOpbasePipelineExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count,放开ranksize限制 + u64 maxCountPerLoop = ((inCCLbufferSize_ / (HCCL_MIN_SLICE_ALIGN_910B * PIPELINE_DEPTH)) \ + * HCCL_MIN_SLICE_ALIGN_910B - HCCL_MIN_SLICE_ALIGN_910B) / unitSize; + HCCL_INFO("[CollReduceScatterMeshOpbasePipelineExecutor][CalcLoopMaxCount] maxCountPerLoop[%llu]", maxCountPerLoop); + return maxCountPerLoop; +} + +bool CollReduceScatterMeshOpbasePipelineExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize > RDMA_SEND_MAX_SIZE || curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +HcclResult CollReduceScatterMeshOpbasePipelineExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + HCCL_INFO("[CollReduceScatterMeshOpbasePipelineExecutor][RunLoop] begins."); + + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + ReduceType reduceType = ((param.reduceType != HCCL_REDUCE_PROD) && + (param.DataDes.dataType != HCCL_DATA_TYPE_INT64)) ? + ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; + + u8 *curInputPtr = static_cast(param.inputPtr); + u8 *curOutputPtr = static_cast(param.outputPtr); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + + u64 maxCountPerLoop = CalcLoopMaxCount(unitSize); + HCCL_DEBUG("[CollReduceScatterMeshOpbasePipelineExecutor][RunLoop]tag[%s], userRankSize is [%llu], maxCountPerLoop " + "is [%llu].", param.tag.c_str(), topoAttr_.userRankSize, maxCountPerLoop); + + // 先获取 comm inner \ comm outer 的value + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = outerCommInfo.localRank; + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + DeviceMem userInMem = DeviceMem::create(param.inputPtr, param.DataDes.count * unitSize); + u64 reduceAttr = GetReduceAttr(userInMem, const_cast(algRes.cclInputMem), param.DataDes.dataType, + param.reduceType); // scratchMem用的cclin + u64 bufferSize = algRes.cclInputMem.size(); + + auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + + for (u64 countLeft = param.DataDes.count, curCount = 0, curOffset = 0, curSize = 0; countLeft > 0; + countLeft -= curCount) { + curInputPtr += curSize; + curOutputPtr += curSize; + + curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; + curSize = curCount * unitSize; + + HCCL_DEBUG("[CollReduceScatterMeshOpbasePipelineExecutor][RunLoop]tag[%s], curOffset[%llu], " \ + "curInputPtr[%p], curOutputPtr[%p], curCount[%llu], dataType[%d].", + param.tag.c_str(), curOffset, curInputPtr, curOutputPtr, curCount, param.DataDes.dataType); + + bool hugeData = IsHugeData(curSize); + auto meta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, param.DataDes.dataType, reduceType, + hugeData); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), meta.isEnableCache, meta.GetCacheKey())); + + ExecMem execMem; + execMem.count = curCount; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + // 使用当前Loop偏移到的地址作为当前的inputPtr和outputPtr + execMem.inputPtr = curInputPtr; + execMem.outputPtr = curOutputPtr; + + std::unique_ptr executor; + executor.reset(new (std::nothrow) ReduceScatterPipeline(dispatcher_, reduceAttr)); + CHK_SMART_PTR_NULL(executor); + + HcomCollOpInfo opInfo = {"", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, + param.root, param.reduceType}; + + CHK_RET(executor->Prepare(&opInfo, execMem.inputMem, curCount, bufferSize, curOffset, outerCommInfo, + innerCommInfo, const_cast(param.stream), streamInfo_.ringStreams, streamInfo_.ringSignal, + streamInfo_.ringSignalAux)); + CHK_RET(executor->RunAsync()); + + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + + curOffset += curSize; + } + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterMeshOpbasePipelineExecutor", ReduceScatterMeshOpbasePipeline, + CollReduceScatterMeshOpbasePipelineExecutor); + +} + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_opbase_pipeline_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_opbase_pipeline_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..5639477cc9cb9830172bfeb2e7f11c06b05df8e7 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_mesh_opbase_pipeline_executor.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_MESH_OPBASE_PIPELINE_EXECUTOR_H +#define COLL_REDUCESCATTER_MESH_OPBASE_PIPELINE_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterMeshOpbasePipelineExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterMeshOpbasePipelineExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterMeshOpbasePipelineExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel1CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + + HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..387a52b922435d8e6c8d325e5da0747e733e2ef8 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_executor.cc @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_ring_executor.h" + +namespace hccl { + +CollReduceScatterRingExecutor::CollReduceScatterRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + topoAttr_.deviceType == DevType::DEV_TYPE_910_73); +} + +void CollReduceScatterRingExecutor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; + + // 是否需要scratch memory + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + (topoAttr_.deviceType == DevType::DEV_TYPE_910B || topoAttr_.deviceType == DevType::DEV_TYPE_910_73) && + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType) && + IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType)) { + scratchMemFlag_ = false; + } else { + scratchMemFlag_ = true; + } + + // 记录图模式总数据量 + totalSize_ = topoAttr_.userRankSize * param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; +} + +HcclResult CollReduceScatterRingExecutor::CalcScratchMemSize(u64& scratchMemSize) +{ + if (scratchMemFlag_) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + scratchMemSize = inCCLbufferSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } else { + scratchMemSize = totalSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } + } else { + scratchMemSize = 0U; + } + HCCL_INFO("[CollReduceScatterRingExecutor][CalcScratchMemSize] tag[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = 1U; + switch (algType_) { + case AlgType::ALG_8P_RING_PLUS_HD: + case AlgType::ALG_8P_RING_PLUS_RING: + case AlgType::ALG_8P_RING_PLUS_NHR: + case AlgType::ALG_8P_RING_PLUS_NHR_V1: + case AlgType::ALG_8P_RING_PLUS_NB: + case AlgType::ALG_8P_RING_PLUS_PIPELINE: + totalStreamNum = OUTER_PLANE_NUM_IN_8PRING; + break; + case AlgType::ALG_NP_SINGLE_RING_PLUS_RING: + case AlgType::ALG_NP_SINGLE_RING_PLUS_HD: + if (topoAttr_.deviceType == DevType::DEV_TYPE_910_73) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + } + } + break; + default: + break; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollReduceScatterRingExecutor][CalcStreamNum] tag[%s] streamNum[%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingExecutor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::CCL_OUTPUT; + } + } else { + inputType = TransportMemType::PARAM_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::PARAM_OUTPUT; + } + } + HCCL_INFO("[CollReduceScatterRingExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollReduceScatterRingExecutor][CalcLevel0CommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollReduceScatterRingExecutor][CalcLevel0CommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterRingExecutor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count,放开ranksize限制 + u64 maxCountPerLoop = inCCLbufferSize_ / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +bool CollReduceScatterRingExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData; + if (DMAReduceFlag_) { + hugeData = curSize > SDMA_SEND_MAX_SIZE; + } else { + hugeData = (curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + } + + return hugeData; +} + +HcclResult CollReduceScatterRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollReduceScatterRingExecutor][KernelRun] The ReduceScatterRingExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + u32 ringNum = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? OUTER_PLANE_NUM_IN_8PRING : + OUTER_PLANE_NUM_IN_NPRING_SINGLE; + + u32 commIndex = (ringNum == OUTER_PLANE_NUM_IN_8PRING) ? topoAttr_.devicePhyId : outerCommInfo.localRank; + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + /* ******************网口裁剪步骤: 节点内allreduce *******************************/ + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + u32 sliceNum = outerCommInfo.localRankSize; + // Slice sliceTemp; + bool isMultiNic = topoType_ == TopoType::TOPO_TYPE_8P_RING && topoAttr_.nicList.size() != DEVICE_EIGHT; + if (isMultiNic) { + u64 inputDataCount = execMem.inputMem.size() / perDataSize; + CHK_RET(ExecutorBase::PrepareSliceData(inputDataCount, perDataSize, sliceNum, 0, dataSegsSlice)); + multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, param.tag); + CHK_PRT_RET(multiStreamSlice.size() != ringNum, + HCCL_ERROR("[CollReduceScatterRingExecutor][KernelRun]ringNum[%u] != multiStreamSlice size[%llu]", + ringNum, multiStreamSlice.size()), HCCL_E_INTERNAL); + + CHK_RET(MultiRingAllReduce(param.tag, execMem.inputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.reduceType, multiStreamSlice, param.stream, PROF_STAGE_0)); + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.inputMem, execMem.scratchMem, + const_cast(param.stream))); + } + + std::vector &nicList = const_cast &>(topoAttr_.nicList); + std::vector::iterator iterNic = std::find(nicList.begin(), nicList.end(), topoAttr_.devicePhyId); + bool innRunRet = isMultiNic && (iterNic == nicList.end()); + if (!innRunRet) { // 1. 8P ring的拓扑。2. 网口不满配。3. 当前device不出网口。 的情况下不进行节点间的reduce scatter + /* ******************第一步: 节点间reducescatter *******************************/ + u32 innerRankSize = innerCommInfo.localRankSize; + if (innerRankSize > 1) { + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.scratchMem, param.DataDes.dataType, + param.reduceType); + std::unique_ptr innerExecutor; + + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0))); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nhr algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0))); + } else if (UseInterServerNHRV1Algo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHRV1(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nhr_v1 algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0))); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-bruck algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / innerRankSize; + u64 ringCount = ringSize / perDataSize; + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0))); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + + CHK_SMART_PTR_NULL(innerExecutor); + u64 inputDataCount = execMem.inputMem.size() / perDataSize; // count是output的数据个数 + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, + std::vector(0))); + } + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } + } + + /* ***********第二步: 节点内reducescatter(正常场景), 节点内多根结点scatter(网口裁剪)*****************************/ + CHK_RET(ActiveSlaveStreams(param.stream)); + + bool useInlineRduce = false; + bool isInlineReduce = IsSupportSDMAReduce(execMem.inputMem.ptr(), execMem.scratchMem.ptr(), param.DataDes.dataType, + param.reduceType); + useInlineRduce = isInlineReduce && algoAttr_.inlineReduceSwitchOn; + multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, execMem.outputMem, + dataSegsSlice, param.tag); + bool bRet = (multiStreamSlice.size() != ringNum); + CHK_PRT_RET(bRet, + HCCL_ERROR("[CollReduceScatterRingExecutor][KernelRun]sliceNum-1[%u] != multiStreamSlice size[%llu]", \ + sliceNum - 1, multiStreamSlice.size()), HCCL_E_INTERNAL); + + if (isMultiNic) { // 网口裁剪情况下需要改变slice最终在rank上位置 + PrepareMultiRingSlice(dataSegsSlice, param.tag, false, nicList); // 刷新多环ringRankList信息 + std::vector> ringNics; + CHK_RET(GetRingNics(param.tag, ringNics)); + + for (u32 ringIdx = 0; ringIdx < ringNum; ringIdx++) { // 按第一个网口位置改变slice最终在rank上的位置 + u32 firstNicIdx = ringNics[ringIdx][0]; + std::rotate(multiStreamSlice[ringIdx].begin(), multiStreamSlice[ringIdx].begin() + firstNicIdx, + multiStreamSlice[ringIdx].end()); + } + } + + DeviceMem srcMem; + if (isMultiNic) { + u32 innerRankSize = topoAttr_.userRankSize / DEVICE_EIGHT; // currComm->commOuter[0]->UserRankSize(); + // 每个server分配的slice大小 + CHK_PRT_RET(innerRankSize == 0, + HCCL_ERROR("[CollReduceScatterRingExecutor][KernelRun]innerRankSize is illegal"), HCCL_E_PARA); + u64 serverSliceSize = execMem.inputMem.size() / innerRankSize; + // 每个服务器对应的偏移 + u32 serverIndex = innerCommInfo.localRank; + CHK_PRT_RET(serverIndex == INVALID_VALUE_RANKID, + HCCL_ERROR("[CollReduceScatterRingExecutor][KernelRun]get rank of " + "bridgeRank failed, commIdx[%u]", commIndex), HCCL_E_PARA); + u64 serverSliceOffset = serverSliceSize * serverIndex; + if (UseInterServerRingAlgo(algType_)) { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.scratchMem, execMem.inputMem, + const_cast(param.stream))); + } + DeviceMem reduceScatterRingOutput = execMem.scratchMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterRingOutput.ptr()); + u64 countLocal = serverSliceSize / perDataSize; + CHK_RET(MultiRingMultiRootScatter(param.tag, reduceScatterRingOutput, reduceScatterRingOutput, countLocal, + param.DataDes.dataType, multiStreamSlice, serverIndex * DEVICE_EIGHT, param.stream, serverSliceOffset)); + + srcMem = reduceScatterRingOutput.range(dataSegsSlice[topoAttr_.devicePhyId].offset, + execMem.count * perDataSize); + CHK_SMART_PTR_NULL(srcMem.ptr()); + } else { + u32 innerRankSize = innerCommInfo.localRankSize; + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / innerRankSize; + // 每个服务器对应的偏移 + u32 serverIndex = innerCommInfo.localRank; + u64 serverSliceOffset = serverSliceSize * serverIndex; + HCCL_DEBUG("inputMem.size=%llu, outerCommInfo.localRankSize=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ + "commIndex=%u commInner[commIndex]->rank=%u", execMem.inputMem.size(), outerCommInfo.localRankSize, + serverSliceSize, serverSliceOffset, commIndex, innerCommInfo.localRank); + DeviceMem reduceScatterRingInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterRingInput.ptr()); + DeviceMem reduceScatterRingOutput = execMem.scratchMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(reduceScatterRingOutput.ptr()); + u64 countLocal = serverSliceSize / perDataSize; + + HcomCollOpInfo opInfo = {"", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, + param.root, param.reduceType}; + HcomCollOpInfo *opInfoPtr = nullptr; + if (DMAReduceFlag_) { + opInfoPtr = &opInfo; + } + + CHK_RET(MultiRingReduceScatter(param.tag, reduceScatterRingInput, reduceScatterRingOutput, countLocal, + param.DataDes.dataType, param.reduceType, multiStreamSlice, param.stream, PROF_STAGE_1, serverSliceOffset, + opInfoPtr)); + + srcMem = execMem.inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, + execMem.count * perDataSize); + CHK_SMART_PTR_NULL(srcMem.ptr()); + } + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterRingExecutor", ReduceScatterRing, CollReduceScatterRingExecutor); + +} + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c4c5950b910d12f8b045c3035fddfdc6d3d751bd --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_executor.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_RING_EXECUTOR_H +#define COLL_REDUCESCATTER_RING_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterRingExecutor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterRingExecutor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_for_910_73_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_for_910_73_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..d142afbc6084776191262394307b900278033eff --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_for_910_73_executor.cc @@ -0,0 +1,411 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_ring_for_910_73_executor.h" + +namespace hccl { + +CollReduceScatterRingFor91073Executor::CollReduceScatterRingFor91073Executor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); +} + +void CollReduceScatterRingFor91073Executor::ParseParam(const OpParam& param) +{ + tag_ = param.tag; + + // 是否需要scratch memory + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType) && + IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType)) { + scratchMemFlag_ = false; + } else { + scratchMemFlag_ = true; + } + + // 记录图模式总数据量 + totalSize_ = topoAttr_.userRankSize * param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; +} + +HcclResult CollReduceScatterRingFor91073Executor::CalcScratchMemSize(u64& scratchMemSize) +{ + if (scratchMemFlag_) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + scratchMemSize = inCCLbufferSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } else { + scratchMemSize = totalSize_ + CCE_REDUCE_ALIGN_FACTOR * CCE_REDUCE_ALIGN_SIZE; + } + } else { + scratchMemSize = 0U; + } + HCCL_INFO("[CollReduceScatterRingFor91073Executor][CalcScratchMemSize] tag[%s] scratchMemSize[%u]", + tag_.c_str(), scratchMemSize); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingFor91073Executor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING ? OUTER_PLANE_NUM_IN_NPRING_DOUBLE : + OUTER_PLANE_NUM_IN_NPRING_SINGLE); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + totalStreamNum *= STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollReduceScatterRingFor91073Executor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingFor91073Executor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingFor91073Executor::CalcTransportMemType(TransportMemType &inputType, + TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::CCL_OUTPUT; + } + } else { + inputType = TransportMemType::PARAM_INPUT; + if (scratchMemFlag_) { + outputType = TransportMemType::SCRATCH; + } else { + outputType = TransportMemType::PARAM_OUTPUT; + } + } + HCCL_INFO("[CollReduceScatterRingFor91073Executor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingFor91073Executor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollReduceScatterRingFor91073Executor::CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_MAX); + if (UseLevel2RingAlgo(algType_)) { + commParaLevel2.commType = CommType::COMM_TAG_RING_INNER; + } else { + commParaLevel2.commType = CommType::COMM_TAG_HALVING_DOUBLING; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + return HCCL_SUCCESS; +} + +u64 CollReduceScatterRingFor91073Executor::CalcLoopMaxCount(const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count,放开ranksize限制 + u64 maxCountPerLoop = inCCLbufferSize_ / (topoAttr_.userRankSize * unitSize); + return maxCountPerLoop; +} + +bool CollReduceScatterRingFor91073Executor::IsHugeData(const u64 curSize) +{ + bool hugeData; + if (DMAReduceFlag_) { + hugeData = curSize > SDMA_SEND_MAX_SIZE; + } else { + hugeData = (curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || + (curSize > SDMA_SEND_MAX_SIZE); + } + + return hugeData; +} + +HcclResult CollReduceScatterRingFor91073Executor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_INFO("[CollReduceScatterRingFor91073Executor][KernelRun] The ReduceScatterDoubleRingExecutor starts."); + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(param.DataDes.dataType, perDataSize)); + + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo outerCommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + + u32 ringNum; + if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + } else { + ringNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + } + + u32 sliceNum = outerCommInfo.localRankSize; + Slice sliceTemp; + u32 commIndex = outerCommInfo.localRank; + commIndex = RefreshCommIdx(commIndex, topoAttr_.nicList, topoAttr_.devicePhyId); + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo innerCommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + + CHK_RET(CheckCommSize(COMM_LEVEL2, COMM_INDEX_0 + 1)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0); + u32 level2RankSize = level2CommInfo.localRankSize; + + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + /* 超节点间通信域是commLevel2 */ + if (level2RankSize > 1) { + /* ****************** 超节点间 reducescatter *******************************/ + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.scratchMem, param.DataDes.dataType, param.reduceType); + std::unique_ptr level2Executor; + + if (UseLevel2RingAlgo(algType_)) { + level2Executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-superPod."); + CHK_SMART_PTR_NULL(level2Executor); + + u64 ringCount = execMem.inputMem.size() / (level2RankSize * perDataSize); + CHK_RET(level2Executor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } else { + level2Executor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-superPod."); + + CHK_SMART_PTR_NULL(level2Executor); + u64 inputDataCount = execMem.inputMem.size() / perDataSize; // count是output的数据个数 + CHK_RET(level2Executor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0))); + } + CHK_RET(level2Executor->RegisterProfiler( + (level2RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2Executor, level2CommInfo)); + + /* ****************** 节点间 reducescatter *******************************/ + + u32 innerRankSize = innerCommInfo.localRankSize; + if (innerRankSize > 1) { + std::unique_ptr innerExecutor; + u32 level1Index = innerCommInfo.localRank; + + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + CHK_SMART_PTR_NULL(innerExecutor); + + u64 ringSize = execMem.inputMem.size() / (innerRankSize * level2RankSize); + u64 ringCount = ringSize / perDataSize; + u64 level1SliceOffset = ringSize * level1Index; + DeviceMem level1InputMem = execMem.inputMem.range(level1SliceOffset, ringSize); + CHK_SMART_PTR_NULL(level1InputMem.ptr()); + + CHK_RET(innerExecutor->Prepare(level1InputMem, level1InputMem, execMem.scratchMem, ringCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0), + level1SliceOffset)); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); + + CHK_SMART_PTR_NULL(innerExecutor); + u64 inputDataCount = execMem.inputMem.size() / (perDataSize * level2RankSize); + u64 level1SliceSize = execMem.inputMem.size() / level2RankSize; + u64 level1SliceOffset = level1SliceSize * level1Index; + + DeviceMem level1InputMem = execMem.inputMem.range(level1SliceOffset, level1SliceSize); + // count是output的数据个数 + CHK_RET(innerExecutor->Prepare(level1InputMem, level1InputMem, execMem.scratchMem, inputDataCount, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, std::vector(0), + level1SliceOffset)); + } + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + } + + /* *********** 节点内reducescatter (正常场景) *****************************/ + CHK_RET(ActiveSlaveStreams(param.stream)); + + bool useInlineRduce = false; + bool isInlineReduce = IsSupportSDMAReduce(execMem.inputMem.ptr(), execMem.scratchMem.ptr(), + param.DataDes.dataType, param.reduceType); + useInlineRduce = isInlineReduce && algoAttr_.inlineReduceSwitchOn; + multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, execMem.outputMem, + dataSegsSlice, param.tag); + bool bRet = (multiStreamSlice.size() != ringNum); + CHK_PRT_RET(bRet, + HCCL_ERROR("[CollReduceScatterRingFor91073Executor][KernelRun]sliceNum-1[%u] != multiStreamSlice" \ + "size[%llu]", sliceNum - 1, multiStreamSlice.size()), HCCL_E_INTERNAL); + + DeviceMem srcMem; + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / (innerRankSize * level2RankSize); + // 每个服务器对应的偏移 + u32 serverIndex = innerCommInfo.localRank; + u64 serverSliceOffset = serverSliceSize * serverIndex; + HCCL_DEBUG("inputMem.size=%llu, outerCommInfo.localRankSize=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ + "commIndex=%u innerCommInfo.localRank=%u", execMem.inputMem.size(), outerCommInfo.localRankSize, + serverSliceSize, serverSliceOffset, commIndex, innerCommInfo.localRank); + DeviceMem reduceScatterRingInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + DeviceMem reduceScatterRingOutput = execMem.scratchMem.range(serverSliceOffset, serverSliceSize); + + u64 countLocal = serverSliceSize / perDataSize; + CHK_RET(MultiRingReduceScatter(param.tag, reduceScatterRingInput, reduceScatterRingOutput, countLocal, + param.DataDes.dataType, param.reduceType, multiStreamSlice, param.stream, PROF_STAGE_1, serverSliceOffset)); + + srcMem = execMem.inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, + execMem.count * perDataSize); + CHK_SMART_PTR_NULL(srcMem.ptr()); + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + HCCL_INFO("reducescatter double ring run success"); + return HCCL_SUCCESS; + } + + // 节点内reduce scatter + CHK_RET(ActiveSlaveStreams(param.stream)); + u32 innerRankSize = innerCommInfo.localRankSize; + + // 计算slice + std::vector > level0DataSegsSlice; + bool useInlineRduce = false; + bool isInlineReduce = IsSupportSDMAReduce(execMem.inputMem.ptr(), execMem.scratchMem.ptr(), param.DataDes.dataType, + param.reduceType); + useInlineRduce = isInlineReduce && algoAttr_.inlineReduceSwitchOn; + multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, execMem.outputMem, + dataSegsSlice, param.tag); + for (u32 ringIndex = 0; ringIndex < multiStreamSlice.size(); ringIndex++) { + std::vector dataSlice; + for (u32 level0Idx = 0; level0Idx < sliceNum; level0Idx++) { + Slice sliceTemp; + for (u32 level1Idx = 0; level1Idx < innerRankSize; level1Idx++) { + sliceTemp.size = multiStreamSlice[ringIndex][level0Idx].size; + sliceTemp.offset = + multiStreamSlice[ringIndex][level0Idx].offset + level1Idx * sliceNum * execMem.outputMem.size(); + dataSlice.push_back(sliceTemp); + } + } + level0DataSegsSlice.push_back(dataSlice); + } + std::vector> multRingsUserMemSlice; + + HcomCollOpInfo opInfo = {"", execMem.inputPtr, execMem.outputPtr, param.DataDes.count, param.DataDes.dataType, + param.root, param.reduceType}; + HcomCollOpInfo *opInfoPtr = nullptr; + if (DMAReduceFlag_) { + opInfoPtr = &opInfo; + } + + if (opInfoPtr == nullptr) { + multRingsUserMemSlice = level0DataSegsSlice; + } else { + for (u32 ringIndex = 0; ringIndex < level0DataSegsSlice.size(); ringIndex++) { + std::vector level1UserMemSlice; + for (auto &cclSlice : level0DataSegsSlice[ringIndex]) { + Slice tmpSlice; + tmpSlice.size = cclSlice.size; + tmpSlice.offset = + (cclSlice.offset / execMem.outputMem.size()) * param.DataDes.count * perDataSize + + multiStreamSlice[ringIndex][0].offset; + level1UserMemSlice.push_back(tmpSlice); + HCCL_DEBUG("rank[%u], ringIndex[%u], tmpSlice.offset=[%llu], size=[%llu]", + topoAttr_.userRank, ringIndex, tmpSlice.offset, tmpSlice.size); + } + multRingsUserMemSlice.push_back(level1UserMemSlice); + } + } + // 区分消减拷贝场景 + if (opInfoPtr != nullptr && innerRankSize > 1) { + HcomCollOpInfo opInfoByReduceScatterDMAreduce = *opInfoPtr; + opInfoByReduceScatterDMAreduce.outputAddr = nullptr; + CHK_RET(MultiRingReduceScatter(param.tag, execMem.inputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.reduceType, level0DataSegsSlice, + param.stream, PROF_STAGE_1, 0, &opInfoByReduceScatterDMAreduce, multRingsUserMemSlice)); + } else { + CHK_RET(MultiRingReduceScatter(param.tag, execMem.inputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.reduceType, + level0DataSegsSlice, param.stream, PROF_STAGE_1, 0, opInfoPtr, multRingsUserMemSlice)); + } + // 对于单server图模式场景最后一步需要把数据从ccl input拷贝到ccl output上 + if (innerRankSize == 1 && opInfoPtr == nullptr) { + DeviceMem srcMem = execMem.inputMem.range(topoAttr_.userRank * execMem.outputMem.size(), + execMem.outputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + } + if (innerRankSize > 1) { + // 节点间做reduce scatter(ring/NHR) + u64 reduceAttr = GetReduceAttr(execMem.inputMem, execMem.scratchMem, param.DataDes.dataType, param.reduceType); + std::unique_ptr innerExecutor; + + // 计算slice + u32 level0ServerIndex = 0; + HcclResult ret = GetRankByUserRank(COMM_LEVEL0, COMM_INDEX_0, topoAttr_.userRank, level0ServerIndex); + + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CollReduceScatterRingFor91073Executor][KernelRun]Get Rank[%u] " + "by User Rank[%u] from CommOuter[%u] Failed!", level0ServerIndex, topoAttr_.userRank, commIndex), ret); + + std::vector level1DataSegsSlice; + for (u32 i = 0; i < innerRankSize; i++) { + sliceTemp.size = execMem.outputMem.size(); + u32 level1UserRank; + CHK_RET(GetUserRankByRank(COMM_LEVEL1, commIndex, i, level1UserRank)); + sliceTemp.offset = level1UserRank * execMem.outputMem.size(); + level1DataSegsSlice.push_back(sliceTemp); + HCCL_DEBUG("rank[%u], level1DataSegsSlice[%u].offset=%llu, size=[%llu]", topoAttr_.userRank, i, + sliceTemp.offset, sliceTemp.size); + } + if (UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using ring algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); + HCCL_INFO("reducescatter ring: using nonuniform-hierarchical-ring algo inter-server."); + } + CHK_SMART_PTR_NULL(innerExecutor); + + CHK_RET(innerExecutor->Prepare(execMem.inputMem, execMem.inputMem, execMem.scratchMem, execMem.count, + param.DataDes.dataType, param.stream, param.reduceType, OUTER_BRIDGE_RANK_ID, level1DataSegsSlice)); + CHK_RET(innerExecutor->RegisterProfiler( + (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + innerCommInfo.localRank, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(innerExecutor, innerCommInfo)); + + // 区分消减拷贝场景(消减拷贝数据需要拷贝到user output上) + DeviceMem srcMem = execMem.inputMem.range(topoAttr_.userRank * execMem.outputMem.size(), + execMem.outputMem.size()); + if (opInfoPtr != nullptr) { + DeviceMem dstMem = DeviceMem::create(static_cast(opInfoPtr->outputAddr), execMem.outputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + } else { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, srcMem, const_cast(param.stream))); + } + } + + HCCL_INFO("reducescatter double ring run success"); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterRingFor91073Executor", ReduceScatterRingFor91073, CollReduceScatterRingFor91073Executor); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_for_910_73_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_for_910_73_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..d7cef6dc405660ad65450bd7de6a87ea50081bf2 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_ring_for_910_73_executor.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_RING_FOR_910_73_EXECUTOR_H +#define COLL_REDUCESCATTER_RING_FOR_910_73_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterRingFor91073Executor : public CollReduceScatterExecutor { +public: + explicit CollReduceScatterRingFor91073Executor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterRingFor91073Executor() = default; + +private: + void ParseParam(const OpParam& param) override; + /* *************** 资源计算 *************** */ + HcclResult CalcScratchMemSize(u64& scratchMemSize) override; + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u32 unitSize) override; + bool IsHugeData(const u64 curSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_single_rank_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_single_rank_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9f7b65cb94346978f8924ec4eaf37f773fa566d --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_single_rank_executor.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_reduce_scatter_single_rank_executor.h" + +namespace hccl { + +CollReduceScatterSingleRankExecutor::CollReduceScatterSingleRankExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollReduceScatterExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollReduceScatterSingleRankExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + u64 totalSize = execMem.count * SIZE_TABLE[param.DataDes.dataType]; + ReduceType reduceType = + ((param.reduceType != HCCL_REDUCE_PROD) && (param.DataDes.dataType != HCCL_DATA_TYPE_INT64)) ? + ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; + + auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; + bool hugeData = totalSize > SDMA_SEND_MAX_SIZE; + bool smallData = totalSize <= HCCL_SMALL_COUNT_32_KB; + if (execMem.inputPtr == execMem.outputPtr) { + auto opMeta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, param.DataDes.dataType, reduceType, + hugeData, smallData, CopyPattern::ZCOPY); // 通过CopyPattern字段区分不同的子图 + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + } else { // ranksize = 1; intput、output地址不同,input->output + auto opMeta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, param.DataDes.dataType, reduceType, + hugeData, smallData, CopyPattern::BCOPY); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), opMeta.isEnableCache, opMeta.GetCacheKey())); + DeviceMem srcMem(execMem.inputPtr, totalSize); + DeviceMem dstMem(execMem.outputPtr, totalSize); + HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream)); + } + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReduceScatterSingleExecutor", ReduceScatterSingleRank, CollReduceScatterSingleRankExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_single_rank_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_single_rank_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..d6671d3fdebea8cf12416476b9e130c4e2a25e11 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_reduce_scatter/coll_reduce_scatter_single_rank_executor.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_REDUCESCATTER_SINGLE_RANK_EXECUTOR_H +#define COLL_REDUCESCATTER_SINGLE_RANK_EXECUTOR_H +#include "coll_reduce_scatter_executor.h" + +namespace hccl { +class CollReduceScatterSingleRankExecutor : public CollReduceScatterExecutor { + +public: + explicit CollReduceScatterSingleRankExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollReduceScatterSingleRankExecutor() = default; + +private: + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b26f95942de0297d64608befc48a0ba80fa27ca --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/CMakeLists.txt @@ -0,0 +1,11 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_scatter_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_scatter_ring_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_scatter_comm_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_scatter_mesh_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_comm_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_comm_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4f89e5144d83ce371fd5876b457e16e4dca0af8 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_comm_executor.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_scatter_comm_executor.h" + +namespace hccl { +CollScatterCommExecutor::CollScatterCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollScatterExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollScatterCommExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcCombinedCommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollScatterCommExecutor::CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaInfo(COMM_COMBINE, CommType::COMM_TAG_MAX); + if (UseInterServerNHRAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING; + } else if (UseInterServerNBAlgo(algType_)) { + commParaInfo.commType = CommType::COMM_TAG_NONUNIFORM_BRUCK; + } else { + commParaInfo.commType = CommType::COMM_TAG_RING_INNER; + } + CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_COMBINE], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollScatterCommExecutor::CalcStreamNum(u32& streamNum) +{ + // 只传递从流数量 + streamNum = 0; + HCCL_INFO("[CollScatterCommExecutor][CalcStreamNum]tag[%s] streamNum_ is [%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollScatterCommExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + DeviceMem& inputMem = execMem.inputMem; + DeviceMem& outputMem = execMem.outputMem; + u64 count = execMem.count; + auto root = param.root; + auto dataType = param.DataDes.dataType; + Stream& stream = const_cast(param.stream); + u32 userRank = topoAttr_.userRank; + + u32 commIndex = COMM_INDEX_0; + // 统一走server间 + CHK_RET(CheckCommSize(COMM_COMBINE, commIndex + 1)); + SubCommInfo combinedCommInfo = GetSubCommInfo(COMM_COMBINE, 0); + + CHK_RET(KernelRunInner(inputMem, count, dataType, commIndex, root, userRank, COMM_COMBINE, stream)); + + // 将scratchMem赋值给outputMem + u8 *inputMemPtr = static_cast(inputMem.ptr()); + DeviceMem resultMem(inputMemPtr + outputMem.size() * combinedCommInfo.localRank, outputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem, resultMem, stream)); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ScatterCommExecutor", ScatterComm, CollScatterCommExecutor); + +} diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_comm_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_comm_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..78248a78b86116f8c3e94413c78c03b1c3727208 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_comm_executor.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_SCATTER_COMM_EXECUTOR_H +#define COLL_SCATTER_COMM_EXECUTOR_H + +#include "coll_scatter_executor.h" +#include "coll_alg_exec_registry.h" + +namespace hccl { + +// 所有 Scatter Executor 的基类,继承自 NativeExecutor +class CollScatterCommExecutor : public CollScatterExecutor { +public: + explicit CollScatterCommExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollScatterCommExecutor() = default; +protected: + /* *************** 资源计算 *************** */ + HcclResult CalcCommInfo(std::vector& opTransport) override; + + HcclResult CalcCombinedCommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport); + + HcclResult CalcStreamNum(u32& streamNum) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +private: +}; + +} // namespace hccl + +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1983df339524ec4c5c0a989c9a15561ca735b4a --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_executor.cc @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_scatter_executor.h" +#include "executor_impl.h" +#include "stream_active_manager.h" +#include "device_capacity.h" +#include "coll_alg_operator.h" + +namespace hccl { +CollScatterExecutor::CollScatterExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollNativeExecutorBase(dispatcher, topoMatcher) +{ +} + +HcclResult CollScatterExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollScatterExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_INPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_INPUT; + } + return HCCL_SUCCESS; +} + +bool CollScatterExecutor::IsHugeData(u64 curSize) +{ + bool hugeData = curSize / topoAttr_.deviceNumPerAggregation / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE || + curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +HcclResult CollScatterExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + auto dataType = param.DataDes.dataType; + u32 unitSize = SIZE_TABLE[dataType]; + RankId root = param.root; + + auto totalRecvCount = param.DataDes.count; + + u8 *curUserInputPtr = static_cast(param.inputPtr); + u8 *curUserOutputPtr = static_cast(param.outputPtr); + if (topoAttr_.userRank == root) { + CHK_PTR_NULL(curUserInputPtr); + } + CHK_PTR_NULL(curUserOutputPtr); + + auto inCCLbuffer = algRes.cclInputMem; + auto outCCLbuffer = algRes.cclOutputMem; + + u64 maxCountPerLoop = inCCLbuffer.size() / (topoAttr_.userRankSize * unitSize); // 中转内存单次最多能够接受的output count + + HCCL_DEBUG("[CollScatterExecutor][RunLoop]tag[%s], userRankSize is [%u], root is [%u], " + "maxCountPerLoop is [%llu], totalRecvCount is [%llu]", + tag_.c_str(), topoAttr_.userRankSize, root, maxCountPerLoop, totalRecvCount); + + for (u64 countLeft = totalRecvCount, curRecvCount = 0, inputOffset = 0, outputOffset = 0; + countLeft > 0; countLeft -= curRecvCount) { + curUserInputPtr += inputOffset; + curUserOutputPtr += outputOffset; + + // 判断剩余数据量对应的input size是否大于中转input size + curRecvCount = + ((countLeft * unitSize * topoAttr_.userRankSize) > inCCLbuffer.size()) ? maxCountPerLoop : countLeft; + CHK_PRT_RET((curRecvCount == 0), HCCL_ERROR("[RunLoop][Scatter]In OP_BASE curRecvCount is zero"), HCCL_E_PARA); + u64 curRecvSize = curRecvCount * unitSize; // 单位:字节 + u64 curSendSize = topoAttr_.userRankSize * curRecvSize; // 单位:字节 + + DeviceMem curCCLInputMem(inCCLbuffer.ptr(), curSendSize); + DeviceMem curCCLOutputMem(outCCLbuffer.ptr(), curRecvSize); + + ExecMem execMem; + execMem.count = curRecvCount; + execMem.inputMem = curCCLInputMem; + execMem.outputMem = curCCLOutputMem; + execMem.scratchMem = algRes.scratchMem; + // 使用当前Loop偏移到的地址作为当前的inputPtr和outputPtr + execMem.inputPtr = curUserInputPtr; + execMem.outputPtr = curUserOutputPtr; + + HCCL_DEBUG("[RunLoop][Scatter] ScatterLoop: inputOffset[%llu], outputOffset[%llu], " + "curUserInputPtr[%llx], curUserOutputPtr[%llx], curRecvCount[%llu], curRecvSize[%llu], " + "curSendSize[%llu], inCCLbuffer.ptr[%llx], outCCLbuffer.ptr[%llx]", + inputOffset, outputOffset, curUserInputPtr, curUserOutputPtr, curRecvCount, curRecvSize, + curSendSize, inCCLbuffer.ptr(), outCCLbuffer.ptr()); + + CHK_RET(RunLoopInner(param, execMem, algRes)); + + inputOffset = curRecvSize; + outputOffset = curRecvSize; + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + } + return HCCL_SUCCESS; +} + +HcclResult CollScatterExecutor::RunLoopInner( + const OpParam ¶m, ExecMem &execMem, const AlgResourceResponse &algRes) +{ + auto dataType = param.DataDes.dataType; + u32 unitSize = SIZE_TABLE[dataType]; + RankId root = param.root; + + auto totalRecvCount = param.DataDes.count; + u64 totalRecvSize = totalRecvCount * unitSize; + + u64 recvSize = execMem.outputMem.size(); + + auto meta = HcclOpMetaInfo::GetOneForScatter(root, IsHugeData(execMem.outputMem.size())); + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), meta.isEnableCache, meta.GetCacheKey())); + + DeviceMem dstMem; + DeviceMem srcMem; + if (topoAttr_.userRank == root) { + // 本rank为root节点,非root节点不需要拷贝到中转内存 + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + // 拷贝input上每个slice的数据到中转内存,源端每个slice的size固定为totalRecvSize + srcMem = DeviceMem::create((u8*)execMem.inputPtr + totalRecvSize * i, recvSize); + dstMem = algRes.cclInputMem.range(recvSize * i, recvSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + } + } + + /* 记录指令信息用于一致性校验 */ + CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_SCATTER, tag_.c_str(), + execMem.count, dataType, root, algRes.cclInputMem.size(), algRes.cclOutputMem.size())); + + /* 入参的正确性由HCCL确保 */ + HcclResult ret = KernelRun(param, execMem); + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollScatterExecutor][RunLoop]errNo[0x%016llx] OP_BASE hcclComm scatter error, tag[%s], " + "input_ptr[%p], output_ptr[%p], recvSize[%llu], data_type[%d], root[%u]", + HCCL_ERROR_CODE(ret), tag_.c_str(), algRes.cclInputMem.ptr(), algRes.cclOutputMem.ptr(), + recvSize, dataType, root), + ret); + + CHK_RET(RankConsistent::GetInstance().DelOpPara(tag_)); + + // 将 CCLOut 上的数据搬运到 userOut + srcMem = algRes.cclOutputMem.range(0, recvSize); + dstMem = DeviceMem::create(execMem.outputPtr, recvSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, const_cast(param.stream))); + return HCCL_SUCCESS; +} + +HcclResult CollScatterExecutor::PrepareDataSlice(u64 dataCount, u32 unitSize, u32 sliceNum, + std::vector &dataSlice) +{ + CHK_PRT_RET((sliceNum == 0), HCCL_ERROR("[CollScatterExecutor][PrepareDataSlice]sliceNum is zero."), HCCL_E_PARA); + + dataSlice.resize(sliceNum); + u64 sliceSize = dataCount * unitSize; + for (u32 i = 0; i < sliceNum; i++) { + dataSlice[i].size = sliceSize; + dataSlice[i].offset = (i * sliceSize); + } + return HCCL_SUCCESS; +} + +HcclResult CollScatterExecutor::ReorderSlice(std::vector &dataSlice, std::vector &order) +{ + CHK_PRT_RET((dataSlice.size() != order.size()), + HCCL_ERROR("[ReorderSlice] data slize size [%zu], not equal to order size [%zu]", + dataSlice.size(), order.size()), HCCL_E_INTERNAL); + std::vector tempDataSegsSlice(dataSlice.size()); + for (size_t i = 0; i < dataSlice.size(); i++) { + CHK_PRT_RET(order[i] >= dataSlice.size(), + HCCL_ERROR("[ReorderSlice] order value [%zu] >= dataSlice size [%zu]", order[i], dataSlice.size()), + HCCL_E_INTERNAL); + tempDataSegsSlice[i] = dataSlice[order[i]]; + } + dataSlice = tempDataSegsSlice; + return HCCL_SUCCESS; +} + +HcclResult CollScatterExecutor::KernelRunInner(DeviceMem& inputMem, u64 count, HcclDataType dataType, + u32 &commIndex, u32 root, u32 &subRoot, CommPlane commLevel, Stream& stream) +{ + CHK_RET(CheckCommSize(commLevel, commIndex + 1)); + SubCommInfo subCommInfo = GetSubCommInfo(commLevel, commIndex); + + u32 subCommSize = subCommInfo.localRankSize; + + if (subCommSize <= 1 || subRoot != topoAttr_.userRank) { + HCCL_INFO("[ScatterRing][KernelRunInner]: no need to run intra-server, subCommSize[%u], subRoot[%u], " + "userRank[%u]", subCommSize, subRoot, topoAttr_.userRank); + return HCCL_SUCCESS; + } + + HCCL_INFO("[ScatterRing][KernelRunInner]: start to run intra-server, subCommSize[%u], subRoot[%u], " + "userRank[%u]", subCommSize, subRoot, topoAttr_.userRank); + + u32 rootRankInner = 0; + CHK_RET(GetRankByUserRank(commLevel, commIndex, root, rootRankInner)); + + std::unique_ptr innerExecutor; + if (UseInterServerNBAlgo(algType_)) { + // server间NB算法走NB + innerExecutor.reset(new (std::nothrow) ScatterNB(dispatcher_)); + CHK_SMART_PTR_NULL(innerExecutor); + HCCL_INFO("[ScatterRing][KernelRunInner]: using NB algo inter-server."); + // 申请临时内存作为scratch内存 + CHK_RET(innerExecutor->Prepare(inputMem, inputMem, inputMem, count * topoAttr_.userRankSize, + dataType, stream, HCCL_REDUCE_RESERVED, rootRankInner, std::vector(0))); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) ScatterNHR(dispatcher_)); + CHK_SMART_PTR_NULL(innerExecutor); + HCCL_INFO("[ScatterRing][KernelRunInner]: using NHR algo inter-server."); + CHK_RET(innerExecutor->Prepare(inputMem, inputMem, inputMem, count * topoAttr_.userRankSize, + dataType, stream, HCCL_REDUCE_RESERVED, rootRankInner, std::vector(0))); + } else { + innerExecutor.reset(new (std::nothrow) ScatterRing(dispatcher_)); + CHK_SMART_PTR_NULL(innerExecutor); + HCCL_INFO("[ScatterRing][KernelRunInner]: using ring algo inter-server."); + CHK_RET(innerExecutor->Prepare(inputMem, inputMem, inputMem, count * topoAttr_.userRankSize, + dataType, stream, HCCL_REDUCE_RESERVED, rootRankInner, std::vector(0))); // count是output的数据个数 + } + CHK_RET(innerExecutor->RegisterProfiler( + (subCommSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + subCommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); + + CHK_RET(RunTemplate(innerExecutor, subCommInfo)); + return HCCL_SUCCESS; +} + +HcclResult CollScatterExecutor::Orchestrate(const OpParam& param, + const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + tag_ = param.tag; + algResResp_ = &algRes; + GetStreamInfo(algRes); + auto rtStream = param.stream.ptr(); + HCCL_PROFILER_ADD_TAG(param.tag, algoAttr_.identifier, GetWorkflowMode()); + HCCL_PROFILER_ADD_STREAM(rtStream, param.tag, 0, algType_); + CHK_RET(AddSubStreamToProfiling()); + + HcclResult ret = HCCL_SUCCESS; + // 图模式和单卡场景下不需要Loop + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + execMem.inputMem = algRes.paramInputMem; + execMem.outputMem = algRes.paramOutputMem; + execMem.scratchMem = algRes.scratchMem; + ret = KernelRun(param, execMem); + } else { + ret = RunLoop(param, algRes); + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollScatterExecutor][Orchestrate]errNo[0x%016llx]all reudce excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_DEL_STREAM(rtStream); + HCCL_PROFILER_DEL_TAG(param.tag); + } + HCCL_INFO("tag[%s] Scatter executor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + +} diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..6ecc63fc9dbff02a36edb07dae082d29185b0408 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_executor.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_SCATTER_EXECUTOR_H +#define COLL_SCATTER_EXECUTOR_H + +#include "coll_native_executor_base.h" +#include "coll_alg_exec_registry.h" + +namespace hccl { + +// 所有 Scatter Executor 的基类,继承自 NativeExecutor +class CollScatterExecutor : public CollNativeExecutorBase { +public: + explicit CollScatterExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollScatterExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; +protected: + /* *************** 资源计算 *************** */ + virtual HcclResult CalcCommInfo(std::vector& opTransport); + virtual HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + // 按Inner、Outer、Level2可继续进行拆分。 + virtual HcclResult KernelRunInner(DeviceMem &inputMem, u64 count, HcclDataType dataType, u32 &commIndex, + u32 root, u32 &subRoot, CommPlane commLevel, Stream &stream); + // 用于需要Loop的Executor + virtual HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + virtual HcclResult RunLoopInner(const OpParam ¶m, ExecMem &execMem, const AlgResourceResponse &algRes); + + virtual bool IsHugeData(u64 curSize); + /* *************** 通用工具 *************** */ + virtual HcclResult PrepareDataSlice(u64 dataCount, u32 unitSize, u32 sliceNum, + std::vector &dataSlice); + virtual HcclResult ReorderSlice(std::vector &dataSlice, std::vector &order); +private: +}; + +} // namespace hccl + +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_mesh_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..5da43232aa6812d20c996f0cd7f51edea9a3ba21 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_mesh_executor.cc @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_scatter_mesh_executor.h" + +namespace hccl { +CollScatterMeshExecutor::CollScatterMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollScatterExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollScatterMeshExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollScatterMeshExecutor][CalcLevel0CommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollScatterMeshExecutor][CalcLevel0CommInfo]tag[%s] Calc meshComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollScatterMeshExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation > 1U ? topoAttr_.deviceNumPerAggregation - 1U : 1U; + streamNum = totalStreamNum - 1U; + return HCCL_SUCCESS; +} + +HcclResult CollScatterMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + Stream& stream = const_cast(param.stream); + + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 level0LocalRank = level0CommInfo.localRank; + u32 level0LocalRankSize = level0CommInfo.localRankSize; + + u32 commIndex = level0LocalRank; + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo level1CommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 level1LocalRank = level1CommInfo.localRank; + u32 level1LocalRankSize = level1CommInfo.localRankSize; + + bool bRet = level0LocalRankSize == 0; + CHK_PRT_RET(bRet, HCCL_ERROR("[CollScatterMeshExecutor][KernelRun]tag[%s],comm outer is empty", tag_.c_str()), + HCCL_E_INTERNAL); + + /* ***********第一步: 节点间scatter ****************************/ + u32 subRoot = topoMatcher_->GetSubRootForScatter(param.root); + CHK_PRT_RET(subRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[CollScatterMeshExecutor][KernelRun]GetSubRootForScatter failed, ", + "userRank[%u], root[%u], subRoot[%u]", topoAttr_.userRank, param.root, subRoot), HCCL_E_INTERNAL); + HCCL_DEBUG("[CollScatterMeshExecutor][KernelRun]GetSubRootForScatter, userRank[%u], root[%u], subRoot[%u]", + topoAttr_.userRank, param.root, subRoot); + CHK_RET(KernelRunInner(execMem.inputMem, execMem.count, param.DataDes.dataType, commIndex, + param.root, subRoot, COMM_LEVEL1, stream)); + + /* ***********第二步: 节点内scatter*****************************/ + // 根据数据量算每个环上数据的偏移和大小 + u32 sliceNum = level0LocalRankSize; + std::vector dataSegsSlice; + CHK_RET(PrepareDataSlice(execMem.count, perDataSize, sliceNum, dataSegsSlice)); + + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / level1LocalRankSize; + // 每个服务器对应的偏移 + u64 serverSliceOffset = serverSliceSize * level1LocalRank; + DeviceMem scatterMeshInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(scatterMeshInput); + DeviceMem scatterMeshOutput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(scatterMeshOutput); + + std::unique_ptr outerExecutor; + outerExecutor.reset( + new (std::nothrow) ScatterMesh(dispatcher_, level0LocalRank, level0LocalRankSize)); + CHK_SMART_PTR_NULL(outerExecutor); + + // 偏移需要带入prepare + u32 rootRankOuter = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL0, COMM_INDEX_0, param.root, rootRankOuter)); + CHK_PRT_RET(rootRankOuter == INVALID_VALUE_RANKID, + HCCL_ERROR("[CollScatterMeshExecutor][KernelRun]rootRankOuter[%u] is invalid, userRank[%u], subRoot[%u]", + rootRankOuter, topoAttr_.userRank, subRoot), HCCL_E_INTERNAL); + + CHK_RET(outerExecutor->Prepare(scatterMeshInput, scatterMeshOutput, execMem.inputMem, execMem.count, + param.DataDes.dataType, stream, HCCL_REDUCE_RESERVED, rootRankOuter, dataSegsSlice, serverSliceOffset)); + + HcclResult ret = RunTemplate(outerExecutor, level0CommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollScatterMeshExecutor][KernelRun]scatter(mesh) RunTemplate failed,return[%d]", ret), ret); + + // 将scratchMem赋值给outputMem + u8 *scatterMeshOutputPtr = static_cast(scatterMeshOutput.ptr()); + DeviceMem resultMem(scatterMeshOutputPtr + execMem.outputMem.size() * level0LocalRank, execMem.outputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, resultMem, stream)); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ScatterMeshExecutor", ScatterMesh, CollScatterMeshExecutor); + +} + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_mesh_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..206880687bd66af4c39a84a968f912fc3b52960e --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_mesh_executor.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_SCATTER_MESH_EXECUTOR_H +#define COLL_SCATTER_MESH_EXECUTOR_H + +#include "coll_scatter_executor.h" +#include "coll_alg_exec_registry.h" + +namespace hccl { + +// 所有 Scatter Executor 的基类,继承自 NativeExecutor +class CollScatterMeshExecutor : public CollScatterExecutor { +public: + explicit CollScatterMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollScatterMeshExecutor() = default; +protected: + /* *************** 资源计算 *************** */ + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + + HcclResult CalcStreamNum(u32& streamNum) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +private: +}; + +} // namespace hccl + +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_ring_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_ring_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7943d5d2c3d9d7d15959c16606b474081f215d0 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_ring_executor.cc @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + +#include "coll_scatter_ring_executor.h" + +namespace hccl { +CollScatterRingExecutor::CollScatterRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollScatterExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollScatterRingExecutor::CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) +{ + HCCL_INFO("[CollScatterRingExecutor][CalcLevel0CommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_RING_INNER); + + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + HCCL_INFO("[CollScatterRingExecutor][CalcLevel0CommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollScatterRingExecutor::CalcStreamNum(u32 &streamNum) +{ + u32 totalStreamNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + switch (algType_) { + case AlgType::ALG_8P_RING_PLUS_HD: + case AlgType::ALG_8P_RING_PLUS_RING: + case AlgType::ALG_8P_RING_PLUS_NHR: + case AlgType::ALG_8P_RING_PLUS_NHR_V1: + case AlgType::ALG_8P_RING_PLUS_NB: + case AlgType::ALG_8P_RING_PLUS_PIPELINE: + totalStreamNum = OUTER_PLANE_NUM_IN_8PRING; + break; + default: + break; + } + streamNum = totalStreamNum - 1; + HCCL_INFO("[CollScatterRingExecutor][CalcStreamNum] tag[%s] streamNum[%u]", tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollScatterRingExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + Stream& stream = const_cast(param.stream); + + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 level0LocalRank = level0CommInfo.localRank; + u32 level0LocalRankSize = level0CommInfo.localRankSize; + + u32 commIndex = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? topoAttr_.devicePhyId : level0LocalRank; + + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo level1CommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 level1LocalRank = level1CommInfo.localRank; + u32 level1LocalRankSize = level1CommInfo.localRankSize; + + bool bRet = level0LocalRankSize == 0; + CHK_PRT_RET(bRet, HCCL_ERROR("[CollScatterRingExecutor][KernelRun]tag[%s],comm outer is empty", tag_.c_str()), + HCCL_E_INTERNAL); + + /* ***********第一步: 节点间scatter ****************************/ + u32 subRoot = topoMatcher_->GetSubRootForScatter(param.root); + CHK_PRT_RET(subRoot == INVALID_VALUE_RANKID, + HCCL_ERROR("[CollScatterRingExecutor][KernelRun]GetSubRootForScatter failed, ", + "userRank[%u], root[%u], subRoot[%u]", topoAttr_.userRank, param.root, subRoot), HCCL_E_INTERNAL); + HCCL_DEBUG("[CollScatterRingExecutor][KernelRun]GetSubRootForScatter, userRank[%u], root[%u], subRoot[%u]", + topoAttr_.userRank, param.root, subRoot); + CHK_RET(KernelRunInner(execMem.inputMem, execMem.count, param.DataDes.dataType, commIndex, + param.root, subRoot, COMM_LEVEL1, stream)); + + /* ***********第二步: 节点内scatter*****************************/ + u32 sliceNum = level0LocalRankSize; + std::vector dataSegsSlice; + u32 outputOffset = level0LocalRank; + CHK_RET(PrepareScatterRingSliceData(execMem.count, perDataSize, sliceNum, dataSegsSlice, outputOffset)); + + // 每个server分配的slice大小 + u64 serverSliceSize = execMem.inputMem.size() / level1LocalRankSize; + // 每个服务器对应的偏移 + u64 serverSliceOffset = serverSliceSize * level1LocalRank; + DeviceMem scatterRingInput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(scatterRingInput); + DeviceMem scatterRingOutput = execMem.inputMem.range(serverSliceOffset, serverSliceSize); + CHK_SMART_PTR_NULL(scatterRingOutput); + + std::unique_ptr outerExecutor; + outerExecutor.reset(new (std::nothrow) ScatterRing(dispatcher_)); + CHK_SMART_PTR_NULL(outerExecutor); + + // 偏移需要带入prepare + u32 rootRankOuter = 0; + CHK_RET(GetRankByUserRank(COMM_LEVEL0, COMM_INDEX_0, param.root, rootRankOuter)); + CHK_PRT_RET(rootRankOuter == INVALID_VALUE_RANKID, + HCCL_ERROR("[CollScatterRingExecutor][KernelRun]rootRankOuter[%u] is invalid, userRank[%u], subRoot[%u]", + rootRankOuter, topoAttr_.userRank, subRoot), + HCCL_E_INTERNAL); + + CHK_RET(outerExecutor->Prepare(scatterRingInput, scatterRingOutput, execMem.inputMem, execMem.count, + param.DataDes.dataType, stream, HCCL_REDUCE_RESERVED, rootRankOuter, dataSegsSlice, serverSliceOffset)); + + HcclResult ret = RunTemplate(outerExecutor, level0CommInfo); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollScatterRingExecutor][KernelRun]scatter(ring) RunTemplate failed,return[%d]", ret), + ret); + + // 将scratchMem赋值给outputMem + u8 *scatterRingOutputPtr = static_cast(scatterRingOutput.ptr()); + DeviceMem resultMem(scatterRingOutputPtr + execMem.outputMem.size() * outputOffset, execMem.outputMem.size()); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, execMem.outputMem, resultMem, stream)); + return HCCL_SUCCESS; +} + +HcclResult CollScatterRingExecutor::PrepareScatterRingSliceData(u64 dataCount, u32 unitSize, u32 sliceNum, + std::vector &dataSlice, u32 &outputOffset) +{ + CHK_PRT_RET((sliceNum == 0), + HCCL_ERROR("[CollScatterRingExecutor][PrepareScatterRingSliceData]sliceNum is zero."), HCCL_E_PARA); + + // 根据数据量算每个环上数据的偏移和大小 + CHK_RET(PrepareDataSlice(dataCount, unitSize, sliceNum, dataSlice)); + + if (topoType_ == TopoType::TOPO_TYPE_8P_RING) { + std::vector tmpRing0 = { 0, 1, 2, 6, 5, 4, 7, 3 }; + outputOffset = tmpRing0[outputOffset]; + CHK_RET(ReorderSlice(dataSlice, tmpRing0)); + } + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ScatterRingExecutor", ScatterRing, CollScatterRingExecutor); + +} + diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_ring_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_ring_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..774fcf92cd13ca409aa143e1c0ffb8d1268caf05 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_scatter/coll_scatter_ring_executor.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_SCATTER_RING_EXECUTOR_H +#define COLL_SCATTER_RING_EXECUTOR_H + +#include "coll_scatter_executor.h" +#include "coll_alg_exec_registry.h" + +namespace hccl { + +// 所有 Scatter Executor 的基类,继承自 NativeExecutor +class CollScatterRingExecutor : public CollScatterExecutor { +public: + explicit CollScatterRingExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); + ~CollScatterRingExecutor() = default; +protected: + /* *************** 资源计算 *************** */ + HcclResult CalcLevel0CommInfo(TransportMemType inputType, + TransportMemType outputType, + std::vector& opTransport) override; + + HcclResult CalcStreamNum(u32& streamNum) override; + + /* *************** 算法编排 *************** */ + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +private: + HcclResult PrepareScatterRingSliceData(u64 dataCount, u32 unitSize, u32 sliceNum, + std::vector &dataSlice, u32 &outputOffset); +}; + +} // namespace hccl + +#endif diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/CMakeLists.txt index ec123eae6eb8433f2f8c67eec70be1ab4c560d24..c6cb023cdfe70595e1ea6315eaa997f8057a6bfe 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/CMakeLists.txt @@ -1,7 +1,9 @@ set(src_list ${CMAKE_CURRENT_SOURCE_DIR}/coll_batch_send_recv_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_send_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_receive_executor.cc ) - + target_sources(hccl_alg PRIVATE ${src_list} ) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.cc index cfdfecfc5467257a28fabdd0e05c71585286aae0..fdbfbe5f0367a407e57f3359cf050923cd5bacf9 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.cc @@ -11,8 +11,10 @@ #include "coll_batch_send_recv_executor.h" namespace hccl { -CollBatchSendRecvExecutor::CollBatchSendRecvExecutor(std::unique_ptr &pImpl) - : CollCommExecutor(pImpl) + +CollBatchSendRecvExecutor::CollBatchSendRecvExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) { } @@ -30,43 +32,16 @@ void CollBatchSendRecvExecutor::ParseParam(const OpParam& param) HcclSendRecvItem** itemPtr = param.BatchSendRecvDataDes.orderedList; u32 itemNum = param.BatchSendRecvDataDes.itemNum; commTargetUserRankSet_.clear(); - std::set totalTargetUserRankSet = {}; - (void)hcclImpl_->GetTotalTargetRankSet(totalTargetUserRankSet); for (u32 i = 0; i < itemNum; i++) { commTargetUserRankSet_.insert((*(itemPtr + i))->remoteRank); - totalTargetUserRankSet.insert((*(itemPtr + i))->remoteRank); - HCCL_INFO("[CollBatchSendRecvExecutor][ParseParam] insert remoteUserRank to Set %u", + HCCL_INFO("[CollBatchSendRecvExecutor][ParseParam] insert remoteUserRank[%u] to Set", (*(itemPtr + i))->remoteRank); } - (void)hcclImpl_->UpdateTotalTargetRankSet(totalTargetUserRankSet); } -bool CollBatchSendRecvExecutor::NeedIncrCreateLink(const OpParam& param) -{ - tag_ = param.tag; - bool isNeed = false; - HcclSendRecvItem** itemPtr = param.BatchSendRecvDataDes.orderedList; - u32 itemNum = param.BatchSendRecvDataDes.itemNum; - - commTargetUserRankSet_.clear(); - std::set totalTargetUserRankSet = {}; - CHK_RET(hcclImpl_->GetTotalTargetRankSet(totalTargetUserRankSet)); - for (u32 i = 0; i < itemNum; i++) { - auto it = totalTargetUserRankSet.find((*(itemPtr + i))->remoteRank); - if (it == totalTargetUserRankSet.end()) { - commTargetUserRankSet_.insert((*(itemPtr + i))->remoteRank); - totalTargetUserRankSet.insert((*(itemPtr + i))->remoteRank); - HCCL_INFO("[CollBatchSendRecvExecutor][NeedIncrCreateLink] Add targetUserRank[%u].", - (*(itemPtr + i))->remoteRank); - isNeed = true; - } - } - CHK_RET(hcclImpl_->UpdateTotalTargetRankSet(totalTargetUserRankSet)); - return isNeed; -} - HcclResult CollBatchSendRecvExecutor::CalcIncreLinkRequest(const OpParam& param, AlgResourceRequest& resourceRequest) { + (void)ParseParam(param); u64 scratchMemSize = 0U; u32 streamNum = 0U; u32 notifyNum = 0U; @@ -75,13 +50,8 @@ HcclResult CollBatchSendRecvExecutor::CalcIncreLinkRequest(const OpParam& param, std::vector opTransport { std::vector(static_cast(COMM_LEVEL_RESERVED)) }; - CalcCommInfo(opTransport); - BuildResourceRequest(scratchMemSize, streamNum, notifyNum, needAivBuffer, opTransport, resourceRequest); - HCCL_INFO("[CollBatchSendRecvExecutor][CalcIncreLinkRequest] StreamNum[%u], notifyNum[%u], sctrachMemSize[%llu]," \ - "needAivBuffer[%u]", resourceRequest.streamNum, resourceRequest.notifyNum, resourceRequest.scratchMemSize, - resourceRequest.needAivBuffer); return HCCL_SUCCESS; } @@ -96,9 +66,9 @@ HcclResult CollBatchSendRecvExecutor::Orchestrate(const OpParam& param, const Al HCCL_PROFILER_ADD_TAG(param.tag, algoAttr_.identifier, GetWorkflowMode()); HCCL_PROFILER_ADD_STREAM(rtStream, param.tag, 0, algType_); - CHK_RET(hcclImpl_->AddSubStreamToProfiling(param.tag, HcclCMDType::HCCL_CMD_BATCH_SEND_RECV)); + CHK_RET(AddSubStreamToProfiling()); - if (GetExternalInputHcclEnableFfts()) { + if (topoMatcher_->GetExternalInputHcclEnableFfts()) { auto meta = HcclOpMetaInfo::GetOneForBatchSendRecv(); CHK_RET(InitTask(dispatcher_, const_cast(param.stream), meta.isEnableCache, meta.GetCacheKey())); // 多流子图前后需加空拷贝 @@ -133,7 +103,7 @@ HcclResult CollBatchSendRecvExecutor::Orchestrate(const OpParam& param, const Al PROF_STAGE_0); CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[BatchSendRecv] stream wait failed"), ret); - if (GetExternalInputHcclEnableFfts()) { + if (topoMatcher_->GetExternalInputHcclEnableFfts()) { // 多流子图前后需加空拷贝 CHK_RET(ExecutorBase::ExecEmptyTask(const_cast(algResource.cclInputMem), const_cast(algResource.cclOutputMem), const_cast(param.stream), dispatcher_)); @@ -281,7 +251,7 @@ u64 CollBatchSendRecvExecutor::CalcRecvLoopMaxCount(DeviceMem& outCCLBuffer, con HcclResult CollBatchSendRecvExecutor::CalcStreamNum(u32& streamNum) { streamNum = 1U; - HCCL_INFO("[CollBatchSendRecvExecutor][CalcScratchMemSize] tag_[%s].", tag_.c_str()); + HCCL_INFO("[CollBatchSendRecvExecutor][CalcScratchMemSize] tag_[%s], streamNum[%u].", tag_.c_str(), streamNum); return HCCL_SUCCESS; } HcclResult CollBatchSendRecvExecutor::CalcCommInfo(std::vector& opTransport) diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.h index b6ab2e2957b956d80b16ec112a671e2f9afe5904..087d0e4e5a77eef5b23b6e6fe79a8eb1db596f68 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_batch_send_recv_executor.h @@ -16,12 +16,11 @@ namespace hccl { class CollBatchSendRecvExecutor : public CollCommExecutor { public: - CollBatchSendRecvExecutor(std::unique_ptr &pImpl); + CollBatchSendRecvExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); ~CollBatchSendRecvExecutor() = default; HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; // 增量建链资源计算接口 HcclResult CalcIncreLinkRequest(const OpParam& param, AlgResourceRequest& resourceRequest) override; - bool NeedIncrCreateLink(const OpParam& param) override; private: /* *************** 资源计算 *************** */ void ParseParam(const OpParam& param) override; diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_receive_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_receive_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..9766af39c6323fa6fe45b8ba86feeae08684921d --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_receive_executor.cc @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_receive_executor.h" + +namespace hccl { + +CollReceiveExecutor::CollReceiveExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollNativeExecutorBase(dispatcher, topoMatcher) +{ +} + +HcclResult CollReceiveExecutor::Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + tag_ = param.tag; + algResResp_ = &algRes; + GetStreamInfo(algRes); + + HcclResult ret = HCCL_SUCCESS; + // 图模式场景下不需要Loop + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + DeviceMem outputMem = algRes.paramOutputMem; + ret = RunTemplate(param, outputMem); + } else { + ret = RunLoop(param, algRes); + } + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollReceiveExecutor][Orchestrate]errNo[0x%016llx]recv excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + HCCL_INFO("tag[%s], Receive Excutor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + +HcclResult CollReceiveExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_OUTPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollRecvExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollReceiveExecutor::CalcP2PCommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport, u32 srcRank) +{ + HCCL_INFO("[CollRecvExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commP2P(COMM_COMBINE, CommType::COMM_TAG_P2P); + commP2P.peerUserRank = srcRank; + CHK_RET(CalcCommPlaneInfo(tag_, commP2P, opTransport[COMM_COMBINE], inputType, outputType)); + HCCL_INFO("[CollRecvExecutor][CalcOuterCommInfo]tag[%s] Calc RingComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollReceiveExecutor::CalcCommInfo(std::vector& opTransport, u32 srcRank) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcP2PCommInfo(inputType, outputType, opTransport, srcRank); + return HCCL_SUCCESS; +} + +HcclResult CollReceiveExecutor::CalcResRequest(const OpParam& param, AlgResourceRequest& resourceRequest) +{ + ParseParam(param); + + u64 scratchMemSize = 0U; + u32 streamNum = 0U; + u32 notifyNum = 0U; + bool needAivBuffer = false; + std::vector opTransport { + std::vector(static_cast(COMM_LEVEL_RESERVED)) + }; + + CalcCommInfo(opTransport, param.srcRank); + + BuildResourceRequest(scratchMemSize, streamNum, notifyNum, needAivBuffer, opTransport, resourceRequest); + HCCL_INFO("streamNum[%u], notifyNum[%u], sctrachMemSize[%llu], needAivBuffer[%u]", + resourceRequest.streamNum, resourceRequest.notifyNum, resourceRequest.scratchMemSize, + resourceRequest.needAivBuffer); + // 打印建链诉求 + PrintTransportRequest(resourceRequest); + return HCCL_SUCCESS; +} + +HcclResult CollReceiveExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + HcclResult ret; + + u64 commOutputSize = algRes.cclOutputMem.size(); + + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + + auto meta = HcclOpMetaInfo::GetOneForRecieve(); + u8 *curOutputPtr = static_cast(param.outputPtr); + CHK_PTR_NULL(curOutputPtr); + + u64 outputOffset = 0; + u64 countLeft = param.DataDes.count; + while (countLeft > 0) { + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), meta.isEnableCache, meta.GetCacheKey())); + curOutputPtr += outputOffset; + HCCL_DEBUG("RecvOutPlace:outputOffset[%llu]", outputOffset); + u64 curCount = ((countLeft * unitSize) > commOutputSize) ? (commOutputSize / unitSize) : countLeft; + u64 curSize = curCount * unitSize; // 单位 byte + HCCL_DEBUG("RecvOutPlace:curOutputPtr[%p], curCount[%llu], curSize[%llu]", curOutputPtr, curCount, curSize); + + /* 记录指令信息用于一致性校验 */ + ret = RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_RECEIVE, param.tag, curCount, + param.DataDes.dataType, commOutputSize, 0, HCCL_WORLD_GROUP); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("errNo[0x%016llx] record CMD with parameter error", HCCL_ERROR_CODE(ret)), ret); + + DeviceMem cclOutputMem(algRes.cclOutputMem.ptr(), curSize); + ret = RunTemplate(param, cclOutputMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("errNo[0x%016llx] RecvOutPlace: recv error, tag[%s], ptr[%p], count[%llu], dataType[%d]", + HCCL_ERROR_CODE(ret), param.tag.c_str(), curOutputPtr, curCount, param.DataDes.dataType), + ret); + + ret = RankConsistent::GetInstance().DelOpPara(param.tag); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("errNo[0x%016llx] delete CMD with parameters error. tag[%s]", HCCL_ERROR_CODE(ret), + param.tag.c_str()), ret); + + DeviceMem outCommMem(cclOutputMem.ptr(), curSize); + DeviceMem outMem(curOutputPtr, curSize); + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outMem, outCommMem, const_cast(param.stream))); + CHK_PRT_RET((curCount == 0), HCCL_ERROR("In OP_BASE curCount is zero"), HCCL_E_PARA); + countLeft -= curCount; + outputOffset = curSize; + + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + } + return HCCL_SUCCESS; +} + +HcclResult CollReceiveExecutor::RunTemplate(const OpParam ¶m, DeviceMem &outputMem) +{ + SubCommInfo commInfo = GetSubCommInfo(COMM_COMBINE, 0); + if (commInfo.links.size() == 0) { + HCCL_ERROR("[CollReceiveExecutor]links size is 0"); + } + LINK transportLink = commInfo.links[0]; + + SendReceive ReceiveExecutor(dispatcher_, transportLink); + ReceiveExecutor.ReceivePrepare(outputMem, param.srcRank, param.stream); + ReceiveExecutor.RegisterProfiler(0, PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream); + CHK_RET(ReceiveExecutor.ReceiveRunAsync()); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("ReceiveExecutor", Receive, CollReceiveExecutor); + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_receive_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_receive_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..2dd2abab0fb7020bea21614dcc89c01b49b181ca --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_receive_executor.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_RECEIVE_EXECUTOR_H +#define COLL_RECEIVE_EXECUTOR_H +#include "coll_comm_executor.h" +#include "executor_base_pub.h" + +namespace hccl { +class CollReceiveExecutor : public CollNativeExecutorBase { + +public: + CollReceiveExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollReceiveExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcResRequest(const OpParam& param, AlgResourceRequest &resourceRequest) override; + HcclResult CalcCommInfo(std::vector& opTransport, u32 srcRank); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + HcclResult CalcP2PCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport, u32 srcRank); + + /* *************** 算法编排 *************** */ + HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + HcclResult RunTemplate(const OpParam ¶m, DeviceMem &outputMem); +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_send_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_send_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd9f6208790e98907ca5a7d6bdbd13fa9b7cbd9d --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_send_executor.cc @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_send_executor.h" + +namespace hccl { + +CollSendExecutor::CollSendExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollNativeExecutorBase(dispatcher, topoMatcher) +{ +} + +HcclResult CollSendExecutor::Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + tag_ = param.tag; + algResResp_ = &algRes; + GetStreamInfo(algRes); + + HcclResult ret = HCCL_SUCCESS; + // 图模式场景下不需要Loop + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + DeviceMem InputMem = algRes.paramInputMem; + ret = RunTemplate(param, InputMem); + } else { + ret = RunLoop(param, algRes); + } + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollSendExecutor][Orchestrate]errNo[0x%016llx]send excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + HCCL_INFO("tag[%s] Send Excutor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + +HcclResult CollSendExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_INPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollSendExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollSendExecutor::CalcP2PCommInfo(TransportMemType inputType, + TransportMemType outputType, std::vector& opTransport, u32 dstRank) +{ + HCCL_INFO("[CollSendExecutor][CalcOuterCommInfo]tag[%s ]start", tag_.c_str()); + CommParaInfo commP2P(COMM_COMBINE, CommType::COMM_TAG_P2P); + commP2P.peerUserRank = dstRank; + CHK_RET(CalcCommPlaneInfo(tag_, commP2P, opTransport[COMM_COMBINE], inputType, outputType)); + HCCL_INFO("[CollSendExecutor][CalcOuterCommInfo]tag[%s] Calc P2PComm finish", tag_.c_str()); + return HCCL_SUCCESS; +} + +HcclResult CollSendExecutor::CalcCommInfo(std::vector& opTransport, u32 dstRank) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CalcTransportMemType(inputType, outputType); + CalcP2PCommInfo(inputType, outputType, opTransport, dstRank); + return HCCL_SUCCESS; +} + +HcclResult CollSendExecutor::CalcResRequest(const OpParam& param, AlgResourceRequest& resourceRequest) +{ + ParseParam(param); + + u64 scratchMemSize = 0U; + u32 streamNum = 0U; + u32 notifyNum = 0U; + bool needAivBuffer = false; + std::vector opTransport { + std::vector(static_cast(COMM_LEVEL_RESERVED)) + }; + + CalcCommInfo(opTransport, param.dstRank); + + BuildResourceRequest(scratchMemSize, streamNum, notifyNum, needAivBuffer, opTransport, resourceRequest); + HCCL_INFO("streamNum[%u], notifyNum[%u], sctrachMemSize[%llu], needAivBuffer[%u]", + resourceRequest.streamNum, resourceRequest.notifyNum, resourceRequest.scratchMemSize, + resourceRequest.needAivBuffer); + // 打印建链诉求 + PrintTransportRequest(resourceRequest); + return HCCL_SUCCESS; +} + +HcclResult CollSendExecutor::RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes) +{ + HcclResult ret; + + u64 commInputSize = algRes.cclInputMem.size(); + + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + auto meta = HcclOpMetaInfo::GetOneForSend(); + u8 *curInputPtr = static_cast(param.inputPtr); + CHK_PTR_NULL(curInputPtr); + + u64 inputOffset = 0; + u64 countLeft = param.DataDes.count; + while (countLeft > 0) { + CHK_RET(InitTask(dispatcher_, const_cast(param.stream), meta.isEnableCache, meta.GetCacheKey())); + curInputPtr += inputOffset; + + HCCL_DEBUG("SendOutPlace:inputOffset[%llu]", inputOffset); + u64 curCount = ((countLeft * unitSize) > commInputSize) ? (commInputSize / unitSize) : countLeft; + u64 curSize = curCount * unitSize; // 单位 byte + + HCCL_DEBUG("SendOutPlace:curInputPtr[%p], curCount[%llu], curSize[%llu]", curInputPtr, curCount, curSize); + + DeviceMem inCommMem(algRes.cclInputMem.ptr(), curSize); + DeviceMem inMem(curInputPtr, curSize); + + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, inCommMem, inMem, const_cast(param.stream))); + + /* 记录指令信息用于一致性校验 */ + ret = RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_SEND, param.tag, curCount, + param.DataDes.dataType, commInputSize, 0, HCCL_WORLD_GROUP); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("errNo[0x%016llx] record CMD with parameter error", HCCL_ERROR_CODE(ret)), ret); + + ret = RunTemplate(param, inCommMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("errNo[0x%016llx] SendOutPlace: send error, tag[%s], ptr[%p], count[%llu], dataType[%d]", + HCCL_ERROR_CODE(ret), param.tag.c_str(), curInputPtr, curCount, param.DataDes.dataType), + ret); + + ret = RankConsistent::GetInstance().DelOpPara(param.tag); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("errNo[0x%016llx] delete CMD with parameters error. tag[%s]", HCCL_ERROR_CODE(ret), + param.tag.c_str()), ret); + CHK_PRT_RET((curCount == 0), HCCL_ERROR("In OP_BASE curCount is zero"), HCCL_E_PARA); + countLeft -= curCount; + inputOffset = curSize; + + CHK_RET(LaunchTask(dispatcher_, const_cast(param.stream))); + } + return HCCL_SUCCESS; +} + +HcclResult CollSendExecutor::RunTemplate(const OpParam ¶m, DeviceMem &inputMem) +{ + SubCommInfo commInfo = GetSubCommInfo(COMM_COMBINE, 0); + if (commInfo.links.size() == 0) { + HCCL_ERROR("[CollSendExecutor]links size is 0"); + } + LINK transportLink = commInfo.links[0]; + + SendReceive sendExecutor(dispatcher_, transportLink); + sendExecutor.SendPrepare(inputMem, param.dstRank, param.stream); + sendExecutor.RegisterProfiler(0, PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream); + sendExecutor.SendRunAsync(); + + return HCCL_SUCCESS; +} + +REGISTER_EXEC("SendExecutor", Send, CollSendExecutor); + +} // namespace hcclss \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_send_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_send_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..8dbee2ca17b7a89cec13631081bd25119539b430 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_send_receive/coll_send_executor.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_SEND_EXECUTOR_H +#define COLL_SEND_EXECUTOR_H +#include "coll_comm_executor.h" +#include "executor_base_pub.h" + +namespace hccl { +class CollSendExecutor : public CollNativeExecutorBase { + +public: + CollSendExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollSendExecutor() = default; + + HcclResult Orchestrate(const OpParam& param, const AlgResourceResponse& algRes) override; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcResRequest(const OpParam& param, AlgResourceRequest &resourceRequest) override; + HcclResult CalcCommInfo(std::vector& opTransport, u32 dstRank); + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + HcclResult CalcP2PCommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport, u32 dstRank); + + /* *************** 算法编排 *************** */ + HcclResult RunLoop(const OpParam ¶m, const AlgResourceResponse &algRes); + HcclResult RunTemplate(const OpParam ¶m, DeviceMem &inputMem); +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.cc index 928954d3ab9904b13557e37797cd66923eee50bf..0df994453069719e59995f81585e6142026d34e5 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.cc +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.cc @@ -30,14 +30,14 @@ HcclResult CollAlgExecRegistry::Register(const std::string &tag, const CollExecC } std::unique_ptr CollAlgExecRegistry::GetAlgExec( - const std::string &tag, std::unique_ptr &pImpl) + const std::string &tag, const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher) { if (execCreators_.find(tag) == execCreators_.end()) { HCCL_DEBUG("[CollAlgExecRegistry]Creator for executor tag[%s] has not registered.", tag.c_str()); return nullptr; } HCCL_DEBUG("[CollAlgExecRegistry][GetAlgExec]get executor by algName[%s]", tag.c_str()); - return std::unique_ptr(execCreators_[tag](pImpl)); + return std::unique_ptr(execCreators_[tag](dispatcher, topoMatcher)); } } // namespace Hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.h b/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.h index 7cd38bf6f8cb7d1871149e0daf26be4d428e81db..b396c26b95b104ea5fb1f34ec1285947ccd1b7a6 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.h +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/registry/coll_alg_exec_registry.h @@ -19,19 +19,21 @@ namespace hccl { -using CollExecCreator = std::function &)>; -template static CollExecutorBase *DefaultExecCreator(std::unique_ptr &pImpl) +using CollExecCreator = std::function &)>; +template +static CollExecutorBase *DefaultExecCreator(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher) { static_assert(std::is_base_of::value, "Executor type must derived from Hccl::CollExecutorBase"); - return new (std::nothrow) P(pImpl); + return new (std::nothrow) P(dispatcher, topoMatcher); } class CollAlgExecRegistry { public: static CollAlgExecRegistry *Instance(); HcclResult Register(const std::string &tag, const CollExecCreator &collAlgExecCreator); - std::unique_ptr GetAlgExec(const std::string &tag, std::unique_ptr &pImpl); + std::unique_ptr GetAlgExec(const std::string &tag, const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher); private: std::unordered_map execCreators_; @@ -45,6 +47,5 @@ private: #define REGISTER_EXEC_HELPER_1(ctr, tag, name, collExecBase) REGISTER_EXEC_HELPER(ctr, tag, name, collExecBase) #define REGISTER_EXEC(tag, name, collExecBase) REGISTER_EXEC_HELPER_1(__COUNTER__, tag, name, collExecBase) - } // namespace hccl #endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/executor_impl.h b/src/domain/collective_communication/algorithm/impl/executor_impl.h index 56d526320417bb4bf64690e22c8e03c3ff5aecf7..91896cb3a3bfb7e11897346904c71936601f849c 100644 --- a/src/domain/collective_communication/algorithm/impl/executor_impl.h +++ b/src/domain/collective_communication/algorithm/impl/executor_impl.h @@ -78,6 +78,10 @@ #include "reduce_scatter_nb_pub.h" #include "reduce_scatter_pipeline_pub.h" #include "all_reduce_opbase_pipeline_pub.h" +#include "allltoall_pipeline_mesh_pairwise_ping_pong_pub.h" +#include "allltoall_pipeline_mesh_pairwise_ccl_enough_pub.h" +#include "allltoall_pipeline_base_pub.h" +#include "alltoallv_mesh_read_only_pub.h" namespace hccl { } diff --git a/src/domain/collective_communication/algorithm/impl/hccl_alg.cc b/src/domain/collective_communication/algorithm/impl/hccl_alg.cc index 052acea67355ee3a963e7f0edc9c5104616e3973..4100c38e0ed051827a12ece53f359d59cdb37775 100644 --- a/src/domain/collective_communication/algorithm/impl/hccl_alg.cc +++ b/src/domain/collective_communication/algorithm/impl/hccl_alg.cc @@ -20,6 +20,7 @@ #include "send_receive_operator.h" #include "alltoall_operator.h" #include "coll_alg_op_registry.h" +#include "topo_matcher.h" namespace hccl { HcclAlg::HcclAlg() @@ -38,10 +39,26 @@ HcclResult HcclAlg::Init(const void* transportResourceInfoAddr, size_t transport const std::unique_ptr &queueNotifyManager, HcclAlgoAttr &algoAttr, HcclTopoAttr &topoAttr, bool isHeterogComm) { + CHK_RET(InitAlgoInfo(algoAttr)); + CHK_RET(InitTopoInfoPartOne(topoAttr)); pimpl_.reset((new (std::nothrow) hcclImpl(dispatcher, vDispatcher, notifyPool, netDevCtxMap, queueNotifyManager, workSpaceRes, cclBufferManager, transportResourceInfoAddr, transportResourceInfoSize, algoAttr, topoAttr))); CHK_SMART_PTR_NULL(pimpl_); - return pimpl_->Init(isHeterogComm); + CHK_RET(pimpl_->Init(isHeterogComm)); + std::vector>> CommPlaneRanks; + CHK_RET(pimpl_->GetCommPlaneRanks(CommPlaneRanks)); + std::vector isBridgeVector; + CHK_RET(pimpl_->GetIsBridgeVector(isBridgeVector)); + CHK_RET(InitTopoInfoPartTwo()); + CHK_RET(InitExternalEnable()); + std::vector>> serverAndsuperPodToRank; + serverAndsuperPodToRank.clear(); + CHK_RET(pimpl_->GetRankVecInfo(serverAndsuperPodToRank)); + + topoMatcher_.reset((new (std::nothrow) TopoMatcher(CommPlaneRanks, isBridgeVector, + topoInfo_, algoInfo_, externalEnable_, + serverAndsuperPodToRank))); + return HCCL_SUCCESS; } // 上层保证,以下方法在初始化成功后才会调用,所以未对pimpl_进行保护判断 HcclResult HcclAlg::ReleaseCommInfos() @@ -55,66 +72,31 @@ std::unique_ptr HcclAlg::GetAlgOperator(const HcclCMDType &opTy HCCL_ERROR("[HcclAlg][GetAlgOperator] impl ptr is null, get algorithm operator failed."); return nullptr; } - return CollAlgOpRegistry::Instance()->GetAlgOp(opType, pimpl_); + if (!topoMatcher_) { + HCCL_ERROR("[HcclAlg][GetAlgOperator] topoMatcher ptr is null, get algorithm operator failed."); + return nullptr; + } + return CollAlgOpRegistry::Instance()->GetAlgOp(opType, pimpl_, topoMatcher_); } HcclResult HcclAlg::AllGather(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, Stream stream, HcomCollOpInfo *opInfo) { - AllGatherOperator operation(pimpl_); + AllGatherOperator operation(pimpl_, topoMatcher_); return operation.AllGather(tag, inputPtr, outputPtr, inputCount, dataType, stream, opInfo); } HcclResult HcclAlg::AllGatherOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, Stream stream, const std::unique_ptr &opBaseAtraceInfo) { - AllGatherOperator operation(pimpl_); + AllGatherOperator operation(pimpl_, topoMatcher_); return operation.AllGatherOutPlace(tag, inputPtr, outputPtr, inputCount, dataType, stream, opBaseAtraceInfo); } -HcclResult HcclAlg::AlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, HcclDataType sendType, - const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, Stream stream, - const std::string &tag) -{ - AlltoAllOperator operation(pimpl_); - return operation.AlltoAllV( - sendBuf, sendCounts, sdispls, sendType, recvBuf, recvCounts, rdispls, recvType, stream, tag); -} - -HcclResult HcclAlg::AlltoAllVOutPlace(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag) -{ - AlltoAllOperator operation(pimpl_); - return operation.AlltoAllVOutPlace( - sendBuf, sendCounts, sdispls, sendType, recvBuf, recvCounts, rdispls, recvType, stream, tag); -} - -HcclResult HcclAlg::AlltoAllVC(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag) -{ - AlltoAllOperator operation(pimpl_); - return operation.AlltoAllVC(sendBuf, sendCountMatrix, sendType, recvBuf, recvType, stream, tag); -} - -HcclResult HcclAlg::AlltoAllVCOutPlace(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag) -{ - AlltoAllOperator operation(pimpl_); - return operation.AlltoAllVCOutPlace(sendBuf, sendCountMatrix, sendType, recvBuf, recvType, stream, tag); -} - -HcclResult HcclAlg::AlltoAll(const void *sendBuf, u64 sendCount, HcclDataType sendType, const void *recvBuf, - u64 recvCount, HcclDataType recvType, Stream stream, const std::string &tag) -{ - AlltoAllOperator operation(pimpl_); - return operation.AlltoAll(sendBuf, sendCount, sendType, recvBuf, recvCount, recvType, stream, tag); -} - HcclResult HcclAlg::Broadcast( const std::string &tag, void *ptr, u64 count, HcclDataType dataType, u32 root, Stream stream) { - BroadCastOperator operation(pimpl_); + BroadCastOperator operation(pimpl_, topoMatcher_); return operation.Broadcast(tag, ptr, count, dataType, root, stream); } @@ -122,14 +104,14 @@ HcclResult HcclAlg::BroadcastOutPlace( const std::string &tag, void *ptr, u64 count, HcclDataType dataType, u32 root, Stream stream, const std::unique_ptr &opBaseAtraceInfo) { - BroadCastOperator operation(pimpl_); + BroadCastOperator operation(pimpl_, topoMatcher_); return operation.BroadcastOutPlace(tag, ptr, count, dataType, root, stream); } HcclResult HcclAlg::Scatter(const std::string &tag, void *inputPtr, void *outputPtr, u64 recvCount, HcclDataType dataType, u32 root, Stream stream) { - ScatterOperator operation(pimpl_); + ScatterOperator operation(pimpl_, topoMatcher_); return operation.Scatter(tag, inputPtr, outputPtr, recvCount, dataType, root, stream); } @@ -137,14 +119,14 @@ HcclResult HcclAlg::ScatterOutPlace(const std::string &tag, void *inputPtr, void HcclDataType dataType, u32 root, Stream stream, const std::unique_ptr &opBaseAtraceInfo) { - ScatterOperator operation(pimpl_); + ScatterOperator operation(pimpl_, topoMatcher_); return operation.ScatterOutPlace(tag, inputPtr, outputPtr, recvCount, dataType, root, stream); } HcclResult HcclAlg::Reduce(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, HcclReduceOp op, u32 root, Stream stream) { - ReduceOperator operation(pimpl_); + ReduceOperator operation(pimpl_, topoMatcher_); return operation.Reduce(tag, inputPtr, outputPtr, count, dataType, op, root, stream); } @@ -152,64 +134,49 @@ HcclResult HcclAlg::ReduceOutPlace(const std::string &tag, void *inputPtr, void HcclDataType dataType, HcclReduceOp op, u32 root, Stream stream, const std::unique_ptr &opBaseAtraceInfo) { - ReduceOperator operation(pimpl_); + ReduceOperator operation(pimpl_, topoMatcher_); return operation.ReduceOutPlace(tag, inputPtr, outputPtr, count, dataType, op, root, stream); } -HcclResult HcclAlg::ReduceScatter(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, HcomCollOpInfo *opInfo) -{ - ReduceScatterOperator operation(pimpl_); - return operation.ReduceScatter(tag, inputPtr, outputPtr, count, dataType, op, stream, opInfo); -} - -HcclResult HcclAlg::ReduceScatterOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, - const std::unique_ptr &opBaseAtraceInfo) -{ - ReduceScatterOperator operation(pimpl_); - return operation.ReduceScatterOutPlace(tag, inputPtr, outputPtr, count, dataType, op, stream, opBaseAtraceInfo); -} - HcclResult HcclAlg::Send(const std::string &tag, void *inputPtr, u64 count, HcclDataType dataType, u32 destRank, Stream stream) { - SendReceiveOperator operation(pimpl_); + SendReceiveOperator operation(pimpl_, topoMatcher_); return operation.Send(tag, inputPtr, count, dataType, destRank, stream); } HcclResult HcclAlg::SendOutPlace(const std::string &tag, void *inputPtr, u64 count, HcclDataType dataType, u32 destRank, Stream stream) { - SendReceiveOperator operation(pimpl_); + SendReceiveOperator operation(pimpl_, topoMatcher_); return operation.SendOutPlace(tag, inputPtr, count, dataType, destRank, stream); } HcclResult HcclAlg::Receive(const std::string &tag, void *outputPtr, u64 count, HcclDataType dataType, u32 srcRank, Stream stream) { - SendReceiveOperator operation(pimpl_); + SendReceiveOperator operation(pimpl_, topoMatcher_); return operation.Receive(tag, outputPtr, count, dataType, srcRank, stream); } HcclResult HcclAlg::ReceiveOutPlace(const std::string &tag, void *outputPtr, u64 count, HcclDataType dataType, u32 srcRank, Stream stream) { - SendReceiveOperator operation(pimpl_); + SendReceiveOperator operation(pimpl_, topoMatcher_); return operation.ReceiveOutPlace(tag, outputPtr, count, dataType, srcRank, stream); } HcclResult HcclAlg::Gather(const std::string &tag, void *inputPtr, void *outputPtr, u32 rootRank, u64 inputCount, HcclDataType dataType, Stream stream) { - GatherOperator operation(pimpl_); + GatherOperator operation(pimpl_, topoMatcher_); return operation.Gather(tag, inputPtr, outputPtr, rootRank, inputCount, dataType, stream); } HcclResult HcclAlg::GetAlltoAllStagedWorkSpaceMemSize(u64 *sendCounts, u64 *sdispls, HcclDataType sendType, u64 *recvCounts, u64 *rdispls, HcclDataType recvType, u64 &memSize) { - AlltoAllOperator operation(pimpl_); + AlltoAllOperator operation(pimpl_, topoMatcher_); return operation.GetAlltoAllStagedWorkSpaceMemSize( sendCounts, sdispls, sendType, recvCounts, rdispls, recvType, memSize); } @@ -217,13 +184,13 @@ HcclResult HcclAlg::GetAlltoAllStagedWorkSpaceMemSize(u64 *sendCounts, u64 *sdis HcclResult HcclAlg::GetAlltoAllStagedWorkSpaceMemSize( std::vector &allMeshAggregationSendRecvInfo, u64 &memSize) { - AlltoAllOperator operation(pimpl_); + AlltoAllOperator operation(pimpl_, topoMatcher_); return operation.GetAlltoAllStagedWorkSpaceMemSize(allMeshAggregationSendRecvInfo, memSize); } HcclResult HcclAlg::GetAllReduceScratchSize(const u32 count, const HcclDataType dataType, u64 &scratchSize) { - AllReduceOperator operation(pimpl_); + AllReduceOperator operation(pimpl_, topoMatcher_); return operation.GetAllReduceScratchSize(count, dataType, scratchSize); } @@ -285,6 +252,29 @@ HcclResult HcclAlg::GetAlgType(AlgType &algType, HcclCMDType opType) return pimpl_->GetAlgType(algType, opType); } +std::string HcclAlg::AlgTypeToStr(const AlgType algType) +{ + AlgTypeLevel1 algTypeLevel1 = AlgTypeLevel1(floor(static_cast(algType) >> HCCL_LEVEL_ALGO_WIDTH)); + AlgTypeLevel0 algTypeLevel0 = AlgTypeLevel0(static_cast(algType) - + (static_cast(algTypeLevel1) << HCCL_LEVEL_ALGO_WIDTH)); + auto level0Iter = HCCL_ALGO_LEVEL0_NAME_MAP.find(algTypeLevel0); + auto level1Iter = HCCL_ALGO_LEVEL1_NAME_MAP.find(algTypeLevel1); + std::string algStrLevel0; + std::string algStrLevel1; + if (level0Iter == HCCL_ALGO_LEVEL0_NAME_MAP.end()) { + algStrLevel0 = "invalid algo type"; + } else { + algStrLevel0 = level0Iter->second; + } + if (level1Iter == HCCL_ALGO_LEVEL1_NAME_MAP.end()) { + algStrLevel1 = "invalid algo type"; + } else { + algStrLevel1 = level1Iter->second; + } + std::string algStr = "level0:" + algStrLevel0 + ",level1:" + algStrLevel1; + return algStr; +} + HcclResult HcclAlg::SupportDeterministicOptim(bool &isDeterministicOptim) { isDeterministicOptim = pimpl_->SupportDeterministicOptim(); @@ -303,12 +293,75 @@ HcclResult HcclAlg::SetHDCModeInfo( u8 HcclAlg::GetDeterministicConfig() const { - return pimpl_->GetDeterministicConfig(); + return topoMatcher_->GetDeterministicConfig(); } HcclResult HcclAlg::SetDeterministicConfig(const u8 deterministic) { - CHK_RET(pimpl_->SetDeterministicConfig(deterministic)); + CHK_RET(topoMatcher_->SetDeterministicConfig(deterministic)); + + return HCCL_SUCCESS; +} + +HcclResult HcclAlg::GetAlltoAllStatus(DeviceMem &tinySendRecvMem, bool &isAlltoAllZCopyMode) +{ + CHK_RET(pimpl_->GetAlltoAllStatus(tinySendRecvMem, isAlltoAllZCopyMode)); + + return HCCL_SUCCESS; +} + +HcclResult HcclAlg::InitExternalEnable() +{ + externalEnable_.enableRdmaSdmaConcurrent = GetExternalInputEnableRdmaSdmaConcurrent(); + externalEnable_.enableFfts = GetExternalInputHcclEnableFfts(); + externalEnable_.deterministic = GetExternalInputHcclDeterministic(); + externalEnable_.highPerfEnable = GetExternalInputHcclHighPerfEnable(); + externalEnable_.intraRoceSwitch = GetExternalInputIntraRoceSwitch(); + externalEnable_.dumpDebug = GetExternalInputHcclDumpDebug(); + + return HCCL_SUCCESS; +} + +HcclResult HcclAlg::InitTopoInfoPartOne(HcclTopoAttr &topoAttr) +{ + topoInfo_.userRank = topoAttr.userRank; + topoInfo_.userRankSize = topoAttr.userRankSize; + topoInfo_.devicePhyId = topoAttr.devicePhyId; + topoInfo_.deviceLogicId = topoAttr.deviceLogicId; + topoInfo_.nicList = topoAttr.nicList; + topoInfo_.isSingleMeshAggregation = topoAttr.isSingleMeshAggregation; + topoInfo_.deviceNumPerAggregation = topoAttr.deviceNumPerAggregation; + topoInfo_.devNumInLevel2 = topoAttr.devNumInLevel2; + topoInfo_.deviceType = topoAttr.deviceType; + topoInfo_.serverNum = topoAttr.serverNum; + topoInfo_.meshAggregationRankSize = topoAttr.meshAggregationRankSize; + topoInfo_.multiModuleDiffDeviceNumMode = topoAttr.multiModuleDiffDeviceNumMode; + topoInfo_.pairLinkCounter = topoAttr.pairLinkCounter; + topoInfo_.isDiffDeviceModule = topoAttr.isDiffDeviceModule; + topoInfo_.realUserRank = topoAttr.realUserRank; + topoInfo_.moduleNum = topoAttr.moduleNum; + + return HCCL_SUCCESS; +} + +HcclResult HcclAlg::InitTopoInfoPartTwo() +{ + TopoType topoType; + CHK_RET(pimpl_->GetTopoType(topoType)); + topoInfo_.topoType = topoType; + topoInfo_.is310P3Common = pimpl_->Is310P3Common(); + std::unordered_map isUsedRdmaMap; + CHK_RET(pimpl_->GetIsUsedRdmaMap(isUsedRdmaMap)); + topoInfo_.isUsedRdmaMap = isUsedRdmaMap; + + return HCCL_SUCCESS; +} + +HcclResult HcclAlg::InitAlgoInfo(HcclAlgoAttr &algoAttr) +{ + algoInfo_.identifier = algoAttr.identifier; + algoInfo_.inlineReduceSwitchOn = algoAttr.inlineReduceSwitchOn; + algoInfo_.isUsedRdmaOuter = algoAttr.isUsedRdmaOuter; return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/algorithm/impl/hccl_alg.h b/src/domain/collective_communication/algorithm/impl/hccl_alg.h index c6c61026835038892cf14e4bc542236e3d726ece..97f53339349670c5598d077d9292553444fd8986 100644 --- a/src/domain/collective_communication/algorithm/impl/hccl_alg.h +++ b/src/domain/collective_communication/algorithm/impl/hccl_alg.h @@ -21,7 +21,7 @@ #include "hccl_impl_pub.h" #include "hccl_opbase_atrace_info_pub.h" #include "resource_manager/queue_notify_manager.h" - +#include "topo_matcher.h" #include "coll_alg_operator.h" namespace hccl { @@ -42,18 +42,7 @@ public: HcclDataType dataType, Stream stream, HcomCollOpInfo *opInfo = nullptr); HcclResult AllGatherOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, Stream stream, const std::unique_ptr &opBaseAtraceInfo = nullptr); - HcclResult AlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag); - HcclResult AlltoAllVOutPlace(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag); - HcclResult AlltoAllVC(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag); - HcclResult AlltoAllVCOutPlace(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag); - HcclResult AlltoAll(const void *sendBuf, u64 sendCount, HcclDataType sendType, - const void *recvBuf, u64 recvCount, HcclDataType recvType, Stream stream, const std::string &tag); + HcclResult Broadcast(const std::string &tag, void *ptr, u64 count, HcclDataType dataType, u32 root, Stream stream); HcclResult BroadcastOutPlace(const std::string &tag, void *ptr, u64 count, HcclDataType dataType, u32 root, @@ -68,11 +57,6 @@ public: HcclResult ReduceOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, HcclReduceOp op, u32 root, Stream stream, const std::unique_ptr &opBaseAtraceInfo = nullptr); - HcclResult ReduceScatter(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, HcomCollOpInfo *opInfo = nullptr); - HcclResult ReduceScatterOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, - const std::unique_ptr &opBaseAtraceInfo = nullptr); HcclResult Send(const std::string &tag, void *inputPtr, u64 count, HcclDataType dataType, u32 destRank, Stream stream); HcclResult SendOutPlace(const std::string &tag, void *inputPtr, u64 count, HcclDataType dataType, @@ -103,6 +87,7 @@ public: void Break(); HcclResult SetAlgType(AlgType algType, HcclCMDType opType); HcclResult GetAlgType(AlgType &algType, HcclCMDType opType); + static std::string AlgTypeToStr(const AlgType algType); HcclResult SupportDeterministicOptim(bool &isDeterministicOptim); HcclResult SetHDCModeInfo( std::unordered_map> &rankDevicePhyIdNicInfoMap, @@ -113,8 +98,18 @@ public: std::unique_ptr GetAlgOperator(const HcclCMDType &opType); + HcclResult GetAlltoAllStatus(DeviceMem &tinySendRecvMem, bool &isAlltoAllZCopyMode); private: std::unique_ptr pimpl_; + HcclResult InitExternalEnable(); + HcclResult InitTopoInfoPartOne(HcclTopoAttr &topoAttr); + HcclResult InitTopoInfoPartTwo(); + HcclResult InitAlgoInfo(HcclAlgoAttr &algoAttr); + HcclTopoInfo topoInfo_; + HcclAlgoInfo algoInfo_; + HcclExternalEnable externalEnable_; + std::unique_ptr topoMatcher_; }; } // namespace hccl + #endif // HCCL_ALG_H diff --git a/src/domain/collective_communication/algorithm/impl/hccl_impl.cc b/src/domain/collective_communication/algorithm/impl/hccl_impl.cc index 9ac1d69a21366f8296f4a1700ec259523a5228d2..52a86f3c7d35f629e0da335e43686cfd73954d35 100644 --- a/src/domain/collective_communication/algorithm/impl/hccl_impl.cc +++ b/src/domain/collective_communication/algorithm/impl/hccl_impl.cc @@ -16,6 +16,7 @@ #include "stream_active_manager.h" #include "profiling_manager_pub.h" #include "heartbeat_pub.h" +#include "hccl_alg.h" #include "hccl_impl.h" using namespace std; @@ -72,6 +73,7 @@ const std::set HCCL_ALGO_TYPE_MAP = { AlgType::ALG_1P_MESH_PLUS_NB, AlgType::ALG_4P_RING_PLUS_NB, AlgType::ALG_NP_SINGLE_RING_PLUS_NB, + AlgType::ALG_NP_DOUBLE_RING_PLUS_NB, AlgType::ALG_NP_MESH_PLUS_NB, AlgType::ALG_WHOLE_NB, AlgType::ALG_NP_STAR, @@ -259,8 +261,8 @@ HcclResult hcclImpl::Init(bool isHeterogComm) HcclResult hcclImpl::CheckAlgType(const AlgType algType) { if (HCCL_ALGO_TYPE_MAP.count(algType) == 0) { - HCCL_ERROR("[Check][AlgType]errNo[0x%016llx] algType[%d] is not supported", HCCL_ERROR_CODE(HCCL_E_PARA), - algType); + HCCL_ERROR("[Check][AlgType]errNo[0x%016llx] algType[%s] is not supported", HCCL_ERROR_CODE(HCCL_E_PARA), + HcclAlg::AlgTypeToStr(algType).c_str()); return HCCL_E_PARA; } return HCCL_SUCCESS; @@ -308,12 +310,13 @@ HcclResult hcclImpl::GetTopoTypeByAlgType(const AlgType &algType, const DevType break; default: HCCL_ERROR("[hcclImpl][GetTopoTypeByAlgType]errNo[0x%016llx] case: device type[%d](0~1:V910)," - " algorithm[%d] is not support", HCCL_ERROR_CODE(HCCL_E_PARA), deviceType, algType); + " algorithm[%s] is not support", + HCCL_ERROR_CODE(HCCL_E_PARA), deviceType, HcclAlg::AlgTypeToStr(algType).c_str()); return HCCL_E_PARA; } - HCCL_INFO("[hcclImpl][GetTopoTypeByAlgType]algtype[%d], devicetype[%d],topotype[%d] is selected", - algType, deviceType, topoType); + HCCL_INFO("[hcclImpl][GetTopoTypeByAlgType]algtype[%s], devicetype[%d],topotype[%d] is selected", + HcclAlg::AlgTypeToStr(algType).c_str(), deviceType, topoType); return HCCL_SUCCESS; } @@ -539,6 +542,13 @@ HcclResult hcclImpl::GetDefaultAlgoLevel0Module(AlgTypeLevel0 &algType) HCCL_DEBUG("[GetDefaultAlgoLevel0Module] AlgTypeLevel0 is set to ALG_LEVEL0_NP_MESH (HCCS links is enabled)."); } + if (deviceType_ == DevType::DEV_TYPE_910_73) { + algType = IsHCCSSWNumEqualToTwiceSIONum() ? AlgTypeLevel0::ALG_LEVEL0_NP_DOUBLE_RING : + AlgTypeLevel0::ALG_LEVEL0_NP_SINGLE_RING; + if ((algType == AlgTypeLevel0::ALG_LEVEL0_NP_SINGLE_RING) && GetExternalInputEnableRdmaSdmaConcurrent()) { + SetRdmaSdmaConcurrentDisable(); + } + } return HCCL_SUCCESS; } @@ -814,9 +824,23 @@ HcclResult hcclImpl::InitMultiStreamResource(const std::string &tag, innerStream { if (!isBatchSendRecv) { switch (algType) { + case AlgType::ALG_NP_SINGLE_RING_PLUS_RING: + case AlgType::ALG_NP_SINGLE_RING_PLUS_HD: + case AlgType::ALG_NP_SINGLE_RING_PLUS_NHR: + if (deviceType_ == DevType::DEV_TYPE_910_73) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + streamInfo.ringNum + = OUTER_PLANE_NUM_IN_NPRING_SINGLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } else { + streamInfo.ringNum + = OUTER_PLANE_NUM_IN_NPRING_SINGLE; + } + } + break; case AlgType::ALG_DOUBLE_RING_PLUS_RING: case AlgType::ALG_DOUBLE_RING_PLUS_HD: case AlgType::ALG_DOUBLE_RING_PLUS_NHR: + // 当前这两种AlgType只支持910_73场景 if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { streamInfo.ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; } else { @@ -854,6 +878,8 @@ HcclResult hcclImpl::InitMultiStreamResource(const std::string &tag, innerStream if ((GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) && (deviceType_ == DevType::DEV_TYPE_910B) && isSingleMeshAggregation_) { streamInfo.ringNum = deviceNumPerAggregation_; + } else if ((deviceType_ == DevType::DEV_TYPE_910_73) && (isAicpuModeEn == true)) { + streamInfo.ringNum = deviceNumPerAggregation_; } else if ((GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) && (deviceType_ == DevType::DEV_TYPE_910B) && UseInterServerPipelineAlgo(algType)) { streamInfo.ringNum = deviceNumPerAggregation_ + 1; /* pipeline ring场景下性能优化 */ @@ -876,6 +902,10 @@ HcclResult hcclImpl::InitMultiStreamResource(const std::string &tag, innerStream streamInfo.ringNum = 2; } + if (GetExternalInputEnableRdmaSdmaConcurrent() && deviceType_ == DevType::DEV_TYPE_910_73) { + streamInfo.ringNum += RDMA_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING; + } + if (piplineSliceNum_ > 0) { streamInfo.ringNum++; // 流水并行算法, Server间需要额外一条从流 } @@ -933,6 +963,7 @@ HcclResult hcclImpl::InitMultiStreamResource(const std::string &tag, innerStream for (auto &signal : streamInfo.ringDeviceSignalAux) { signal = nullptr; } + u32 notifyNum = resNum * 2; // 2:Signal + SignalAux std::vector> notifys(notifyNum, nullptr); CHK_RET(queueNotifyManager_->Alloc(tag, notifyNum, notifys, NotifyLoadType::DEVICE_NOTIFY)); @@ -963,9 +994,9 @@ HcclResult hcclImpl::PrepareCommRes(const std::string &tag, DeviceMem &inputMem, HcclUs startut = TIME_NOW(); HCCL_INFO("[HcclImpl][PrepareCommRes] tag[%s], inputMem ptr[%p] size[%llu], outputMem ptr[%p] size[%llu], " - "algType[%d], root[%u], isP2p[%d], isHaveCpuRank[%d], meshSinglePlane[%d], aivMode[%d]", - tag.c_str(), inputMem.ptr(), inputMem.size(), outputMem.ptr(), outputMem.size(), algType, root, isP2p, - isHaveCpuRank, meshSinglePlane, aivMode); + "algType[%s], root[%u], isP2p[%d], isHaveCpuRank[%d], meshSinglePlane[%d], aivMode[%d]", + tag.c_str(), inputMem.ptr(), inputMem.size(), outputMem.ptr(), outputMem.size(), + HcclAlg::AlgTypeToStr(algType).c_str(), root, isP2p, isHaveCpuRank, meshSinglePlane, aivMode); CHK_RET(notifyPool_->RegisterOp(tag)); HcclResult ret = HCCL_SUCCESS; @@ -1007,9 +1038,9 @@ HcclResult hcclImpl::PrepareCommRes(const std::string &tag, DeviceMem &inputMem, s32 streamId = 0; (void)hrtGetStreamId(stream.ptr(), streamId); HCCL_ERROR("[HcclImpl][PrepareCommRes] failed, tag[%s], inputMem ptr[%p] size[%llu], outputMem ptr[%p] "\ - "size[%llu], algType[%d], streamId[%d], root[%u], isP2p[%d], isHaveCpuRank[%d], return[0x%016llx]", + "size[%llu], algType[%s], streamId[%d], root[%u], isP2p[%d], isHaveCpuRank[%d], return[0x%016llx]", tag.c_str(), inputMem.ptr(), inputMem.size(), outputMem.ptr(), outputMem.size(), - algType, streamId, root, isP2p, isHaveCpuRank, HCCL_ERROR_CODE(ret)); + HcclAlg::AlgTypeToStr(algType).c_str(), streamId, root, isP2p, isHaveCpuRank, HCCL_ERROR_CODE(ret)); (void)notifyPool_->UnregisterOp(tag); if (!isBatchSendRecv) { UnRegisterToHeartBeat(); @@ -1041,6 +1072,12 @@ HcclResult hcclImpl::CreateComm(const std::string &tag, DeviceMem &inputMem, Dev if (isP2p) { CHK_RET(CreateP2pComm(tag, *commInfo, inputMemComm, root)); + } else if (isAicpuModeEn && deviceType_ == DevType::DEV_TYPE_910_73) { + // level0 mesh通信域 + std::vector > commMeshL0; + CommParaInfo commCombinePara(COMM_MESH_L0, CommType::COMM_TAG_MESH); + commCombinePara.isAicpuModeEn = isAicpuModeEn; + CHK_RET(commFactory_->CreateCommPlane(tag, inputMemComm, outputMemComm, commCombinePara, commInfo->commOuter)); } else { CHK_RET(CreateCommByAlg(tag, algType, *commInfo, inputMemComm, outputMemComm, root, isAicpuModeEn, meshSinglePlane)); @@ -1120,8 +1157,8 @@ HcclResult hcclImpl::GetCommTypeInLevel0(const AlgType algType, const TopoType t } else { commType = CommType::COMM_TAG_RING_INNER; } - HCCL_DEBUG("[Get][CommTypeForLevel0]The algType is %d, topoType is %d, while commType is %d", - algType, topoType, commType); + HCCL_DEBUG("[Get][CommTypeForLevel0]The algType is %s, topoType is %d, while commType is %d", + HcclAlg::AlgTypeToStr(algType).c_str(), topoType, commType); return HCCL_SUCCESS; } @@ -1136,8 +1173,8 @@ HcclResult hcclImpl::GetCommTypeInLevel0(const AlgType algType, const TopoType t } else { commType = CommType::COMM_TAG_RING_INNER; } - HCCL_DEBUG("[Get][CommTypeForLevel0]The algType is %d, topoType is %d, while commType is %d", - algType, topoType, commType); + HCCL_DEBUG("[Get][CommTypeForLevel0]The algType is %s, topoType is %d, while commType is %d", + HcclAlg::AlgTypeToStr(algType).c_str(), topoType, commType); return HCCL_SUCCESS; } @@ -1240,10 +1277,11 @@ HcclResult hcclImpl::GetCommTypeInLevel1(const AlgType algType, CommType &commTy } default: - HCCL_ERROR("[Get][CommTypeInLevel1]algType[%d] is not support", algType); + HCCL_ERROR("[Get][CommTypeInLevel1]algType[%s] is not support", HcclAlg::AlgTypeToStr(algType).c_str()); return HCCL_E_PARA; } - HCCL_DEBUG("[Get][CommTypeInLevel1]The algType is %d, while commType is %d", algType, commType); + HCCL_DEBUG("[Get][CommTypeInLevel1]The algType is %s, while commType is %d", + HcclAlg::AlgTypeToStr(algType).c_str(), commType); return HCCL_SUCCESS; } @@ -1291,6 +1329,10 @@ HcclResult hcclImpl::CreateCommByAlg(const std::string &tag, const AlgType algTy HcclResult commThreadResultLevel0Rdma = HCCL_SUCCESS; CHK_RET(GetCommTypeInLevel0(algType, topoType_, commTypeInLevel0)); bool isUsedRdma = false; + if (GetExternalInputEnableRdmaSdmaConcurrent() && deviceType_ == DevType::DEV_TYPE_910_73) { + HCCL_INFO("commInfo create commOuterRdma/commInnerRdma for EnableRdmaSdma start"); + isUsedRdma = true; + } if (Is310P3Common()) { if (isAicpuModeEn) { @@ -1613,27 +1655,31 @@ void hcclImpl::UnRegisterToHeartBeatP2P() } } -HcclResult hcclImpl::ParallelTaskLoaderProcess(const std::string &tag, Stream &stream) +HcclResult hcclImpl::ParallelTaskLoaderProcess(const std::string &tag, Stream &stream, SubCommInfo &outerCommInfo, + std::vector &ringStreams) { u32 streamIndex; std::vector streamsPtr; - streamsPtr.resize(tagStreamInfo_[tag].ringStreams.size() + 1); + streamsPtr.resize(ringStreams.size() + 1); - for (streamIndex = 0; streamIndex < tagStreamInfo_[tag].ringStreams.size(); streamIndex++) { - streamsPtr[streamIndex] = &tagStreamInfo_[tag].ringStreams[streamIndex]; + for (streamIndex = 0; streamIndex < ringStreams.size(); streamIndex++) { // StreamInfo_.ringStreams + streamsPtr[streamIndex] = &ringStreams[streamIndex]; } streamsPtr[streamIndex] = &stream; HCCL_INFO("[ParallelTaskLoaderProcess]main stream[%p], streams size[%u]", stream.ptr(), streamsPtr.size()); // 准备多线程启动参数 - CHK_RET(parallelTaskLoader_->Prepare(streamsPtr, tagCommInfo_[tag].commOuter[0].get())); + CHK_RET(parallelTaskLoader_->Prepare(streamsPtr, outerCommInfo)); // 启动多线程处理 CHK_RET(parallelTaskLoader_->StartTaskLoad()); // 等待多线程处理结果 CHK_RET(parallelTaskLoader_->WaitTaskLoadFinish()); + + // 销毁通信域 + CHK_RET(parallelTaskLoader_->ClearTagCommInfo()); return HCCL_SUCCESS; } @@ -1851,12 +1897,11 @@ u32 hcclImpl::GetInnerCommRank(const u32 ringIdx) return commFactory_->GetInnerCommRank(ringIdx); } -HcclResult hcclImpl::GetAlltoAllStatus(DeviceMem &tinySendRecvMem, bool &isAlltoAllZCopyMode, - std::map &isAlltoAllZCopyModeMap) +HcclResult hcclImpl::GetAlltoAllStatus(DeviceMem &tinySendRecvMem, bool &isAlltoAllZCopyMode) { tinySendRecvMem = tinySendRecvMem_; isAlltoAllZCopyMode = isAlltoAllZCopyMode_; - isAlltoAllZCopyModeMap = isAlltoAllZCopyModeMap_; + return HCCL_SUCCESS; } @@ -1947,6 +1992,11 @@ HcclResult hcclImpl::PrepareInnerCommInfo(u32 &segmentIdx, u32 &commIndex, u64 & hdSize = iter->second[nicIdx]; // 通过nicSendSizeList_得到该网口传输数据量 u32 ringRanks = multRingsSliceZero[0].size(); // 获取单个 ring 上设备的数量 segmentIdx = ringRanks / nicList_.size() * nicIdx; // 通过网口位置得到该网口传输数据的起始位置 + // 910A只有8卡场景,所以commIdx等于devicePhyId_,不由nicIdx决定 + // 910_73场景的ring环内的设备物理ID是是由nicList决定的,需要segmentIdx(由nicIdx决定)更新 + if (deviceType_ == DevType::DEV_TYPE_910_73) { + commIndex = segmentIdx; + } } else { // 如果当前rank不是通信网口,则不发送数据 hdSize = 0; } @@ -1983,6 +2033,11 @@ HcclResult hcclImpl::PrepareInnerCommInfo(u32 &segmentIdx, u32 &commIndex, u64 & hdSize = iter->second[nicIdx]; // 通过nicSendSizeList_得到该网口传输数据量 u32 ringRanks = multRingsSliceZero[0].size(); // 获取单个 ring 上设备的数量 segmentIdx = ringRanks / nicList_.size() * nicIdx; // 通过网口位置得到该网口传输数据的起始位置 + // 910A只有8卡场景,所以commIdx等于devicePhyId_,不由nicIdx决定 + // 910_73场景的ring环内的设备物理ID是是由nicList决定的,需要segmentIdx(由nicIdx决定)更新 + if (deviceType_ == DevType::DEV_TYPE_910_73) { + commIndex = segmentIdx; + } } else { // 如果当前rank不是通信网口,则不发送数据 hdSize = 0; } @@ -2028,38 +2083,50 @@ void hcclImpl::SetHDCModeInfo( isUseRankPort_ = isUseRankPort; } -u8 hcclImpl::GetDeterministicConfig() const +u64 hcclImpl::GetInCCLbufferSize() const { - return deterministic_; + return cclBufferManager_.GetInCCLbufferSize(); } -HcclResult hcclImpl::SetDeterministicConfig(const u8 deterministic) +HcclResult hcclImpl::GetCommPlaneRanks(std::vector>> &CommPlaneRanks) { - if (deterministic > 1) { - HCCL_ERROR("[SetDeterministicConfig] deterministic[%d] should be 0 or 1."); - return HCCL_E_PARA; - } - deterministic_ = deterministic; + CHK_RET(commFactory_->GetCommPlaneRanks(CommPlaneRanks)); return HCCL_SUCCESS; } -HcclResult hcclImpl::CalcCommPlaneInfo(const std::string &tag, const CommParaInfo &commParaInfo, - std::vector &commTransport, TransportMemType inPutMemType, - TransportMemType outPutMemType) +HcclResult hcclImpl::GetIsBridgeVector(std::vector &isBridgeVector) { - return commFactory_->CalcCommPlaneInfo(tag, commParaInfo, commTransport, inPutMemType, outPutMemType); + CHK_RET(commFactory_->GetIsBridgeVector(isBridgeVector)); + return HCCL_SUCCESS; } -HcclResult hcclImpl::GetTotalTargetRankSet(std::set& totalTargetRankSet) +HcclResult hcclImpl::GetIsUsedRdmaMap(std::unordered_map &isUsedRdmaMap) { - totalTargetRankSet = totalTargetRankSet_; + CHK_RET(commFactory_->GetIsUsedRdmaMap(isUsedRdmaMap)); return HCCL_SUCCESS; } - -HcclResult hcclImpl::UpdateTotalTargetRankSet(std::set& totalTargetRankSet) + +HcclResult hcclImpl::GetDispatcher(HcclDispatcher &dispatcher) +{ + dispatcher = dispatcher_; + return HCCL_SUCCESS; +} +HcclResult hcclImpl::GetVirtualDispatcher(HcclDispatcher &vdispatcher) { - totalTargetRankSet_.clear(); - totalTargetRankSet_ = totalTargetRankSet; + vdispatcher = vDispatcher_; return HCCL_SUCCESS; } + +HcclResult hcclImpl::GetParallelTaskLoader(ParallelTaskLoader* ¶llelTaskLoader) +{ + parallelTaskLoader = parallelTaskLoader_.get(); + return HCCL_SUCCESS; +} + +HcclResult hcclImpl::GetRankVecInfo(std::vector>> &serverAndsuperPodToRank) +{ + CHK_RET(commFactory_->GetRankVecInfo(serverAndsuperPodToRank)); + return HCCL_SUCCESS; +} + } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/hccl_impl.h b/src/domain/collective_communication/algorithm/impl/hccl_impl.h index 35238eb28f7b0f8ba05257cf625c637476cf7e16..4be308847601afb514ded0daa716955dbd1a0a94 100644 --- a/src/domain/collective_communication/algorithm/impl/hccl_impl.h +++ b/src/domain/collective_communication/algorithm/impl/hccl_impl.h @@ -106,7 +106,8 @@ public: HcclResult CreateCommForAlltoAllFullMesh(const std::string &tag, DeviceMem &sendBuf, DeviceMem &recvBuf); HcclResult CreateAlltoAllVCommMem(DeviceMem& inputMem, DeviceMem& outputMem) const; HcclResult BuildAlltoAllVScratchMem(const std::string &tag, u64 workSpaceMemSize); - HcclResult ParallelTaskLoaderProcess(const std::string &tag, Stream &stream); + HcclResult ParallelTaskLoaderProcess(const std::string &tag, Stream &stream, SubCommInfo &outerCommInfo, + std::vector &ringStreams); HcclResult GetTopoType(TopoType &topoType); HcclResult GetAlgoLevel1DefaultSwitch(bool &isAlgoLevel1Default, HcclCMDType opType); @@ -122,8 +123,7 @@ public: std::vector>& threadManager); innerStreamInfo_t* GetStreamInfoWithoutCheck(const std::string &tag); HcclResult SetPipelineSliceNum(u64 piplineSliceNum); - HcclResult GetAlltoAllStatus(DeviceMem &tinySendRecvMem, bool &isAlltoAllZCopyMode, - std::map &isAlltoAllZCopyModeMap); + HcclResult GetAlltoAllStatus(DeviceMem &tinySendRecvMem, bool &isAlltoAllZCopyMode); HcclResult UpdateAlltoAllStatus(bool &isAlltoAllZCopyMode, bool &needRecreateAlltoallComm, std::map &isAlltoAllZCopyModeMap); u64 GetOtherRankAllocScratchSize( @@ -150,11 +150,16 @@ public: HcclResult CreateComm(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, AlgType algType, u32 root = INVALID_VALUE_RANKID, bool isP2p = false, bool isBatchSendRecv = false, bool meshSinglePlane = false, bool aivMode = false, std::set batchSendRecvtargetRanks = std::set()); - HcclResult CalcCommPlaneInfo(const std::string &tag, const CommParaInfo &commParaInfo, - std::vector &commTransport, TransportMemType inPutMemType, - TransportMemType outPutMemType); - + HcclResult GetCommPlaneRanks(std::vector>> &CommPlaneRanks); + HcclResult GetIsBridgeVector(std::vector &isBridgeVector); HcclResult ClearOpResource(const std::string &tag); + HcclResult GetTopoAttr(HcclTopoAttr &topoAttr); + HcclResult GetAlgoAttr(HcclAlgoAttr &algoAttr); + HcclResult GetDispatcher(HcclDispatcher &dispatcher); + HcclResult GetVirtualDispatcher(HcclDispatcher &vdispatcher); + HcclResult GetParallelTaskLoader(ParallelTaskLoader* ¶llelTaskLoader); + HcclResult GetIsUsedRdmaMap(std::unordered_map &isUsedRdmaMap); + HcclResult GetRankVecInfo(std::vector>> &serverAndsuperPodToRank); void Break() { if (Is310P3Common()) { @@ -206,11 +211,11 @@ public: std::unordered_map> &rankDevicePhyIdNicInfoMap, std::vector &ranksPort, bool isSetHDCModeInfo, bool isUseRankPort); - u8 GetDeterministicConfig() const; // 获取确定性计算配置 - HcclResult SetDeterministicConfig(const u8 deterministic); // 设置确定性计算配置 - // 用于batchsendrecv增量建链获取更新已建链的对端rank号 - HcclResult GetTotalTargetRankSet(std::set& totalTargetRankSet); - HcclResult UpdateTotalTargetRankSet(std::set& totalTargetRankSet); + u64 GetInCCLbufferSize() const; // 获取CCL缓存区大小,用于Executor计算scratch大小 + bool Is310P3Common() + { + return !isHaveCpuRank_ && !Is310PDevice() && deviceType_ == DevType::DEV_TYPE_310P3; + } private: void SetAlgoAttr(HcclAlgoAttr &algoAttr); @@ -280,10 +285,6 @@ private: void UnRegisterToHeartBeat(); void UnRegisterToHeartBeat(const std::string& tag); - bool Is310P3Common() - { - return !isHaveCpuRank_ && !Is310PDevice() && deviceType_ == DevType::DEV_TYPE_310P3; - } /* ---------------以下为私有成员变量定义领域-------------------------- */ HcclTopoAttr topoAttr_; HcclAlgoAttr algoAttr_; @@ -342,8 +343,6 @@ private: bool isAlltoAllZCopyMode_ = false; bool needRecreateAlltoallComm_ = false; std::map isAlltoAllZCopyModeMap_; - // batchsendrecv增量建链,记录已经建链的对端userRank号 - std::set totalTargetRankSet_; // 按照 tag 记录全局所有卡上 alltoall 算子的中转内存大小 std::unordered_map> allRankAlltoallScratchMemSize_; bool isSingleMeshAggregation_ = false; diff --git a/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt index 74843c5f5e616ced477366b9320ba02c28483a24..02c8776f8bd0a70942db5f4498f331705c117042 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt @@ -11,6 +11,8 @@ set(src_list ${CMAKE_CURRENT_SOURCE_DIR}/send_receive_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/alltoall_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchsendrecv_operator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/send_operator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/receive_operator.cc ) target_sources(hccl_alg PRIVATE diff --git a/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.cc index 049f2b8bba2cacd0277cf7b80197f05dd8f5cdc0..a26649ad1b8342672f439b7d6bfdb9a35272fac8 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.cc @@ -12,10 +12,11 @@ #include "device_capacity.h" #include "rank_consistent.h" #include "executor_impl.h" +#include "coll_alg_op_registry.h" namespace hccl { -AllGatherOperator::AllGatherOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_ALLGATHER) +AllGatherOperator::AllGatherOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_ALLGATHER) { } @@ -33,6 +34,16 @@ HcclResult AllGatherOperator::AllGather(const std::string &tag, void *inputPtr, bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; + // 910_73超节点只支持server间ring,NB和NHR,默认需继续使用NHR + if (!(UseInterServerRingAlgo(algType_) || UseInterServerNBAlgo(algType_)) && + deviceType_ == DevType::DEV_TYPE_910_73) { + HcclResult ret = SetInterServerNHRAlgo(algType_); + HCCL_WARNING("[AllGatherOperator][AllGather] only support ring, NB and NHR in AlgoLevel1 yet, "\ + "default is algType=NHR."); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherOperator][AllGather]errNo[0x%016llx] tag[%s],AllGather set inter server "\ + "nhr algo failed", HCCL_ERROR_CODE(ret), tag.c_str()), ret); + } /* 屏蔽pytorch子图+静态图场景 */ if (UseInterServerPipelineAlgo(algType_) && ((GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) || !isMeshTopo)) { @@ -278,8 +289,11 @@ HcclResult AllGatherOperator::AllGatherOutPlace(const std::string &tag, void *in bool isPipeLine = (deviceType_ == DevType::DEV_TYPE_910B && isMeshTopo && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && (UseInterServerPipelineAlgo(algType_) || isSingleMeshAggregation_)); + // 当前allgather 的DMA削减只支持sever内 + bool isDMAreduceOn91073 = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE + && (deviceType_ == DevType::DEV_TYPE_910_73) && !isMeshTopo); bool isUseDMA = !GetExternalInputEnableRdmaSdmaConcurrent(); - if (isUseDMA && (isPipeLine)) { + if (isUseDMA && (isPipeLine || isDMAreduceOn91073)) { CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_ALLGATHER, newTag, curCount, dataType, inCCLbuffer.size(), outCCLbuffer.size(), HCCL_WORLD_GROUP)); HcomCollOpInfo singleServerOpInfo; @@ -360,7 +374,7 @@ HcclResult AllGatherOperator::AllGatherCommFor310P(const std::string &tag, Devic CHK_RET(commCombined->RunExecutor(executor)); HCCL_INFO("allgather for 310P run success"); - + return HCCL_SUCCESS; } @@ -398,9 +412,19 @@ HcclResult AllGatherOperator::RunAllGather(const std::string& tag, DeviceMem &in } case TopoType::TOPO_TYPE_8P_RING: case TopoType::TOPO_TYPE_NP_SINGLE_RING: - case TopoType::TOPO_TYPE_NP_DOUBLE_RING: - ret = AllGatherRingExecutor(tag, inputMem, outputMem, count, dataType, op, stream, opInfo); - break; + case TopoType::TOPO_TYPE_NP_DOUBLE_RING: // 只存在于910_73场景下 + if (deviceType_ == DevType::DEV_TYPE_910_73) { + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + ret = AllGatherDoubleRingConcurrentExecutor(tag, inputMem, outputMem, count, + dataType, op, stream, opInfo); + } else { + ret = AllGatherDoubleRingExecutor(tag, inputMem, outputMem, count, dataType, op, stream, opInfo); + } + break; + } else { + ret = AllGatherRingExecutor(tag, inputMem, outputMem, count, dataType, op, stream, opInfo); + break; + } default: ret = AllGatherComm(tag, inputMem, outputMem, count, dataType, op, stream); break; @@ -868,6 +892,230 @@ HcclResult AllGatherOperator::AllGatherDoubleRingExecutor(const std::string &tag return HCCL_SUCCESS; } +HcclResult AllGatherOperator::AllGatherDoubleRingConcurrentExecutor(const std::string &tag, DeviceMem &inputMem, + DeviceMem &outputMem, u64 count, HcclDataType dataType, + HcclReduceOp op, Stream &stream, const HcomCollOpInfo *opInfo) +{ + HCCL_INFO("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + " AllGatherDoubleRingConcurrentExecutor starts."); + (void)op; + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(dataType, perDataSize)); + CHK_PRT_RET(perDataSize == 0, + HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]errNo[0x%016llx] datatype[%d] is invalid", + HCCL_ERROR_CODE(HCCL_E_PARA), dataType), HCCL_E_PARA); + CommInfo *currComm; + hcclImpl_->GetCommInfo(currComm, tag); + + CHK_PRT_RET(currComm->commOuter.empty(), + HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]comm outer is empty"), HCCL_E_PARA); + u32 ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; + u32 commIndex = currComm->commOuter[0]->Rank(); + bool bRet = commIndex >= currComm->commInner.size(); + CHK_PRT_RET(bRet, + HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]commIndex[%u] >= (tag[%s]) "\ + "comm size[%llu]", commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); + + commIndex = RefreshCommIdx(commIndex, nicList_, devicePhyId_); + + // 第一步,将数据从input内存拷贝到output内存的对应位置 + HcclResult ret; + CHK_PRT_RET(currComm->commOuter.empty(), + HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]errNo[0x%016llx] comm outer is empty", + HCCL_ERROR_CODE(HCCL_E_PARA)), HCCL_E_PARA); + CHK_SMART_PTR_NULL(currComm->commOuter[COMM_INDEX_0]); + u32 outerRankSize = currComm->commOuter[COMM_INDEX_0]->RankSize(); + u32 serverIndex = 0; + ret = currComm->commInner[commIndex]->GetRankByUserRank(userRank_, serverIndex); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]Get "\ + " Rank[%u] by User Rank[%u] from CommInner[%u] Failed!", serverIndex, userRank_, commIndex), ret); + + u64 inputMemSize = inputMem.size(); + u64 baseOffset = serverIndex * inputMemSize * outerRankSize; + u64 outerOffset = commIndex * inputMemSize; + DeviceMem dstMem = outputMem.range(baseOffset + outerOffset, inputMemSize); + CHK_SMART_PTR_NULL(dstMem); + if (opInfo == nullptr) { + ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, inputMem, stream); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]all gather double " + "ring memcpy Failed, Offset[%llu], Size[%llu]", + baseOffset + outerOffset, inputMemSize), ret); + } + + // 第二步,各个AI Server 内 multi ring all gather + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector> multRingsSliceZero; // 数据基于该rank上环0的偏移 + bRet = currComm->commOuter.size() < ringNum; + CHK_PRT_RET(bRet, HCCL_ERROR("[[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]][]ringNum[%u] > "\ + " (tag[%s]), comm outer count[%llu]", ringNum, tag.c_str(), currComm->commOuter.size()), HCCL_E_INTERNAL); + u32 sliceNum = currComm->commOuter[COMM_INDEX_0]->RankSize(); + + CHK_RET(PrepareAllgatherSlice(sliceNum, inputMemSize, dataSegsSlice)); + + // 多环数据切分 + auto mult2RingsSlice = PrepareMultiRingSlice(dataSegsSlice, tag, false, nicList_); + std::vector>> mult4RingsSlice; + // 基于2环数据切分2环SDMA+2环ROH; bool = true表示SDMA; + u32 syncTrans = BEST_SPLIT_VALUE; + u64 totalDataSize = inputMemSize * dataSegsSlice.size(); + if (totalDataSize <= HCCL_SDMA_RDMA_SPLIT_SIZE) { + syncTrans = MAX_SPLIT_VALUE; + } + mult4RingsSlice.resize(mult2RingsSlice.size() * SLICES_FACTOR); + for (u32 ringIndex = 0; ringIndex < mult2RingsSlice.size(); ringIndex++) { + std::vector sdmaSlice; + std::vector rdmaSlice; + for (u32 segsIndex = 0; segsIndex < mult2RingsSlice[ringIndex].size(); segsIndex++) { + auto totalSize = mult2RingsSlice[ringIndex][segsIndex].size; + auto sdmaSliceOffset = mult2RingsSlice[ringIndex][segsIndex].offset; + auto sdmaSliceSize = (totalSize <= HCCL_MIN_SLICE_ALIGN_910_73) ? totalSize: + ((syncTrans * totalSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * HCCL_MIN_SLICE_ALIGN_910_73; + Slice sdmaSliceTmp; + sdmaSliceTmp.offset = sdmaSliceOffset; + sdmaSliceTmp.size = sdmaSliceSize; + Slice rdmaSliceTmp; + rdmaSliceTmp.offset = sdmaSliceOffset + sdmaSliceSize; + rdmaSliceTmp.size = totalSize - sdmaSliceSize; + sdmaSlice.push_back(sdmaSliceTmp); + rdmaSlice.push_back(rdmaSliceTmp); + HCCL_DEBUG("Ring index:%u, segId:%u, Orignal [offset %llu, size %llu], sdma [offset %llu, size %llu], "\ + "rdma [offset %llu, size %llu]", ringIndex, segsIndex, sdmaSliceOffset, totalSize, + sdmaSliceTmp.offset, sdmaSliceTmp.size, rdmaSliceTmp.offset, rdmaSliceTmp.size); + } + mult4RingsSlice[ringIndex] = std::make_pair(true, sdmaSlice); // true表示使用sdma + mult4RingsSlice[ringIndex + mult2RingsSlice.size()] = std::make_pair(false, rdmaSlice); // false表示rdma + } + if (syncTrans == MAX_SPLIT_VALUE) { + mult4RingsSlice.erase(mult4RingsSlice.end() - mult2RingsSlice.size(), mult4RingsSlice.end()); + } + + // 抽取当前用于多环all gather 的output内存数据 + DeviceMem currentOutputMem = outputMem.range(baseOffset, inputMemSize * outerRankSize); + CHK_SMART_PTR_NULL(currentOutputMem); + CHK_RET(hcclImpl_->ActiveRingStreams(tag, stream)); + + CHK_RET(MultiRingAllGatherConcurrent(tag, inputMem, currentOutputMem, count, dataType, + mult4RingsSlice, stream, PROF_STAGE_1, baseOffset, opInfo)); + + HCCL_INFO("all gather double ring outer run success"); + + // 第三步, AI server 间 recursive halving doubling all gather + u64 hdSize = 0; + std::vector::iterator iterNic = std::find(nicList_.begin(), nicList_.end(), devicePhyId_); + if (iterNic != nicList_.end()) { + hdSize = inputMemSize * outerRankSize; + } + + u64 hdCount = hdSize / perDataSize; + std::unique_ptr innerExecutor; + u64 firstCommInnerSize = ((syncTrans * hdSize / MAX_SPLIT_VALUE) / HCCL_MIN_SLICE_ALIGN_910_73) * + HCCL_MIN_SLICE_ALIGN_910_73; + std::vector sendSize{firstCommInnerSize, hdSize - firstCommInnerSize}; + std::vector sendOffset{0, firstCommInnerSize}; + innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); + for (int innerCommIndex = 0; innerCommIndex < RDMA_PLANE_NUM_IN_NPRING_DOUBLE; ++innerCommIndex) { + if (sendSize[innerCommIndex] == 0 || (!GetExternalInputEnableRdmaSdmaConcurrent() && innerCommIndex > 0)) { + continue; + } + if (GetExternalInputEnableRdmaSdmaConcurrent() || UseInterServerRingAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherRing(dispatcher_)); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else if (UseInterServerNHRAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNHR(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-hierarchical-ring algo inter-server."); + } else if (UseInterServerNBAlgo(algType_)) { + innerExecutor.reset(new (std::nothrow) AllGatherNB(dispatcher_)); + HCCL_INFO("allgather ring: using nonuniform-bruck algo inter-server."); + } else { + innerExecutor.reset(new (std::nothrow) AllGatherRecursiveHalvingDoubling(dispatcher_)); + HCCL_INFO("allgather ring: using halving-doubling algo inter-server."); + } + + CHK_SMART_PTR_NULL(innerExecutor); + std::unique_ptr &commInner = (innerCommIndex == 0 ? currComm->commInner[commIndex] : + currComm->commInnerRdma[commIndex]); + CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); + CHK_SMART_PTR_NULL(currComm->commInnerRdma[commIndex]); + + if (devNumInLevel2_ <= 1) { + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + u32 rankSize = commInner->RankSize(); + std::vector inputSlices(rankSize, Slice()); + for (u32 i = 0; i < rankSize; i++) { + inputSlices[i].size = sendSize[innerCommIndex]; + inputSlices[i].offset = hdSize * i + sendOffset[innerCommIndex]; + } + auto &innerCommStream = streamInfo->ringStreams[innerCommIndex]; + auto ret = streamInfo->ringSignalAux[innerCommIndex]->Wait(innerCommStream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + " inner wait main [%u] failed", innerCommIndex), ret); + + CHK_RET(innerExecutor->Prepare(outputMem, outputMem, inputMem, hdCount, dataType, innerCommStream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, inputSlices, 0)); + + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commInner->Rank(), + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, innerCommStream)); + + CHK_RET(commInner->RunExecutor(innerExecutor)); + + ret = streamInfo->ringSignal[innerCommIndex]->Post(innerCommStream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + " inner post mains [%u] failed", innerCommIndex), ret); + + ret = streamInfo->ringSignalAux[innerCommIndex]->Post(stream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + " main post inner [%u] failed", innerCommIndex), ret); + } else { + u32 innerRankSize = currComm->commInner[COMM_INDEX_0]->RankSize(); + u64 innerBaseOffset = baseOffset * innerRankSize; + DeviceMem innerInputMem = outputMem.range(innerBaseOffset, inputMemSize * outerRankSize); + DeviceMem innerOutputMem = outputMem.range(innerBaseOffset, inputMemSize * outerRankSize * innerRankSize); + + std::vector inputSlices(innerRankSize, Slice()); + for (u32 i = 0; i < innerRankSize; i++) { + inputSlices[i].size = sendSize[innerCommIndex]; + inputSlices[i].offset = hdSize * i + sendOffset[innerCommIndex]; + } + + auto &innerCommStream = streamInfo->ringStreams[innerCommIndex]; + auto ret = streamInfo->ringSignalAux[innerCommIndex]->Wait(innerCommStream, dispatcher_, PROF_STAGE_2); + + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(innerExecutor->Prepare(innerInputMem, innerOutputMem, inputMem, hdCount, dataType, + innerCommStream, HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, inputSlices, 0)); + + u32 rankSize = commInner->RankSize(); + CHK_RET(innerExecutor->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commInner->Rank(), + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, innerCommStream)); + + CHK_RET(commInner->RunExecutor(innerExecutor)); + ret = streamInfo->ringSignal[innerCommIndex]->Post(innerCommStream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + "inner post mains [%u] failed", innerCommIndex), ret); + + ret = streamInfo->ringSignalAux[innerCommIndex]->Post(stream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + "main post inner [%u] failed", innerCommIndex), ret); + + // 超节点间做allgather + ret = AllGatherLevel2Executor(tag, inputMem, outputMem, count, dataType, op, stream); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor]tag[%s], all_gather failed, "\ + "return[%d]", tag.c_str(), ret), ret); + } + if (sendSize[innerCommIndex] == 0 || (!GetExternalInputEnableRdmaSdmaConcurrent() && innerCommIndex > 0)) { + continue; + } + + auto ret = streamInfo->ringSignal[innerCommIndex]->Wait(stream, dispatcher_, PROF_STAGE_2); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AllGatherOperator][AllGatherDoubleRingConcurrentExecutor] "\ + "main wait inner [%u] failed", innerCommIndex), ret); + } + HCCL_INFO("all gather double ring inner run success"); + return HCCL_SUCCESS; +} + HcclResult AllGatherOperator::AllGatherLevel2Executor(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, const HcomCollOpInfo *opInfo) { @@ -1198,4 +1446,125 @@ HcclResult AllGatherOperator::CalculateLevel2AllgatherSlice(u64 inputMemSize, u3 return HCCL_SUCCESS; } +HcclResult AllGatherOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) +{ + if (userRankSize_ == 1 && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + algName = "AllGatherSingleExecutor"; + return HCCL_SUCCESS; + } + HcclResult ret; + if (deviceType_ == DevType::DEV_TYPE_310P3) { + ret = SelectAlgfor310P3(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910) { + ret = SelectAlgfor910A(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910B) { + ret = SelectAlgfor910B(param, algName); + } else { + ret = SelectAlgfor91073(param, algName); + } + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + newTag = tag; + } else if (deviceType_ == DevType::DEV_TYPE_310P3) { + newTag = tag + algName; + } else { + AlgTypeLevel1 algType1 = GetLevel1AlgType(algType_); + auto level1Iter = HCCL_ALGO_LEVEL1_NAME_MAP.find(algType1); + newTag = tag + level1Iter->second + algName; + } + HCCL_INFO("[SelectAlg] all_gather newTag is [%s]", newTag.c_str()); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherSelector][SelectAlg]tag[%s], all_gather failed, retrun[%d]", tag.c_str(), ret), ret); + return ret; +} + +HcclResult AllGatherOperator::SelectAlgfor310P3(const OpParam& param, std::string& algName) +{ + algName = "AllGatherFor310PExecutor"; + HCCL_INFO("[SelectAlgfor310P3] all_gather SelectAlgfor310P3 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult AllGatherOperator::SelectAlgfor910A(const OpParam& param, std::string& algName) +{ + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_4P_MESH || topoType_ == TopoType::TOPO_TYPE_2P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING || topoType_ == TopoType::TOPO_TYPE_8P_RING; + + if (isMeshTopo) { + algName = "AllGatherMeshExecutor"; + } else if (isRingTopo) { + algName = "AllGatherRingExecutor"; + } else { + algName = "AllGatherComm"; + } + HCCL_INFO("[SelectAlgfor910A] all_gather SelectAlgfor910A is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult AllGatherOperator::SelectAlgfor910B(const OpParam& param, std::string& algName) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + u64 dataSize = param.DataDes.count * unitSize; // 单位:字节 + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || + topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING; + + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !isSingleMeshAggregation_) { + u64 cclBufferSize = cclBufferManager_.GetOutCCLbufferSize() / userRankSize_; + std::string algTypeLevel1Tag; + CHK_RET(AutoSelectAlgTypeLevel1(HcclCMDType::HCCL_CMD_ALLGATHER, dataSize, cclBufferSize, algTypeLevel1Tag)); + if (param.opBaseAtraceInfo != nullptr) { + CHK_RET(param.opBaseAtraceInfo->SavealgtypeTraceInfo(algTypeLevel1Tag, param.tag)); + } + } + + if (isMeshTopo) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + if (isSingleMeshAggregation_) { + algName = "AllGatherMeshOpbaseExecutor"; + } else if (UseInterServerPipelineAlgo(algType_)) { + algName = "AllGatherMeshOpbasePipelineExecutor"; + } + } + if (algName.empty()) { + algName = "AllGatherMeshExecutor"; + } + } else if (isRingTopo) { + algName = "AllGatherRingExecutor"; + } else { + algName = "AllGatherComm"; + } + HCCL_INFO("[SelectAlgfor910B] all_gather SelectAlgfor910B is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult AllGatherOperator::SelectAlgfor91073(const OpParam& param, std::string& algName) +{ + if (GetExternalInputEnableRdmaSdmaConcurrent() && topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + if (!UseInterServerRingAlgo(algType_)) { + HcclResult ret = SetInterServerRingAlgo(algType_); + HCCL_WARNING("[AllGatherOperator][SelectAlgfor91073] concurrent only support ring in AlgoLevel1 yet, "\ + "default is algType=ring."); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherOperator][SelectAlgfor91073]errNo[0x%016llx] tag[%s], AllGather concurrent "\ + "set inter server ring algo failed", HCCL_ERROR_CODE(ret), param.tag.c_str()), ret); + } + algName = "AllGatherDoubleRingConcurrentExecutor"; + } else { + if (!(UseInterServerRingAlgo(algType_) || UseInterServerNBAlgo(algType_))) { + HcclResult ret = SetInterServerNHRAlgo(algType_); + HCCL_WARNING("[AllGatherOperator][SelectAlgfor91073] only support ring, NB and NHR in AlgoLevel1 yet, "\ + "default is algType=NHR."); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherOperator][SelectAlgfor91073]errNo[0x%016llx] tag[%s], AllGather set inter server "\ + "nhr algo failed", HCCL_ERROR_CODE(ret), param.tag.c_str()), ret); + } + algName = "AllGatherRingFor91073Executor"; + } + HCCL_INFO("[SelectAlgfor91073] all_gather SelectAlgfor91073 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_ALLGATHER, AllGather, AllGatherOperator); + } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.h b/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.h index fa371c9a88ef4946dbf5efb948f1c32025893415..8238fd49bf0279f9daa0f5a5aac71c1f4950286e 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/all_gather_operator.h @@ -16,12 +16,13 @@ namespace hccl { class AllGatherOperator : public CommonOperator { public: - AllGatherOperator(std::unique_ptr &pImpl); + AllGatherOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~AllGatherOperator(); HcclResult AllGather(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, Stream stream, HcomCollOpInfo *opInfo = nullptr); HcclResult AllGatherOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, Stream stream, const std::unique_ptr &opBaseAtraceInfo = nullptr); + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag) override; private: // all gather private HcclResult GetAllGatherOutPlaceSplitLoop(void* commOutputPtr, bool isMeshTopo, const u32 unitSize, @@ -67,6 +68,13 @@ private: HcclResult CalculateLevel2AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, u32 level2RankSize, std::vector dataSegsSlice, std::vector &level0DataSlice) const; + HcclResult SelectAlgfor310P3(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor910A(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor910B(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor91073(const OpParam& param, std::string& algName); }; } diff --git a/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.cc index a7937a7e8f6b7cdf25da58c103924b3d1dfd42b3..6c9d039f233a00d8ee36cc9ca2cd6af63884b07e 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.cc @@ -18,8 +18,9 @@ #include "coll_alg_op_registry.h" namespace hccl { -AllReduceOperator::AllReduceOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_ALLREDUCE) + +AllReduceOperator::AllReduceOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_ALLREDUCE) { } @@ -124,6 +125,8 @@ HcclResult AllReduceOperator::SelectAlg(const std::string& tag, const OpParam& p ret = SelectAlgfor910A(param, algName); } else if (deviceType_ == DevType::DEV_TYPE_910B) { ret = SelectAlgfor910B(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910_73) { + ret = SelectAlgfor91073(param, algName); } else { HCCL_ERROR("[SelectAlg] device type[%d] is out of range for selector.", deviceType_); return HCCL_E_NOT_SUPPORT; @@ -250,7 +253,7 @@ HcclResult AllReduceOperator::SelectAlgfor910B(const OpParam& param, std::string bool isCCLBufferGE16M = !isOpbase || (commInputSize >= HCCL_MID_COUNT_16_MB && commOutputSize >= HCCL_MID_COUNT_16_MB); bool isAivMode = GetExternalInputHcclAivMode() && IsSupportAIVReduce(param.DataDes.dataType, param.reduceType) && - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && isMesh && isCCLBufferGE16M && + topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && isMesh && isCCLBufferGE16M && (isSingleMeshAggregation_ || isSupportAivRdmaSmallCount || isSupportAivRdmaMidCount); if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { @@ -295,7 +298,7 @@ HcclResult AllReduceOperator::SelectAlgfor910B(const OpParam& param, std::string algName = "AllReduceRingExecutor"; } // 多机单卡/两卡 pipeline需单独做判断(pipeline无确定性算法,并只支持单算子模式) - } else if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && + } else if (topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && GetLevel1AlgType(algType_) == AlgTypeLevel1::ALG_LEVEL1_PIPELINE && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && IsMultiMeshInlineReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType)) { @@ -317,7 +320,7 @@ HcclResult AllReduceOperator::SelectAlgfor910B(const OpParam& param, std::string GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE)) { algName = "AllReduceMeshExecutor"; // 非确定性算法 - } else if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE) { + } else if (topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE) { ret = NonDeterministicSelector(param, algName, dataSize); // 确定性算法 } else { @@ -390,6 +393,23 @@ HcclResult AllReduceOperator::DeterministicSelector(const OpParam& param, std::s return HCCL_SUCCESS; } +HcclResult AllReduceOperator::SelectAlgfor91073(const OpParam& param, std::string& algName) +{ + if (topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING) { + algName = "AllReduceRingExecutor"; + } else if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + algName = "AllReduceDoubleRingConcurrentExecutor"; + } else { + algName = "AllReduceDoubleRingExecutor"; + } + } else { + algName = "AllReduceComm"; + } + HCCL_INFO("[SelectAlgfor91073] all_reduce SelectAlgfor91073 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + REGISTER_OP(HcclCMDType::HCCL_CMD_ALLREDUCE, AllReduce, AllReduceOperator); } diff --git a/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.h b/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.h index 4f8c409f5aac6f19e15caf2c569adfc3357b04c0..1f66d4f9d88464e8095e62ed4b9090892e028875 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/all_reduce_operator.h @@ -24,7 +24,7 @@ enum class HcclDataCountType { class AllReduceOperator : public CommonOperator { public: - AllReduceOperator(std::unique_ptr &pImpl); + AllReduceOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~AllReduceOperator(); HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); HcclResult GetAllReduceScratchSize(const u32 count, const HcclDataType dataType, u64 &scratchSize); @@ -40,6 +40,8 @@ private: HcclResult SelectAlgfor910B(const OpParam& param, std::string& algName); + HcclResult SelectAlgfor91073(const OpParam& param, std::string& algName); + HcclResult MeshTopoSelector(std::string& algName, u64 unitSize); HcclResult DeterministicSelector(const OpParam& param, std::string& algName); diff --git a/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.cc index 673bf65adbe292c3ea5623e7eeaf102076c2aaa3..cc0ea43339a4a75aad28049f2c7c5824179684bb 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.cc @@ -16,402 +16,338 @@ #include #include "allltoall_pipeline_mesh_pairwise_ccl_enough_pub.h" #include "allltoall_pipeline_mesh_pairwise_ping_pong_pub.h" +#include "coll_alg_exec_registry.h" +#include "coll_alg_op_registry.h" namespace hccl { -AlltoAllOperator::AlltoAllOperator(std::unique_ptr &pImpl) - : CollAlgOperator(pImpl, HcclCMDType::HCCL_CMD_ALLTOALL) +AlltoAllOperator::AlltoAllOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CollAlgOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_ALLTOALL) { - hcclImpl_->GetAlltoAllStatus(tinySendRecvMem_, isAlltoAllZCopyMode_, isAlltoAllZCopyModeMap_); + hcclImpl_->GetAlltoAllStatus(tinySendRecvMem_, isAlltoAllZCopyMode_); } AlltoAllOperator::~AlltoAllOperator() { } -HcclResult AlltoAllOperator::AlltoAllVForOneRankSize(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag) +HcclResult AlltoAllOperator::CheckSendRecvParams( + const std::vector &allMeshAggregationSendRecvInfo) { - u32 sendTypeSize = 0, recvTypeSize = 0; - CHK_RET(SalGetDataTypeSize(sendType, sendTypeSize)); - CHK_RET(SalGetDataTypeSize(recvType, recvTypeSize)); - HCCL_PROFILER_ADD_STREAM(stream.ptr(), tag, 0, algType_); - u64 curSendCount = *(static_cast(sendCounts) + 0) + *(static_cast(sdispls) + 0); - u64 sendCount = 0; - sendCount = std::max(sendCount, curSendCount); - bool hugeData = (sendCount * sendTypeSize ) > SDMA_SEND_MAX_SIZE ; - if (sendBuf == recvBuf) { - // 通过CopyPattern字段区分不同的子图 - auto opMeta = HcclOpMetaInfo::GetOneForAllToAllV(CopyPattern::ZCOPY, sendCount * sendTypeSize, hugeData); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - } else { - auto opMeta = HcclOpMetaInfo::GetOneForAllToAllV(CopyPattern::BCOPY, sendCount * sendTypeSize,hugeData); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - DeviceMem srcMem = DeviceMem::create(const_cast(sendBuf), sendCount * sendTypeSize); - DeviceMem dstMem = DeviceMem::create(const_cast(recvBuf), sendCount * sendTypeSize); - HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream); // ranksize = 1; intput、output地址不同,input->output + u32 rankSize = allMeshAggregationSendRecvInfo.size(); + for (u32 i = 0; i < rankSize; i++) { + u32 sendsSize = allMeshAggregationSendRecvInfo[i].sendLength.size(); + u32 recvsSize = allMeshAggregationSendRecvInfo[i].recvLength.size(); + if (rankSize != sendsSize || rankSize != recvsSize) { + HCCL_ERROR( + "[AlltoAllV][CheckSendRecvParam] rankSize[%u], sendsSize[%u], recvsSize[%u] are not match Index[%u]", + rankSize, sendsSize, recvsSize, i); + return HCCL_E_PARA; + } + for (u32 j = 0; j < sendsSize; j++) { + if (allMeshAggregationSendRecvInfo[i].sendLength[j] != allMeshAggregationSendRecvInfo[j].recvLength[i]) { + HCCL_ERROR("SendLength[%u][%u]: %llu and recvLength[%u][%u]: %llu are not match", i, j, + allMeshAggregationSendRecvInfo[i].sendLength[j], j, i, + allMeshAggregationSendRecvInfo[j].recvLength[i]); + return HCCL_E_PARA; + } + } } - CHK_RET(LaunchTask(dispatcher_, stream)); - HCCL_PROFILER_DEL_STREAM(stream.ptr()); return HCCL_SUCCESS; } -HcclResult AlltoAllOperator::AlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag) +HcclResult AlltoAllOperator::GetAlltoAllvcSendRecvInfo(const void *sendCountMatrix, HcclDataType sendType, + HcclDataType recvType) { - /* ------------集合通信资源准备------------ */ - HcclUs startut = TIME_NOW(); - - auto rtStream = stream.ptr(); - u32 sendTypeSize = 0, recvTypeSize = 0; - CHK_RET(SalGetDataTypeSize(sendType, sendTypeSize)); - CHK_RET(SalGetDataTypeSize(recvType, recvTypeSize)); - - if (userRankSize_ == 1 && (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE)) - { - CHK_RET(AlltoAllVForOneRankSize(sendBuf,sendCounts,sdispls,sendType,recvBuf,recvCounts,rdispls,recvType,stream,tag)); - return HCCL_SUCCESS ; - } - - CHK_RET(notifyPool_->RegisterOp(tag)); - u64 sendCount = 0; - u64 recvCount = 0; + allMeshAggregationSendRecvInfo_.clear(); for (u32 i = 0; i < userRankSize_; i++) { - u64 curSendCount = *(static_cast(sendCounts) + i) + *(static_cast(sdispls) + i); - sendCount = std::max(sendCount, curSendCount); - u64 curRecvCount = *(static_cast(recvCounts) + i) + *(static_cast(rdispls) + i); - recvCount = std::max(recvCount, curRecvCount); - } - - // sendCount或recvCount为0时, 使用默认分配的内存空间, 避免sendMem和recvMem为空 - DeviceMem sendMem = sendCount == 0 ? - DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(sendBuf), sendCount * sendTypeSize); - DeviceMem recvMem = recvCount == 0 ? - DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(recvBuf), recvCount * recvTypeSize); - - bool useOneLevelAlgorithm = - (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_NA && - (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE || - NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_))); // 用户配置打平 alltoall + SendRecvInfo sendRecvInfo; + sendRecvInfo.sendCounts.resize(userRankSize_); + sendRecvInfo.sendDispls.resize(userRankSize_); + sendRecvInfo.sendLength.resize(userRankSize_); + sendRecvInfo.sendOffset.resize(userRankSize_); + u64 curSendDispls = 0; + u64 curSendOffset = 0; - std::vector allMeshAggregationSendRecvInfo; - CHK_RET(GetAllMeshAggregationSendRecvInfo(sendCounts, sdispls, sendType, recvCounts, rdispls, recvType, - allMeshAggregationSendRecvInfo, stream)); - UpdateAlltoAllZCopyMode(allMeshAggregationSendRecvInfo, tag); - // NA+pairwise算法不支持A+X跨mesh两卡 - bool isSingleDeviceModuleP2p = (userRankSize_ <= HCCL_ALLTOALLV_P2P_SIZE); + sendRecvInfo.recvCounts.resize(userRankSize_); + sendRecvInfo.recvDispls.resize(userRankSize_); + sendRecvInfo.recvLength.resize(userRankSize_); + sendRecvInfo.recvOffset.resize(userRankSize_); + u64 curRecvDispls = 0; + u64 curRecvOffset = 0; + // sendCountMatrix[i * userRankSize_ + j] 代表rank i发送到rank j的count参数 + for (u32 j = 0; j < userRankSize_; j++) { + u64 curSendCounts = *(static_cast(sendCountMatrix) + i * userRankSize_ + j); + u64 curSendLength = curSendCounts * SIZE_TABLE[sendType]; + sendRecvInfo.sendCounts[j] = curSendCounts; + sendRecvInfo.sendDispls[j] = curSendDispls; + sendRecvInfo.sendLength[j] = curSendLength; + sendRecvInfo.sendOffset[j] = curSendOffset; + curSendDispls += curSendCounts; + curSendOffset += curSendLength; - HCCL_PROFILER_ADD_STREAM(rtStream, tag, 0, algType_); + u64 curRecvCounts = *(static_cast(sendCountMatrix) + i + userRankSize_ * j); + u64 curRecvLength = curRecvCounts * SIZE_TABLE[recvType]; + sendRecvInfo.recvCounts[j] = curRecvCounts; + sendRecvInfo.recvDispls[j] = curRecvDispls; + sendRecvInfo.recvLength[j] = curRecvLength; + sendRecvInfo.recvOffset[j] = curRecvOffset; + curRecvDispls += curRecvCounts; + curRecvOffset += curRecvLength; - // 暂时先支持单算子模式 - if (IsSatisfyAlltoallPipelineCondition()) { - HCCL_RUN_INFO("[AlltoAllOperator][AlltoAllV] running alltoallv intra mesh inter pairwise pipeline"); - RunAlltoAllVTwoLevelPipeline(sendMem, recvMem, allMeshAggregationSendRecvInfo, stream, tag); - } else if (useOneLevelAlgorithm || isAllRankSamePlane_ || isSingleDeviceModuleP2p || - multiModuleDiffDeviceNumMode_) { - HCCL_INFO("[hcclImpl][AlltoAllV] running alltoallv full-mesh implementation"); - CHK_RET(hcclImpl_->CreateCommForAlltoAllFullMesh(tag, sendMem, recvMem)); - CHK_RET(hcclImpl_->RegisterToHeartBeat()); - HCCL_INFO("resource creation (AlltoAllV Full Mesh) success, take time [%lld]us, tag[%s]", - DURATION_US(TIME_NOW() - startut), tag.c_str()); - CHK_RET(RunAlltoAllVFullMesh( - sendMem, sendType, recvMem, recvType, allMeshAggregationSendRecvInfo, stream, tag)); - } else { // 当前如果是910B的16P场景,单server内跨组网也走分级,但是PCIE - HCCL_INFO("[hcclImpl][AlltoAllV] running alltoallv staged implementation"); - CHK_RET(RunAlltoAllVStaged(sendMem, sendType, recvMem, recvType, - allMeshAggregationSendRecvInfo, stream, tag)); + HCCL_DEBUG("GetAlltoAllvcSendRecvInfo rank[%u], sendCounts[%llu], sendDispls[%llu] "\ + "recvCounts[%llu], recvDispls[%llu]", i, sendRecvInfo.sendCounts[j], sendRecvInfo.sendDispls[j], + sendRecvInfo.recvCounts[j], sendRecvInfo.recvDispls[j]); + } + allMeshAggregationSendRecvInfo_.push_back(sendRecvInfo); } - - CHK_RET(notifyPool_->UnregisterOp(tag)); - - HCCL_INFO("tag[%s],alltoallv run success,take time [%lld]us", tag.c_str(), DURATION_US(TIME_NOW() - startut)); - + CHK_RET(CheckSendRecvParams(allMeshAggregationSendRecvInfo_)); return HCCL_SUCCESS; } -HcclResult AlltoAllOperator::AlltoAllVOutPlace(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag) +void AlltoAllOperator::UpdateAlltoAllCopyMode(std::vector &allMeshAggregationSendRecvInfo, + std::string& copyMode) { - /* ------------集合通信资源准备------------ */ - HcclUs startut = TIME_NOW(); - auto rtStream = stream.ptr(); - u32 sendTypeSize = 0, recvTypeSize = 0; - CHK_RET(SalGetDataTypeSize(sendType, sendTypeSize)); - CHK_RET(SalGetDataTypeSize(recvType, recvTypeSize)); - - if (userRankSize_ == 1 ) - { - CHK_RET(AlltoAllVForOneRankSize(sendBuf,sendCounts,sdispls,sendType,recvBuf,recvCounts,rdispls,recvType,stream,tag) ); - return HCCL_SUCCESS ; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + u64 maxSendSize = 0; + u64 maxRecvSize = 0; + for (auto &sendRecvInfo : allMeshAggregationSendRecvInfo) { + for (u32 i = 0; i < userRankSize_; i++) { + u64 curSendSize = sendRecvInfo.sendLength[i] + sendRecvInfo.sendOffset[i]; + maxSendSize = std::max(maxSendSize, curSendSize); + u64 curRecvSize = sendRecvInfo.recvLength[i] + sendRecvInfo.recvOffset[i]; + maxRecvSize = std::max(maxRecvSize, curRecvSize); + } + } + bool isAlltoAllZCopyMode = (maxSendSize <= GetExternalInputCCLBuffSize()) && + (maxRecvSize <= GetExternalInputCCLBuffSize()); + if (isAlltoAllZCopyMode) { + copyMode = "ZCopy"; + } + HCCL_INFO("[AlltoAllOperator][UpdateAlltoAllCopyMode] maxSendSize[%llu], maxRecvSize[%llu], "\ + "cclBufferSize[%llu], CopyMode[%s]", maxSendSize, maxRecvSize, + GetExternalInputCCLBuffSize(), copyMode.c_str()); + } else { + // 图模式走ZCopy实现 + copyMode = "ZCopy"; } +} - CHK_RET(notifyPool_->RegisterOp(tag)); - u64 sendCount = 0; - u64 recvCount = 0; +HcclResult AlltoAllOperator::GetAlltoAllvSendRecvInfo(const OpParam& param, const HostMem &alltoallAddrInfoGathered) +{ + allMeshAggregationSendRecvInfo_.clear(); + u64 stepSize = sizeof(u64) * userRankSize_; + const u32 addrItemNum = 4; + const u32 recvLengthStep = 2; + const u32 recvOffsetStep = 3; for (u32 i = 0; i < userRankSize_; i++) { - u64 curSendCount = *(static_cast(sendCounts) + i) + *(static_cast(sdispls) + i); - sendCount = std::max(sendCount, curSendCount); - u64 curRecvCount = *(static_cast(recvCounts) + i) + *(static_cast(rdispls) + i); - recvCount = std::max(recvCount, curRecvCount); + SendRecvInfo sendRecvInfo; + sendRecvInfo.sendLength.resize(userRankSize_); + sendRecvInfo.sendOffset.resize(userRankSize_); + sendRecvInfo.recvLength.resize(userRankSize_); + sendRecvInfo.recvOffset.resize(userRankSize_); + CHK_SAFETY_FUNC_RET(memcpy_s(sendRecvInfo.sendLength.data(), + stepSize, + static_cast(alltoallAddrInfoGathered.ptr()) + i * stepSize * addrItemNum + 0 * stepSize, + stepSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(sendRecvInfo.sendOffset.data(), + stepSize, + static_cast(alltoallAddrInfoGathered.ptr()) + i * stepSize * addrItemNum + stepSize, + stepSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(sendRecvInfo.recvLength.data(), + stepSize, + static_cast(alltoallAddrInfoGathered.ptr()) + i * stepSize * addrItemNum + recvLengthStep * stepSize, + stepSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(sendRecvInfo.recvOffset.data(), + stepSize, + static_cast(alltoallAddrInfoGathered.ptr()) + i * stepSize * addrItemNum + recvOffsetStep * stepSize, + stepSize)); + allMeshAggregationSendRecvInfo_.push_back(std::move(sendRecvInfo)); } - // sendCount或recvCount为0时, 使用默认分配的内存空间, 避免sendMem和recvMem为空 - DeviceMem sendMem = sendCount == 0 ? DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(sendBuf), sendCount * sendTypeSize); - DeviceMem recvMem = recvCount == 0 ? DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(recvBuf), recvCount * recvTypeSize); + for (auto &sendRecvInfo : allMeshAggregationSendRecvInfo_) { + for (u32 i = 0; i < userRankSize_; i++) { + sendRecvInfo.sendCounts.push_back(sendRecvInfo.sendLength[i] / SIZE_TABLE[param.All2AllDataDes.sendType]); + sendRecvInfo.sendDispls.push_back(sendRecvInfo.sendOffset[i] / SIZE_TABLE[param.All2AllDataDes.sendType]); + sendRecvInfo.recvCounts.push_back(sendRecvInfo.recvLength[i] / SIZE_TABLE[param.All2AllDataDes.recvType]); + sendRecvInfo.recvDispls.push_back(sendRecvInfo.recvOffset[i] / SIZE_TABLE[param.All2AllDataDes.recvType]); + HCCL_INFO("[GetAllMeshAggregationSendRecvInfo] rank[%u], sendCounts[%llu], sendDispls[%llu], "\ + "recvCounts[%llu], recvDispls[%llu]", i, sendRecvInfo.sendCounts[i], sendRecvInfo.sendDispls[i], + sendRecvInfo.recvCounts[i], sendRecvInfo.recvDispls[i]); + HCCL_INFO("[GetAllMeshAggregationSendRecvInfo] rank[%u], sendLength[%llu], sendOffset[%llu], "\ + "recvLength[%llu], recvOffset[%llu]", i, sendRecvInfo.sendLength[i], sendRecvInfo.sendOffset[i], + sendRecvInfo.recvLength[i], sendRecvInfo.recvOffset[i]); + } + } + + CHK_RET(CheckSendRecvParams(allMeshAggregationSendRecvInfo_)); + + return HCCL_SUCCESS; +} + +HcclResult AlltoAllOperator::SelectAlgforAlltoAll(const OpParam& param, std::string& algName, std::string& copyMode) +{ bool useOneLevelAlgorithm = (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_NA && (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE || - NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_))); // 用户配置打平 alltoall + CollAlgOperator::NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_))); + // 用户配置打平 alltoall - std::vector allMeshAggregationSendRecvInfo; - CHK_RET(GetAllMeshAggregationSendRecvInfo(sendCounts, sdispls, sendType, recvCounts, rdispls, recvType, - allMeshAggregationSendRecvInfo, stream)); - UpdateAlltoAllZCopyMode(allMeshAggregationSendRecvInfo, tag); - HCCL_PROFILER_ADD_STREAM(rtStream, tag, 0, algType_); - CopyPattern copyPattern = isAlltoAllZCopyMode_? CopyPattern::ZCOPY : CopyPattern::BCOPY; - - bool massTasks = HasMassTasks(allMeshAggregationSendRecvInfo); - /* zcopy拆分4GB以上SDMA任务前,准备好子图不复用标志 */ - bool hugeData = false; - if (copyPattern == CopyPattern::ZCOPY) { - hugeData = sendMem.size() > SDMA_SEND_MAX_SIZE; - } - auto opMeta = HcclOpMetaInfo::GetOneForAllToAllV(copyPattern, sendMem.size(), hugeData); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - if (massTasks) { - CHK_RET(SetNormalMode(dispatcher_)); - } // NA+pairwise算法不支持A+X跨mesh两卡 bool isSingleDeviceModuleP2p = (userRankSize_ <= HCCL_ALLTOALLV_P2P_SIZE); - bool alltoallPingPong = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - !multiModuleDiffDeviceNumMode_ && GetAlltoall2LevelPipelineMaxScratchSize910B(allMeshAggregationSendRecvInfo) > - cclBufferManager_.GetInCCLbuffer().size()); - // 暂时先支持单算子模式 + if (IsSatisfyAlltoallPipelineCondition()) { - HCCL_RUN_INFO("[AlltoAllOperator][AlltoAllVOutPlace] running alltoallv intra mesh inter pairwise pipeline"); - auto opMeta = HcclOpMetaInfo::GetOneForAllToAllV(copyPattern, sendMem.size(), - hugeData || alltoallPingPong); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - RunAlltoAllVTwoLevelPipeline(sendMem, recvMem, allMeshAggregationSendRecvInfo, stream, tag); + algName = "RunAlltoAllVTwoLevelPipeline"; } else if (useOneLevelAlgorithm || isAllRankSamePlane_ || isSingleDeviceModuleP2p || multiModuleDiffDeviceNumMode_) { - HCCL_INFO("[hcclImpl][AlltoAllV] running alltoallv full-mesh implementation"); - CHK_RET(hcclImpl_->CreateCommForAlltoAllFullMesh(tag, sendMem, recvMem)); - CHK_RET(hcclImpl_->RegisterToHeartBeat()); - HCCL_INFO("resource creation (AlltoAllV Full Mesh) success, take time [%lld]us, tag[%s]", - DURATION_US(TIME_NOW() - startut), tag.c_str()); - CHK_RET(RunAlltoAllVFullMesh( - sendMem, sendType, recvMem, recvType, allMeshAggregationSendRecvInfo, stream, tag)); - } else { // 当前如果是910B的16P场景,单server内跨组网也走分级,但是PCIE - HCCL_INFO("[hcclImpl][AlltoAllV] running alltoallv staged implementation"); - CHK_RET(RunAlltoAllVStaged(sendMem, sendType, recvMem, recvType, - allMeshAggregationSendRecvInfo, stream, tag)); + algName = "RunAlltoAllVFullMesh"; + } else { + algName = "RunAlltoAllVStaged"; } - CHK_RET(LaunchTask(dispatcher_, stream)); - CHK_RET(notifyPool_->UnregisterOp(tag)); - HCCL_INFO("tag[%s],alltoallv run success,take time [%lld]us", tag.c_str(), DURATION_US(TIME_NOW() - startut)); - return HCCL_SUCCESS; -} - -HcclResult AlltoAllOperator::AlltoAllVCForOneRankSize(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag) -{ - u32 sendTypeSize = 0, recvTypeSize = 0; - CHK_RET(SalGetDataTypeSize(sendType, sendTypeSize)); - CHK_RET(SalGetDataTypeSize(recvType, recvTypeSize)); - - HCCL_PROFILER_ADD_STREAM(stream.ptr(), tag, 0, algType_); - u64 sendCounts = *(static_cast(sendCountMatrix) + userRank_ * userRankSize_ + 0); - bool hugeData = (sendCounts * sendTypeSize ) > SDMA_SEND_MAX_SIZE ; - if (sendBuf == recvBuf) { - auto opMeta = HcclOpMetaInfo::GetOneForAllToAllVC(CopyPattern::ZCOPY, sendCounts * sendTypeSize, hugeData); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - if (!GetExternalInputHcclEnableFfts()) { - CHK_RET(SetNormalMode(dispatcher_)); - } + if (param.opType == HcclCMDType::HCCL_CMD_ALLTOALLV) { + // alltoallv + CHK_RET(GetAlltoAllvSendRecvInfo(param, hostCollectBuffer_)); + } else if (param.opType == HcclCMDType::HCCL_CMD_ALLTOALLVC || param.opType == HcclCMDType::HCCL_CMD_ALLTOALL){ + // alltoallvc&&alltoall + CHK_RET(GetAlltoAllvcSendRecvInfo(param.All2AllDataDes.sendCountMatrix, param.All2AllDataDes.sendType, + param.All2AllDataDes.recvType)); } else { - auto opMeta = HcclOpMetaInfo::GetOneForAllToAllVC(CopyPattern::BCOPY, sendCounts * sendTypeSize, hugeData); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - if (!GetExternalInputHcclEnableFfts()) { - CHK_RET(SetNormalMode(dispatcher_)); - } - DeviceMem srcMem = DeviceMem::create(const_cast(sendBuf), sendCounts * sendTypeSize); - DeviceMem dstMem = DeviceMem::create(const_cast(recvBuf), sendCounts * sendTypeSize); - HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream); // ranksize = 1; intput、output地址不同,input->output + HCCL_ERROR("[AlltoAllOperator][SelectAlgforAlltoAll] get wrong opType"); + return HCCL_E_PARA; } - CHK_RET(LaunchTask(dispatcher_, stream)); - HCCL_PROFILER_DEL_STREAM(stream.ptr()); + UpdateAlltoAllCopyMode(allMeshAggregationSendRecvInfo_, copyMode); + + HCCL_INFO("[SelectAlgforAlltoAll] all_to_all SelectAlgforAlltoAll is algName [%s]", algName.c_str()); return HCCL_SUCCESS; } -HcclResult AlltoAllOperator::AlltoAllVC(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag) +HcclResult AlltoAllOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) { - /* ------------集合通信资源准备------------ */ - HcclUs startut = TIME_NOW(); - - u32 sendTypeSize = 0, recvTypeSize = 0; - CHK_RET(SalGetDataTypeSize(sendType, sendTypeSize)); - CHK_RET(SalGetDataTypeSize(recvType, recvTypeSize)); - - if (userRankSize_ == 1 && (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE)) - { - CHK_RET(AlltoAllVCForOneRankSize(sendBuf,sendCountMatrix,sendType,recvBuf,recvType,stream,tag)); - return HCCL_SUCCESS ; - } + HcclResult ret; + std::string copyMode = "BCopy"; - CHK_RET(notifyPool_->RegisterOp(tag)); - u64 sendCount = 0; - u64 recvCount = 0; - for (u32 i = 0; i < userRankSize_; i++) { - sendCount += *(static_cast(sendCountMatrix) + userRank_ * userRankSize_ + i); - recvCount += *(static_cast(sendCountMatrix) + userRank_ + userRankSize_ * i); - } + ret = SelectAlgforAlltoAll(param, algName, copyMode); - // sendCount或recvCount为0时, 使用默认分配的内存空间, 避免sendMem和recvMem为空 - DeviceMem sendMem = sendCount == 0 ? - DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(sendBuf), sendCount * sendTypeSize); - DeviceMem recvMem = recvCount == 0 ? - DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(recvBuf), recvCount * recvTypeSize); - - bool useOneLevelAlgorithm = - (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_NA && - (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE || - NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_))); // 用户配置打平 alltoall - - std::vector allMeshAggregationSendRecvInfo; - CHK_RET(GetAlltoAllvcAllSendRecvInfo(sendCountMatrix, sendType, recvType, allMeshAggregationSendRecvInfo)); - UpdateAlltoAllZCopyMode(allMeshAggregationSendRecvInfo, tag); - // NA+pairwise算法不支持A+X跨mesh两卡 - bool isSingleDeviceModuleP2p = (userRankSize_ <= HCCL_ALLTOALLV_P2P_SIZE); - - HCCL_PROFILER_ADD_STREAM(stream.ptr(), tag, 0, algType_); - - // 暂时先支持单算子模式 - if (IsSatisfyAlltoallPipelineCondition()) { - HCCL_INFO("[AlltoAllOperator][AlltoAllVC] running alltoallvc intra mesh inter pairwise pipeline"); - RunAlltoAllVTwoLevelPipeline(sendMem, recvMem, allMeshAggregationSendRecvInfo, stream, tag); - } else if (useOneLevelAlgorithm || isAllRankSamePlane_ || isSingleDeviceModuleP2p || - multiModuleDiffDeviceNumMode_) { - HCCL_INFO("[hcclImpl][AlltoAllVC] running alltoallvc full-mesh implementation"); - CHK_RET(hcclImpl_->CreateCommForAlltoAllFullMesh(tag, sendMem, recvMem)); - CHK_RET(hcclImpl_->RegisterToHeartBeat()); - HCCL_INFO("resource creation (AlltoAllVC Full Mesh) success, take time [%lld]us, tag[%s]", - DURATION_US(TIME_NOW() - startut), tag.c_str()); - CHK_RET(RunAlltoAllVFullMesh( - sendMem, sendType, recvMem, recvType, allMeshAggregationSendRecvInfo, stream, tag)); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + newTag = tag + algName + copyMode; } else { - HCCL_INFO("[hcclImpl][AlltoAllVC] running alltoallvc staged implementation"); - CHK_RET(RunAlltoAllVStaged(sendMem, sendType, recvMem, recvType, - allMeshAggregationSendRecvInfo, stream, tag)); + newTag = tag; } - - CHK_RET(notifyPool_->UnregisterOp(tag)); - HCCL_PROFILER_DEL_STREAM(stream.ptr()); - HCCL_INFO("tag[%s], alltoallvc run success,take time [%lld]us", tag.c_str(), DURATION_US(TIME_NOW() - startut)); - return HCCL_SUCCESS; + HCCL_INFO("[SelectAlg] all_to_all newTag is [%s]", newTag.c_str()); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[SelectAlgforAlltoAll][SelectAlg]tag[%s], all_reduce failed, return[%d]", tag.c_str(), ret), ret); + CHK_RET(SetExcutorExtraInfo(algName)); + return ret; } -HcclResult AlltoAllOperator::AlltoAllVCOutPlace(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag) +HcclResult AlltoAllOperator::GetAlltoAllvAllAddrInfo(u64 *sendLength, u64 *sendOffset, + u64 *recvLength, u64 *recvOffset, Stream &stream, std::unique_ptr &preMetaInfo) { - std::vector allMeshAggregationSendRecvInfo; - CHK_RET(GetAlltoAllvcAllSendRecvInfo(sendCountMatrix, sendType, recvType, allMeshAggregationSendRecvInfo)); - UpdateAlltoAllZCopyMode(allMeshAggregationSendRecvInfo, tag); + const u32 addrItemNum = 4; + u64 stepSize = sizeof(u64) * userRankSize_; - /* ------------集合通信资源准备------------ */ - HcclUs startut = TIME_NOW(); + std::vector alltoallAddrInfo(userRankSize_ * addrItemNum, 0); + const u32 recvLengthStep = 2; + const u32 recvOffsetStep = 3; - u32 sendTypeSize = 0, recvTypeSize = 0; - CHK_RET(SalGetDataTypeSize(sendType, sendTypeSize)); - CHK_RET(SalGetDataTypeSize(recvType, recvTypeSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(&alltoallAddrInfo[0], stepSize, sendLength, stepSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(&alltoallAddrInfo[userRankSize_], stepSize, sendOffset, stepSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(&alltoallAddrInfo[recvLengthStep * userRankSize_], stepSize, recvLength, stepSize)); + CHK_SAFETY_FUNC_RET(memcpy_s(&alltoallAddrInfo[recvOffsetStep * userRankSize_], stepSize, recvOffset, stepSize)); - if (userRankSize_ == 1 ) { - CHK_RET(AlltoAllVCForOneRankSize(sendBuf,sendCountMatrix,sendType,recvBuf,recvType,stream,tag)) ; - return HCCL_SUCCESS; - } - u64 sendCount = 0; - u64 recvCount = 0; - for (u32 i = 0; i < userRankSize_; i++) { - sendCount += *(static_cast(sendCountMatrix) + userRank_ * userRankSize_ + i); - recvCount += *(static_cast(sendCountMatrix) + userRank_ + userRankSize_ * i); - } + preMetaInfo->inputData = alltoallAddrInfo; + preMetaInfo->inputSize = stepSize * addrItemNum; + preMetaInfo->outputSize = userRankSize_ * stepSize * addrItemNum; - CHK_RET(notifyPool_->RegisterOp(tag)); + return HCCL_SUCCESS; +} - // sendCount或recvCount为0时, 使用默认分配的内存空间, 避免sendMem和recvMem为空 - DeviceMem sendMem = sendCount == 0 ? DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(sendBuf), sendCount * sendTypeSize); - DeviceMem recvMem = recvCount == 0 ? DeviceMem::create(tinySendRecvMem_.ptr(), tinySendRecvMem_.size()) : - DeviceMem::create(const_cast(recvBuf), recvCount * recvTypeSize); +HcclResult AlltoAllOperator::PrepareAlltoAllAddrInfo(const void *sendCounts, const void *sdispls, + HcclDataType sendType, const void *recvCounts, const void *rdispls, HcclDataType recvType, + Stream &stream, std::unique_ptr &preMetaInfo) +{ + std::vector vctSendLength(userRankSize_, 0); + std::vector vctSendOffset(userRankSize_, 0); + std::vector vctRecvLength(userRankSize_, 0); + std::vector vctRecvOffset(userRankSize_, 0); - bool useOneLevelAlgorithm = - (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_NA && - (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE || - NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_))); // 用户配置打平 alltoall + for (u32 i = 0; i < userRankSize_; i++) { + vctSendLength[i] = *(static_cast(sendCounts) + i) * SIZE_TABLE[sendType]; + vctSendOffset[i] = *(static_cast(sdispls) + i) * SIZE_TABLE[sendType]; + vctRecvLength[i] = *(static_cast(recvCounts) + i) * SIZE_TABLE[recvType]; + vctRecvOffset[i] = *(static_cast(rdispls) + i) * SIZE_TABLE[recvType]; - bool massTasks = HasMassTasks(allMeshAggregationSendRecvInfo); - // 子图适配,bcopy每次重新生成子图 - HcclOpMetaInfo meta; - bool hugeData = sendMem.size() > SDMA_SEND_MAX_SIZE; - if (isAlltoAllZCopyMode_) { - /* zcopy拆分4GB以上SDMA任务前,准备好子图不复用标志 */ - meta = HcclOpMetaInfo::GetOneForAllToAllVC(CopyPattern::ZCOPY, sendMem.size(), hugeData); - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); + HCCL_DEBUG("[GetAllMeshAggregationSendRecvInfo] rank[%u], SendLength[%llu], SendOffset[%llu], "\ + "RecvLength[%llu], RecvOffset[%llu]", i, vctSendLength[i], vctSendOffset[i], vctRecvLength[i], + vctRecvOffset[i]); + } + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + CHK_RET(GetAlltoAllvAllAddrInfo(vctSendLength.data(), vctSendOffset.data(), + vctRecvLength.data(), vctRecvOffset.data(), stream, preMetaInfo)); } else { - meta = HcclOpMetaInfo::GetOneForAllToAllVC(CopyPattern::BCOPY, sendMem.size(), false); - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); - if (massTasks) { - CHK_RET(SetNormalMode(dispatcher_)); - } + HCCL_INFO("Run with Graph, alloc new stream"); + Stream graphStream(StreamType::STREAM_TYPE_ONLINE); + CHK_RET(GetAlltoAllvAllAddrInfo(vctSendLength.data(), vctSendOffset.data(), + vctRecvLength.data(), vctRecvOffset.data(), graphStream, preMetaInfo)); } - // NA+pairwise算法不支持RDMA不使能下时A+X跨mesh两卡 - bool isSingleDeviceModuleP2p = (userRankSize_ <= HCCL_ALLTOALLV_P2P_SIZE); - bool alltoallPingPong = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - !multiModuleDiffDeviceNumMode_ && GetAlltoall2LevelPipelineMaxScratchSize910B(allMeshAggregationSendRecvInfo) > - cclBufferManager_.GetInCCLbuffer().size()); - HCCL_PROFILER_ADD_STREAM(stream.ptr(), tag, 0, algType_); + return HCCL_SUCCESS; +} - // 暂时先支持单算子模式 - if (IsSatisfyAlltoallPipelineCondition()) { - HCCL_RUN_INFO("[AlltoAllOperator][AlltoAllVCOutPlace] running alltoallvc intra mesh inter pairwise pipeline"); - meta = HcclOpMetaInfo::GetOneForAllToAllV(CopyPattern::BCOPY, sendMem.size(), - hugeData || alltoallPingPong); - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); - RunAlltoAllVTwoLevelPipeline(sendMem, recvMem, allMeshAggregationSendRecvInfo, stream, tag); - } else if (useOneLevelAlgorithm || isAllRankSamePlane_ || - isSingleDeviceModuleP2p || multiModuleDiffDeviceNumMode_) { // 只走pairWise - HCCL_INFO("[hcclImpl][AlltoAllVC] running alltoallvc full-mesh implementation"); - CHK_RET(hcclImpl_->CreateCommForAlltoAllFullMesh(tag, sendMem, recvMem)); - CHK_RET(hcclImpl_->RegisterToHeartBeat()); - HCCL_INFO("resource creation (AlltoAllVC Full Mesh) success, take time [%lld]us, tag[%s]", - DURATION_US(TIME_NOW() - startut), tag.c_str()); - CHK_RET(RunAlltoAllVFullMesh(sendMem, sendType, recvMem, recvType, - allMeshAggregationSendRecvInfo, stream, tag)); - } else { - HCCL_INFO("[hcclImpl][AlltoAllVC] running alltoallvc staged implementation"); - CHK_RET(RunAlltoAllVStaged(sendMem, sendType, recvMem, recvType, - allMeshAggregationSendRecvInfo, stream, tag)); +HcclResult AlltoAllOperator::PreparePreOpParam(OpParam& preProcessOpParam, + const std::unique_ptr &preMetaInfo, Stream &preProcessStream) +{ + u64 stepSize = sizeof(u64) * userRankSize_; + u32 perDataSize = SIZE_TABLE[HCCL_DATA_TYPE_UINT64]; + + preProcessOpParam.tag = HCCL_ALLTOALL_PARA_ALLGATHER; + preProcessOpParam.inputPtr = cclBufferManager_.GetInAlltoAllvParaBuffer().ptr(); + preProcessOpParam.inputSize = (preMetaInfo->outputSize / stepSize) * perDataSize; + preProcessOpParam.outputPtr = cclBufferManager_.GetOutAlltoAllvParaBuffer().ptr(); + preProcessOpParam.outputSize = (preMetaInfo->outputSize / stepSize) * perDataSize * userRankSize_; + preProcessOpParam.DataDes.count = (preMetaInfo->outputSize / stepSize); + preProcessOpParam.DataDes.dataType = HCCL_DATA_TYPE_UINT64; + preProcessOpParam.stream = preProcessStream; + return HCCL_SUCCESS; +} + +bool AlltoAllOperator::JudgeIfNeedPreProcessAndGetParam(const OpParam& param, + std::unique_ptr &preMetaInfo) +{ + if (param.opType == HcclCMDType::HCCL_CMD_ALLTOALLV) { + CHK_RET(PrepareAlltoAllAddrInfo(param.All2AllDataDes.sendCounts, param.All2AllDataDes.sdispls, + param.All2AllDataDes.sendType, param.All2AllDataDes.recvCounts, param.All2AllDataDes.rdispls, + param.All2AllDataDes.recvType, const_cast(param.stream), preMetaInfo)); + preMetaInfo->opType = HcclCMDType::HCCL_CMD_ALLGATHER; + return true; } + return false; +} - CHK_RET(LaunchTask(dispatcher_, stream)); +void AlltoAllOperator::SetPreProcessResult(HostMem hostCollectBuffer) +{ + hostCollectBuffer_ = std::move(hostCollectBuffer); +} - CHK_RET(notifyPool_->UnregisterOp(tag)); - HCCL_PROFILER_DEL_STREAM(stream.ptr()); - HCCL_INFO("tag[%s], alltoallvc run success,take time [%lld]us", tag.c_str(), DURATION_US(TIME_NOW() - startut)); - return HCCL_SUCCESS; +HcclResult AlltoAllOperator::SetExcutorExtraInfo(const std::string& algName) +{ + HCCL_DEBUG("[AlltoAllOperator][SetExcutorExtraInfo]algName[%s]", algName.c_str()); + if (executor_.get() == nullptr) { + executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, dispatcher_, topoMatcher_); + CHK_PRT_RET(executor_.get() == nullptr, + HCCL_ERROR("[CollAlgOperator][CalcResRequest]Fail to find executor for algName[%s]", algName.c_str()), + HCCL_E_PARA); + executor_->SetVirtualDispatcher(vDispatcher_); + ParallelTaskLoader* parallelTaskLoader = nullptr; + hcclImpl_->GetParallelTaskLoader(parallelTaskLoader); + executor_->SetParallelTaskLoader(parallelTaskLoader); + executor_->SetAlgType(algType_); + } + + return executor_->SetExcutorExtraInfo(allMeshAggregationSendRecvInfo_); } bool AlltoAllOperator::HasMassTasks(std::vector &allMeshAggregationSendRecvInfo) @@ -463,20 +399,6 @@ std::vector AlltoAllOperator::GenerateSendCountMatrix(u64 count, u32 rankSi return sendCountMatrix; } -HcclResult AlltoAllOperator::AlltoAll(const void *sendBuf, u64 sendCount, HcclDataType sendType, - const void *recvBuf, u64 recvCount, HcclDataType recvType, Stream stream, const std::string &tag) -{ - // 生成sendCountMatrix矩阵,alltoall的底层实现走alltoallvc - std::vector sendCountMatrix = GenerateSendCountMatrix(sendCount, userRankSize_); - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - GetExternalInputHcclEnableFfts()) { - CHK_RET(AlltoAllVCOutPlace(sendBuf, sendCountMatrix.data(), sendType, recvBuf, recvType, stream, tag)); - } else { - CHK_RET(AlltoAllVC(sendBuf, sendCountMatrix.data(), sendType, recvBuf, recvType, stream, tag)); - } - return HCCL_SUCCESS; -} - HcclResult AlltoAllOperator::GetAllMeshAggregationSendRecvInfo(const void *sendCounts, const void *sdispls, HcclDataType sendType, const void *recvCounts, const void *rdispls, HcclDataType recvType, std::vector& allMeshAggregationSendRecvInfo, Stream &stream) @@ -681,7 +603,7 @@ HcclResult AlltoAllOperator::GetAlltoAllvAllSendRecvInfo(u64 *sendLength, u64 *s HcclResult AlltoAllOperator::ExchangeSendRecvInfoFromAllGather(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, Stream stream) { - AllGatherOperator operation(hcclImpl_); + AllGatherOperator operation(hcclImpl_, topoMatcher_); CHK_RET(operation.AllGatherOutPlace(tag, inputPtr, outputPtr, inputCount, dataType, stream)); CHK_RET(hcclStreamSynchronize(stream.ptr())); return HCCL_SUCCESS; @@ -874,18 +796,18 @@ u64 AlltoAllOperator::GetAlltoall2LevelPipelineScratchSize910B( u32 rank, std::vector &allMeshAggregationSendRecvInfo) { - u64 userRankSize = allMeshAggregationSendRecvInfo.size(); - u64 maxBlockSize = 0; - u64 maxScratchSize = 0; - const SendRecvInfo& info = allMeshAggregationSendRecvInfo[rank]; - for (u64 i = 0; i < userRankSize; i++) { - maxBlockSize = std::max(maxBlockSize, info.sendLength[i]); - maxBlockSize = std::max(maxBlockSize, info.recvLength[i]); - maxScratchSize = std::max(maxScratchSize, info.sendOffset[i] + info.sendLength[i]); - maxScratchSize = std::max(maxScratchSize, info.recvOffset[i] + info.recvLength[i]); - } - maxScratchSize = std::max(maxBlockSize * userRankSize, maxScratchSize); - return maxScratchSize; + u64 scratchSize = 0; + u32 meshRankStart = (rank / meshAggregationRankSize_) * meshAggregationRankSize_; + u32 meshRankEnd = meshRankStart + meshAggregationRankSize_ - 1; + u32 rankIntraMesh = rank - meshRankStart; + for (u32 sendRank = rankIntraMesh, userRankSize = allMeshAggregationSendRecvInfo.size(); + sendRank < userRankSize; sendRank += meshAggregationRankSize_) { + const std::vector& remoteSendLength = allMeshAggregationSendRecvInfo[sendRank].sendLength; + const std::vector& remoteSendOffset = allMeshAggregationSendRecvInfo[sendRank].sendOffset; + scratchSize += (remoteSendOffset[meshRankEnd] + remoteSendLength[meshRankEnd] - + remoteSendOffset[meshRankStart]); + } + return scratchSize; } // 计算 alltoall pipeline 910B 的两级流水算法所有卡需要的 scratch 大小的最大值(单算子模式需要) @@ -900,311 +822,8 @@ u64 AlltoAllOperator::GetAlltoall2LevelPipelineMaxScratchSize910B( return maxScratchSize; } -HcclResult AlltoAllOperator::RunAlltoAllVTwoLevelPipeline(DeviceMem &sendBuf, DeviceMem &recvBuf, - std::vector &allMeshAggregationSendRecvInfo, Stream &stream, const std::string &tag) -{ - HCCL_INFO("[AlltoAllOperator][RunAlltoAllVTwoLevelPipeline] alltoall two level pipeline start"); - bool cclEnough = true; - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - GetAlltoall2LevelPipelineMaxScratchSize910B(allMeshAggregationSendRecvInfo) > - cclBufferManager_.GetInCCLbuffer().size()) { - cclEnough = false; - } - HCCL_DEBUG("[AlltoAllOperator][RunAlltoAllVTwoLevelPipeline] alltoall pipeline run %s algo", - cclEnough ? "cclEnough" : "ping pong"); - A2aPipelineMemory a2aPipelineMemory; - a2aPipelineMemory.userInput = sendBuf; - a2aPipelineMemory.userOutput = recvBuf; - // 具体传入 A2aPipelineMemory 对象的 alltoall pipeline executor 会根据图模式还是单算子模式 - // 选择使用 ccl 还是 scratch,不会访问空指针 - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(hcclImpl_->CreateCommForNoScratchAlltoall(tag, sendBuf, recvBuf)); - a2aPipelineMemory.cclInBuffer = cclBufferManager_.GetInCCLbuffer(); - a2aPipelineMemory.cclOutBuffer = cclBufferManager_.GetOutCCLbuffer(); - } else { - // 图模式才需要申请 scratch - u64 scratchSize = GetAlltoall2LevelPipelineScratchSize910B(userRank_, allMeshAggregationSendRecvInfo); - CHK_RET(hcclImpl_->BuildAlltoAllVScratchMem(tag, scratchSize)); - DeviceMem scratchMem; - CHK_RET(hcclImpl_->GetScratchMem(scratchMem, tag)); - CHK_RET(hcclImpl_->CreateCommForNoScratchAlltoall(tag, sendBuf, recvBuf, scratchMem)); - a2aPipelineMemory.scratchMem = scratchMem; - } - std::unique_ptr alltoallPipe = nullptr; - if (cclEnough) { - alltoallPipe.reset(new (std::nothrow)AlltoallPipelineMeshPairwiseCCLEnough(dispatcher_, - allMeshAggregationSendRecvInfo, GetWorkflowMode())); - } else { - alltoallPipe.reset(new (std::nothrow)AlltoallPipelineMeshPairwisePingPong(dispatcher_, - allMeshAggregationSendRecvInfo, GetWorkflowMode())); - } - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - CHK_RET(hcclImpl_->RegisterToHeartBeat()); - hcclImpl_->CreateMutiStreamRes(tag, stream, algType_); - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - alltoallPipe->Prepare(userRank_, a2aPipelineMemory, currComm->commOuter[0], currComm->commInner[0], - stream, streamInfo->ringStreams, streamInfo->ringSignal, streamInfo->ringSignalAux); - alltoallPipe->RunAsync(); - HCCL_INFO("[AlltoAllOperator][RunAlltoAllVTwoLevelPipeline] alltoall two level pipeline end"); - return HCCL_SUCCESS; -} - -HcclResult AlltoAllOperator::RunAlltoAllVStaged(DeviceMem &sendBuf, HcclDataType sendType, DeviceMem &recvBuf, - HcclDataType recvType, std::vector &allMeshAggregationSendRecvInfo, - Stream &stream, const std::string &tag) -{ - CHK_PRT_RET(userRankSize_ % meshAggregationRankSize_ != 0, - HCCL_ERROR("userRankSize[%u] is not an Integer multiple of MeshAggregation Dev Num[%u]", - userRankSize_, meshAggregationRankSize_), HCCL_E_PARA); - HcclUs startut = TIME_NOW(); - - // 1 申请中转内存,2. 创建第一级通信域,3. 下发第一级alltoallv 4. 创建第二级通信域 5. 下发第二级 alltoallv - AlltoAllUserRankInfo userRankInfo; - userRankInfo.userRank = userRank_; - userRankInfo.userRankSize = userRankSize_; - u64 workSpaceMemSize = 0; - - AlltoAllVStagedCalculator::CalcWorkSpaceMemSize(userRankInfo, allMeshAggregationSendRecvInfo, - workSpaceMemSize, meshAggregationRankSize_); - CHK_RET(hcclImpl_->BuildAlltoAllVScratchMem(tag, workSpaceMemSize)); - hcclImpl_->CheckStagedAlltoAllNeedRecreateComm(allMeshAggregationSendRecvInfo, tag); - - DeviceMem scratchMem; - hcclImpl_->GetScratchMem(scratchMem, tag); - bool alltoallMeshReadOnly = FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition( - deviceType_, meshAggregationRankSize_); - CHK_RET(hcclImpl_->CreateCommForAlltoallVStaged(tag, sendBuf, recvBuf, scratchMem, alltoallMeshReadOnly)); - CHK_RET(hcclImpl_->RegisterToHeartBeat()); - - // 此处统计只统计与通信域创建相关的耗时 - HCCL_INFO("resource creation (AlltoAllVC Staged) success, take time [%lld]us, tag[%s]", - DURATION_US(TIME_NOW() - startut), tag.c_str()); - - std::map> sendAddrInfosIntra; - std::map> recvAddrInfosIntra; - bool isSingleMesh = GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - isAlltoAllZCopyMode_ && isSingleMeshAggregation_; - AlltoAllVStagedCalculator::CalcIntraMeshAggregationAlltoAllMemInfo(userRankInfo, allMeshAggregationSendRecvInfo, - sendAddrInfosIntra, recvAddrInfosIntra, meshAggregationRankSize_, isSingleMesh); - - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - - if (alltoallMeshReadOnly) { - HCCL_RUN_INFO("[AlltoAllOperator][RunAlltoAllVStaged] staged 1 read only algo"); - HcclResult ret = hcclImpl_->CreateMutiStreamRes(tag, stream, algType_, false, meshAggregationRankSize_); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AlltoAllOperator][AlltoAllv]errNo[0x%016llx] tag[%s],\ - alltoallv create stream resource", HCCL_ERROR_CODE(ret), tag.c_str()), ret); - - u32 rankSize = meshAggregationRankSize_; - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { - for (u32 streamIndex = 0; streamIndex < rankSize - 1; streamIndex++) { // 从stream 个数 = ranksize -2 - ret = StreamActiveManager::GetInstance(deviceLogicId_).StreamActive( - streamInfo->ringStreams[streamIndex].ptr(), stream.ptr()); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[AlltoAllOperator][ActiveRingStreams]stream[%u] active failed,return[%d]", - streamIndex, ret), ret); - } - } - // 添加从流profiling, 用于维护planID - CHK_RET(hcclImpl_->AddSubStreamToProfiling(tag, HcclCMDType::HCCL_CMD_ALLTOALL)); - std::unique_ptr alltoallReadOnly = nullptr; - if (GetExternalInputHcclEnableFfts()) { - alltoallReadOnly.reset(new (std::nothrow) AlltoAllVMeshReadOnly(dispatcher_, stream, - streamInfo->ringStreams, streamInfo->ringSignal, streamInfo->ringSignalAux, userRank_, - meshAggregationRankSize_, currComm->commOuter[0]->TransportInfo(), allMeshAggregationSendRecvInfo)); - } else { - alltoallReadOnly.reset(new (std::nothrow) AlltoAllVMeshReadOnly(dispatcher_, stream, - streamInfo->ringStreams, streamInfo->ringSignal, streamInfo->ringSignalAux, userRank_, - meshAggregationRankSize_, currComm->commOuter[0]->TransportInfo(), allMeshAggregationSendRecvInfo)); - } - - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(alltoallReadOnly->Prepare(sendBuf, (isSingleMeshAggregation_ ? recvBuf : scratchMem), - cclBufferManager_.GetInCCLbuffer(), cclBufferManager_.GetOutCCLbuffer(), sendAddrInfosIntra, - recvAddrInfosIntra, GetWorkflowMode())); - } else { - CHK_RET(alltoallReadOnly->Prepare(sendBuf, (isSingleMeshAggregation_ ? recvBuf : scratchMem), sendBuf, - recvBuf, sendAddrInfosIntra, recvAddrInfosIntra, GetWorkflowMode())); - } - alltoallReadOnly->RunAsync(); - } else { - std::unique_ptr alltoallOuter = nullptr; - - CHK_RET(PrepareAlltoAllVStaged1(sendBuf, recvBuf, scratchMem, sendAddrInfosIntra, - recvAddrInfosIntra, stream, tag, alltoallOuter)); - - innerStreamInfo_t* streamInfo = hcclImpl_->GetStreamInfoWithoutCheck(tag); - if ((streamInfo->ringStreams.size() != 0) && - (!GetExternalInputHcclEnableFfts()) && isAlltoAllZCopyMode_) { - CHK_RET(currComm->commOuter[0]->RunAlltoAllVStagedMesh(alltoallOuter)); - // 多流场景下,并行多线程下发task处理 - CHK_RET(hcclImpl_->ParallelTaskLoaderProcess(tag, stream)); - } else { - CHK_RET(currComm->commOuter[0]->RunAlltoAllVStaged(alltoallOuter)); - } - - HCCL_INFO("[hcclImpl][RunAlltoAllVStaged] stage0 run success!"); - } - std::map> sendAddrInfosInter; - std::map> recvAddrInfosInter; - AlltoAllVStagedCalculator::CalcInterMeshAggregationAlltoAllMemInfo(userRankInfo, - allMeshAggregationSendRecvInfo, sendAddrInfosInter, recvAddrInfosInter, meshAggregationRankSize_); - - if (((GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - isAlltoAllZCopyMode_) || alltoallMeshReadOnly) && isSingleMeshAggregation_) { - // we don't need to do stage 2 when there is only one mesh aggregation - } else { - std::unique_ptr alltoallInner = nullptr; - PrepareAlltoAllVStaged2(recvBuf, scratchMem, sendAddrInfosInter, recvAddrInfosInter, - stream, tag, alltoallInner); - CHK_RET(currComm->commInner[0]->RunAlltoAllVStaged(alltoallInner)); // 第二级alltoallv - } - - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - isAlltoAllZCopyMode_ && !isSingleMeshAggregation_) { - auto outCCLbuffer = cclBufferManager_.GetOutCCLbuffer(); - DeviceMem srcMem = outCCLbuffer.range(0, recvBuf.size()); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, recvBuf, srcMem, stream)); - } - return HCCL_SUCCESS; -} - -bool AlltoAllOperator::NAFullmeshSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize) -{ - return false; -} - -bool AlltoAllOperator::FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize) -{ - return false; -} - -HcclResult AlltoAllOperator::RunAlltoAllVFullMesh(DeviceMem &sendBuf, HcclDataType sendType, - DeviceMem &recvBuf, HcclDataType recvType, std::vector &allMeshAggregationSendRecvInfo, - Stream &stream, const std::string &tag) -{ - bool ZCopyMode = GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - isAlltoAllZCopyMode_; - auto inCCLbuffer = cclBufferManager_.GetInCCLbuffer(); - auto outCCLbuffer = cclBufferManager_.GetOutCCLbuffer(); - - // 构造入参 - AlltoAllVBufferInfo sendInfo; - sendInfo.mem = ZCopyMode ? inCCLbuffer : sendBuf; - sendInfo.counts = &allMeshAggregationSendRecvInfo[userRank_].sendCounts[0]; - sendInfo.displs = &allMeshAggregationSendRecvInfo[userRank_].sendDispls[0]; - sendInfo.dataType = sendType; - - AlltoAllVBufferInfo recvInfo; - recvInfo.mem = ZCopyMode ? outCCLbuffer : recvBuf; - recvInfo.counts = &allMeshAggregationSendRecvInfo[userRank_].recvCounts[0]; - recvInfo.displs = &allMeshAggregationSendRecvInfo[userRank_].recvDispls[0]; - recvInfo.dataType = recvType; - - if (NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_)) { - HCCL_INFO("[AlltoAllOperator][RunAlltoAllVFullMesh] one level read only algo"); - HcclResult ret = hcclImpl_->CreateMutiStreamRes(tag, stream, algType_, false, userRankSize_); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[AlltoAllOperator][AlltoAllv]errNo[0x%016llx] tag[%s], " - "alltoallv create stream resource", HCCL_ERROR_CODE(ret), tag.c_str()), ret); - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { - for (u32 streamIndex = 0; streamIndex < userRankSize_ - 1; streamIndex++) { // 从stream 个数 = ranksize -2 - ret = StreamActiveManager::GetInstance(deviceLogicId_).StreamActive( - streamInfo->ringStreams[streamIndex].ptr(), stream.ptr()); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[AlltoAllOperator][ActiveRingStreams]stream[%u] active failed,return[%d]", - streamIndex, ret), ret); - } - } - CHK_RET(hcclImpl_->AddSubStreamToProfiling(tag, HcclCMDType::HCCL_CMD_ALLTOALL)); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - std::unique_ptr alltoallReadOnly = nullptr; - std::unique_ptr &commMeshPtr = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE ? - hcclImpl_->GetCommMesh() : hcclImpl_->GetCommMeshByTag(tag)); - alltoallReadOnly.reset(new (std::nothrow) AlltoAllVMeshReadOnly(dispatcher_, stream, - streamInfo->ringStreams, streamInfo->ringSignal, streamInfo->ringSignalAux, userRank_, userRankSize_, - commMeshPtr->TransportInfo(), allMeshAggregationSendRecvInfo)); - - CHK_SMART_PTR_NULL(alltoallReadOnly); - AlltoAllUserRankInfo userRankInfo; - userRankInfo.userRank = userRank_; - userRankInfo.userRankSize = userRankSize_; - std::map> sendAddrInfosIntra; - std::map> recvAddrInfosIntra; - AlltoAllVStagedCalculator::CalcIntraMeshAggregationAlltoAllMemInfo(userRankInfo, - allMeshAggregationSendRecvInfo, sendAddrInfosIntra, recvAddrInfosIntra, userRankSize_, true); - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(alltoallReadOnly->Prepare(sendBuf, recvBuf, inCCLbuffer, outCCLbuffer, sendAddrInfosIntra, - recvAddrInfosIntra, GetWorkflowMode())); - } else { - CHK_RET(alltoallReadOnly->Prepare(sendBuf, recvBuf, sendBuf, recvBuf, sendAddrInfosIntra, - recvAddrInfosIntra, GetWorkflowMode())); - } - alltoallReadOnly->RunAsync(); - - return HCCL_SUCCESS; - } - - // 执行算法 - std::unique_ptr pairWisePtr = nullptr; - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - !isAlltoAllZCopyMode_) { // 单算子 && Buffer Copy模式 - std::unique_ptr &commMeshPtr = hcclImpl_->GetCommMesh(); - pairWisePtr.reset(new (std::nothrow)AlltoAllVPairWise(dispatcher_)); - CHK_SMART_PTR_NULL(pairWisePtr); - CHK_RET(pairWisePtr->Prepare(sendInfo, recvInfo, inCCLbuffer, outCCLbuffer, isAlltoAllZCopyMode_, stream)); - CHK_RET(commMeshPtr->RunAlltoAll(pairWisePtr)); - } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && - isAlltoAllZCopyMode_) { - std::map> rankSendDisplsMap; - std::map> rankRecvDisplsMap; - for (u32 i = 0; i < userRankSize_; i++) { - rankSendDisplsMap.insert(std::pair>(i, allMeshAggregationSendRecvInfo[i].sendOffset)); - rankRecvDisplsMap.insert(std::pair>(i, allMeshAggregationSendRecvInfo[i].recvOffset)); - } - - pairWisePtr.reset(new (std::nothrow)AlltoAllVPairWise(dispatcher_, rankSendDisplsMap, rankRecvDisplsMap, - HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE)); - CHK_SMART_PTR_NULL(pairWisePtr); - CHK_SMART_PTR_NULL(inCCLbuffer.ptr()); - CHK_SMART_PTR_NULL(outCCLbuffer.ptr()); - DeviceMem dstMem = inCCLbuffer.range(0, sendBuf.size()); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, sendBuf, stream)); - - CHK_RET(pairWisePtr->Prepare(sendInfo, recvInfo, inCCLbuffer, outCCLbuffer, isAlltoAllZCopyMode_, stream)); - std::unique_ptr &commMeshPtr = hcclImpl_->GetCommMesh(); - CHK_RET(commMeshPtr->RunAlltoAll(pairWisePtr)); // inCCLbuffer -> outCCLbuffer - DeviceMem srcMem = outCCLbuffer.range(0, recvBuf.size()); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, recvBuf, srcMem, stream)); - } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { - std::map> rankSendDisplsMap; - std::map> rankRecvDisplsMap; - for (u32 i = 0; i < userRankSize_; i++) { - rankSendDisplsMap.insert(std::pair>(i, allMeshAggregationSendRecvInfo[i].sendOffset)); - rankRecvDisplsMap.insert(std::pair>(i, allMeshAggregationSendRecvInfo[i].recvOffset)); - } - - pairWisePtr.reset(new (std::nothrow)AlltoAllVPairWise(dispatcher_, rankSendDisplsMap, rankRecvDisplsMap, - HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB)); - CHK_SMART_PTR_NULL(pairWisePtr); - CHK_RET(pairWisePtr->Prepare(sendInfo, recvInfo, isAlltoAllZCopyMode_, stream)); - // 保证最新的commMesh是为该次alltoallv创建(不支持多线程) - std::unique_ptr &commMeshPtr = hcclImpl_->GetCommMeshByTag(tag); - CHK_RET(commMeshPtr->RunAlltoAll(pairWisePtr)); - } else { - HCCL_ERROR("[hcclImpl][RunAlltoAllVFullMesh]work flow mode is invalid"); - return HCCL_E_PARA; - } - return HCCL_SUCCESS; -} +REGISTER_OP(HcclCMDType::HCCL_CMD_ALLTOALLV, AlltoAllV, AlltoAllOperator); +REGISTER_OP(HcclCMDType::HCCL_CMD_ALLTOALL, AlltoAll, AlltoAllOperator); +REGISTER_OP(HcclCMDType::HCCL_CMD_ALLTOALLVC, AlltoAllVC, AlltoAllOperator); } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.h b/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.h index 528cf547a2431ad056b2c7eff035c2035991f03e..5e5eb920d29d9d9516d9a380a669bf718261e9d3 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/alltoall_operator.h @@ -12,32 +12,35 @@ #define ALLTOALL_OPERATOR_H #include "coll_alg_operator.h" - namespace hccl { -constexpr u64 MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH = 16; - class AlltoAllOperator : public CollAlgOperator { public: - AlltoAllOperator(std::unique_ptr &pImpl); + AlltoAllOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~AlltoAllOperator(); - HcclResult AlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag); - HcclResult AlltoAllVOutPlace(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag); - HcclResult AlltoAllVC(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag); - HcclResult AlltoAllVCOutPlace(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag); - HcclResult AlltoAll(const void *sendBuf, u64 sendCount, HcclDataType sendType, - const void *recvBuf, u64 recvCount, HcclDataType recvType, Stream stream, const std::string &tag); + HcclResult GetAlltoAllStagedWorkSpaceMemSize(u64 *sendCounts, u64 *sdispls, HcclDataType sendType, u64 *recvCounts, u64 *rdispls, HcclDataType recvType, u64 &memSize); HcclResult GetAlltoAllStagedWorkSpaceMemSize(std::vector &allMeshAggregationSendRecvInfo, u64 &memSize); - static bool NAFullmeshSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize); - static bool FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize); + + HcclResult CheckSendRecvParams(const std::vector &allMeshAggregationSendRecvInfo); + HcclResult GetAlltoAllvSendRecvInfo(const OpParam& param, const HostMem &alltoallAddrInfoGathered); + HcclResult GetAlltoAllvcSendRecvInfo(const void *sendCountMatrix, HcclDataType sendType, HcclDataType recvType); + void UpdateAlltoAllCopyMode(std::vector &allMeshAggregationSendRecvInfo, std::string& copyMode); + HcclResult SelectAlgforAlltoAll(const OpParam& param, std::string& algName, std::string& copyMode); + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); + + HcclResult GetAlltoAllvAllAddrInfo(u64 *sendLength, u64 *sendOffset, u64 *recvLength, u64 *recvOffset, + Stream &stream, std::unique_ptr &preMetaInfo); + HcclResult PrepareAlltoAllAddrInfo(const void *sendCounts, const void *sdispls, HcclDataType sendType, + const void *recvCounts, const void *rdispls, HcclDataType recvType, Stream &stream, + std::unique_ptr &preMetaInfo); + HcclResult PreparePreOpParam(OpParam& preProcessOpParam, const std::unique_ptr &preMetaInfo, + Stream &preProcessStream); + bool JudgeIfNeedPreProcessAndGetParam(const OpParam& param, std::unique_ptr &preMetaInfo); + void SetPreProcessResult(HostMem hostCollectBuffer); + HcclResult SetExcutorExtraInfo(const std::string& algName); + private: std::vector GenerateSendCountMatrix(u64 count, u32 rankSize); @@ -54,7 +57,6 @@ private: Stream &stream, const std::string &tag); HcclResult RunAlltoAllVStaged(DeviceMem &sendBuf, HcclDataType sendType, DeviceMem &recvBuf, HcclDataType recvType, std::vector &allMeshAggregationSendRecvInfo, Stream &stream, const std::string &tag); - HcclResult PrepareAlltoAllVStaged1(DeviceMem &sendBuf, DeviceMem &recvBuf, DeviceMem &scratchMem, std::map> &sendAddrInfosIntra, std::map> &recvAddrInfosIntra, @@ -85,15 +87,11 @@ private: u64 inputCount, HcclDataType dataType, Stream stream); bool HasMassTasks(std::vector &allMeshAggregationSendRecvInfo); - HcclResult AlltoAllVForOneRankSize(const void *sendBuf, const void *sendCounts, const void *sdispls, - HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, - Stream stream, const std::string &tag); - HcclResult AlltoAllVCForOneRankSize(const void *sendBuf, const void *sendCountMatrix, HcclDataType sendType, - const void *recvBuf, HcclDataType recvType, Stream stream, const std::string &tag); - bool isAlltoAllZCopyMode_ = false; std::map isAlltoAllZCopyModeMap_; DeviceMem tinySendRecvMem_; // 在sendCount/recvCount全0时, 使用tinySendRecvMem_, 避免使用空deviceMem + HostMem hostCollectBuffer_; + std::vector allMeshAggregationSendRecvInfo_; }; } diff --git a/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.cc index a33309e20df00352ab1459c7e1650574020cdecd..4b8d8f75dd38a5dd012986bb12ca6ebcc8c26243 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.cc @@ -11,8 +11,9 @@ #include "batchsendrecv_operator.h" namespace hccl { -BatchSendRecvOperator::BatchSendRecvOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) +BatchSendRecvOperator::BatchSendRecvOperator(std::unique_ptr &pImpl, + std::unique_ptr &topoMatcher) + : CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) { } BatchSendRecvOperator::~BatchSendRecvOperator() { diff --git a/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.h b/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.h index 8345483a21ee7be7e46348e121d9c9166e1498ce..e532478fac0b04f3e438b28dfed0737a47dcbf32 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/batchsendrecv_operator.h @@ -18,7 +18,7 @@ namespace hccl { class BatchSendRecvOperator : public CommonOperator { public: - BatchSendRecvOperator(std::unique_ptr &pImpl); + BatchSendRecvOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~BatchSendRecvOperator(); HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); }; diff --git a/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.cc index 10651347df0f9832b393c92ca16e3d46ea964f3c..489fa921a7ba12135f55f66c3b5e07d25bf5bbf1 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.cc @@ -13,10 +13,11 @@ #include "rank_consistent.h" #include "executor_impl.h" #include "stream_active_manager.h" +#include "coll_alg_op_registry.h" namespace hccl { -BroadCastOperator::BroadCastOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_BROADCAST) +BroadCastOperator::BroadCastOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_BROADCAST) { // 由于bcast/allgather/reducescatter/reduce/send/recv暂不支持server间ring,需继续使用HD或NHR if (!UseInterServerNHRAlgo(algType_) && !UseInterServerNHRV1Algo(algType_) && !UseInterServerNBAlgo(algType_)) { @@ -185,6 +186,8 @@ HcclResult BroadCastOperator::BroadcastOutPlace(const std::string &tag, void *pt auto originalAlgTypeLevel0 = GetLevel0AlgType(algType_); bool isMeshTopo = IsAlgTypeLevel0Mesh(originalAlgTypeLevel0); + bool isDMAreduceOn91073 = (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE + && (deviceType_ == DevType::DEV_TYPE_910_73) && !isMeshTopo); std::string newTag = tag; if (UseInterServerHDAlgo(algType_)) { @@ -220,14 +223,23 @@ HcclResult BroadCastOperator::BroadcastOutPlace(const std::string &tag, void *pt HCCL_INFO("BroadcastOutPlace:curPtr[%p], curCount[%llu], curSize[%llu]", curPtr, curCount, curSize); HcclResult ret; /* 入参的正确性由HCCL确保 */ - DeviceMem commMem = inCCLbuffer.range(0, curSize); - DeviceMem userMem(curPtr, curSize); - if (userRank_ == root) { // 本rank为root节点,非root节点不需要拷贝到中转内存 - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, commMem, userMem, stream)); - } - ret = Broadcast(newTag, inCCLbuffer.ptr(), curCount, dataType, root, stream); - if (realUserRank_ != root) { - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, userMem, commMem, stream)); + if (isDMAreduceOn91073) { + HcomCollOpInfo opInfo; + opInfo.inputAddr = curPtr; + opInfo.outputAddr = curPtr; + opInfo.count = count; + opInfo.dataType = dataType; + ret = Broadcast(newTag, inCCLbuffer.ptr(), curCount, dataType, root, stream, &opInfo); + } else { + DeviceMem commMem = inCCLbuffer.range(0, curSize); + DeviceMem userMem(curPtr, curSize); + if (userRank_ == root) { // 本rank为root节点,非root节点不需要拷贝到中转内存 + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, commMem, userMem, stream)); + } + ret = Broadcast(newTag, inCCLbuffer.ptr(), curCount, dataType, root, stream); + if (realUserRank_ != root) { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, userMem, commMem, stream)); + } } CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Loop][Broadcast]errNo[0x%016llx] OP_BASE hcclComm broadcast, tag[%s], input_ptr[%p], " @@ -1158,4 +1170,114 @@ HcclResult BroadCastOperator::GetRankSliceSize(HcclDataType dataType, const u64 return HCCL_SUCCESS; } + +HcclResult BroadCastOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) +{ + HcclResult ret; + if (Is310P3Common()) { + ret = SelectAlgfor310P3(param, algName); + } else if (Is310PDevice() && topoType_ == TopoType::TOPO_TYPE_2P_MESH) { + ret = SelectAlgfor310P(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910) { + ret = SelectAlgfor910A(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910B) { + ret = SelectAlgfor910B(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910_73) { + ret = SelectAlgfor91073(param, algName); + } else { + HCCL_ERROR("[SelectAlg] device type[%d] is out of range for selector.", deviceType_); + return HCCL_E_NOT_SUPPORT; + } + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + newTag = tag; + } else { + if (UseInterServerHDAlgo(algType_)) { + u32 part1Size = 2 * (moduleNum_ - (1 << static_cast(log2(moduleNum_)))); + u32 rootId = param.root / deviceNumPerAggregation_; + std::string appendTag = std::to_string((rootId >= part1Size) || ((rootId % 2) == 0)); + newTag = newTag + '_' + appendTag; + if (param.opBaseAtraceInfo != nullptr) { + CHK_RET(param.opBaseAtraceInfo->SavealgtypeTraceInfo(appendTag, param.tag)); + } + } else if (Is310P3Common()) { + newTag = tag + algName; + } else { + AlgTypeLevel1 algType1 = GetLevel1AlgType(algType_); + auto level1Iter = HCCL_ALGO_LEVEL1_NAME_MAP.find(algType1); + newTag = tag + level1Iter->second + algName; + } + } + HCCL_INFO("[SelectAlg] broadcast newTag is [%s]", newTag.c_str()); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[BroadCastSelector][SelectAlg]tag[%s], broadcast failed, return[%d]", tag.c_str(), ret), ret); + return ret; +} + +HcclResult BroadCastOperator::SelectAlgfor310P3(const OpParam& param, std::string& algName) +{ + algName = "BroadCastCommFor310P"; + HCCL_INFO("[SelectAlgfor310P3] broadcast SelectAlgfor310P3 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult BroadCastOperator::SelectAlgfor310P(const OpParam& param, std::string& algName) +{ + algName = "BroadcastPlusBroadcast"; + HCCL_INFO("[SelectAlgfor310P] broadcast SelectAlgfor310P is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult BroadCastOperator::SelectAlgfor910A(const OpParam& param, std::string& algName) +{ + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_4P_MESH || topoType_ == TopoType::TOPO_TYPE_2P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING || topoType_ == TopoType::TOPO_TYPE_8P_RING; + + if (isMeshTopo) { + algName = "BroadCastMeshExecutor"; + } else if (topoType_ == TopoType::TOPO_TYPE_4P_RING) { + algName = "BroadCast4pRingExecutor"; + } else if (isRingTopo) { + algName = "BroadCastRingExecutor"; + } else { + algName = "BroadCastComm"; + } + HCCL_INFO("[SelectAlgfor910A] broadcast SelectAlgfor910A is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult BroadCastOperator::SelectAlgfor910B(const OpParam& param, std::string& algName) +{ + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || + topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING || topoType_ == TopoType::TOPO_TYPE_8P_RING; + + if (isMeshTopo) { + algName = "BroadCastMeshExecutor"; + } else if (topoType_ == TopoType::TOPO_TYPE_4P_RING) { + algName = "BroadCast4pRingExecutor"; + } else if (isRingTopo) { + algName = "BroadCastRingExecutor"; + } else { + algName = "BroadCastComm"; + } + HCCL_INFO("[SelectAlgfor910B] broadcast SelectAlgfor910B is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult BroadCastOperator::SelectAlgfor91073(const OpParam& param, std::string& algName) +{ + if (topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING) { + algName = "BroadCastRingExecutor"; + } else if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + algName = "BroadCastDoubleRingExecutor"; + } else { + algName = "BroadCastComm"; + } + HCCL_INFO("[SelectAlgfor91073] broadcast SelectAlgfor91073 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_BROADCAST, Broadcast, BroadCastOperator); + } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.h b/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.h index 9c72752aa4debf96cd6101cc5369cddf7f5f5624..6775ee7a9664264d1d00a40105d978da19ebb36d 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/broadcast_operator.h @@ -16,12 +16,13 @@ namespace hccl { class BroadCastOperator : public CommonOperator { public: - BroadCastOperator(std::unique_ptr &pImpl); + BroadCastOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~BroadCastOperator(); HcclResult Broadcast(const std::string &tag, void *ptr, u64 count, HcclDataType dataType, u32 root, Stream stream, HcomCollOpInfo *opInfo = nullptr); HcclResult BroadcastOutPlace(const std::string &tag, void *ptr, u64 count, HcclDataType dataType, u32 root, Stream stream, const std::unique_ptr &opBaseAtraceInfo = nullptr); + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); private: // broadcast @@ -58,6 +59,16 @@ private: std::vector &sliceList); bool IsBroadcastSmallData(u64 size); + + HcclResult SelectAlgfor310P3(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor310P(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor910A(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor910B(const OpParam& param, std::string& algName); + + HcclResult SelectAlgfor91073(const OpParam& param, std::string& algName); }; } diff --git a/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc index bf69f2e01bb08b88b0249b08f6f6a807e4d31ade..a8673dba1fe0d7f6010215b83b3ee81e117aa541 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc @@ -35,11 +35,13 @@ constexpr u64 PIPELINE_MIN_SIZE = 32 * 1024; // 当数据量大于等于32KB时 constexpr u64 PIPELINE_ALLREDUCE_MIN_SIZE = 1024 * 1024; // 当数据量大于等于1MB时,allreduce使能pipeline模式 constexpr u64 PIPELINE_MIN_SIZE_NO_LITE = 2 * 1024 * 1024; // 如不支持RDMALite,当数据量大于等于2MB时,使能pipeline模式 -CollAlgOperator::CollAlgOperator(std::unique_ptr &pImpl, HcclCMDType opType) - : dispatcher_(pImpl->dispatcher_), vDispatcher_(pImpl->vDispatcher_), - cclBufferManager_(pImpl->cclBufferManager_), notifyPool_(pImpl->notifyPool_), - rankInfoList_(pImpl->rankInfoList_), hcclImpl_(pImpl) +CollAlgOperator::CollAlgOperator(std::unique_ptr &pImpl, + std::unique_ptr &topoMatcher, HcclCMDType opType) + : cclBufferManager_(pImpl->cclBufferManager_), notifyPool_(pImpl->notifyPool_), + rankInfoList_(pImpl->rankInfoList_), hcclImpl_(pImpl), topoMatcher_(topoMatcher) { + hcclImpl_->GetDispatcher(dispatcher_); + hcclImpl_->GetVirtualDispatcher(vDispatcher_); SetTopoAttr(); SetAlgoAttr(); hcclImpl_->GetAlgTypeDirect(algType_, opType); @@ -53,15 +55,24 @@ HcclResult CollAlgOperator::SelectAlg(const std::string& tag, return HCCL_SUCCESS; } +bool CollAlgOperator::CheckNeedRecreateComm(const std::string& algName, u64 lastScratchMemSize) +{ + if (executor_.get() == nullptr) { + executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, dispatcher_, topoMatcher_); + SetExecutorAttr(); + } + return executor_->CheckNeedRecreateComm(lastScratchMemSize); +} + HcclResult CollAlgOperator::CalcResRequest(const std::string& algName, const OpParam& param, AlgResourceRequest& resourceRequest) { if (executor_.get() == nullptr) { - executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, hcclImpl_); + executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, dispatcher_, topoMatcher_); CHK_PRT_RET(executor_.get() == nullptr, HCCL_ERROR("[CollAlgOperator][CalcResRequest]Fail to find executor for algName[%s]", algName.c_str()), HCCL_E_PARA); - executor_->SetAlgType(algType_); + SetExecutorAttr(); } return executor_->CalcResRequest(param, resourceRequest); } @@ -71,38 +82,45 @@ HcclResult CollAlgOperator::Orchestrate(const std::string& algName, const OpPara { HCCL_INFO("[CollAlgOperator][Orchestrate]algName[%s]", algName.c_str()); if (executor_.get() == nullptr) { - executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, hcclImpl_); + executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, dispatcher_, topoMatcher_); CHK_PRT_RET(executor_.get() == nullptr, HCCL_ERROR("[CollAlgOperator][Orchestrate]Fail to find executor for algName[%s]", algName.c_str()), HCCL_E_PARA); - executor_->SetAlgType(algType_); + SetExecutorAttr(); } + return executor_->Orchestrate(param, algResource); } -bool CollAlgOperator::NeedIncrCreateLink(const std::string& algName, const OpParam& param) -{ - if (executor_.get() == nullptr) { - executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, hcclImpl_); - CHK_PRT_RET(executor_.get() == nullptr, - HCCL_ERROR("[BatchSendRecvOperator][NeedIncrCreateLink]Fail to find executor for algName[%s]", - algName.c_str()), HCCL_E_PARA); - } - return executor_->NeedIncrCreateLink(param); -} - HcclResult CollAlgOperator::CalcIncreLinkRequest(const std::string& algName, const OpParam& param, AlgResourceRequest& resourceRequest) { if (executor_.get() == nullptr) { - executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, hcclImpl_); + executor_ = CollAlgExecRegistry::Instance()->GetAlgExec(algName, dispatcher_, topoMatcher_); CHK_PRT_RET(executor_.get() == nullptr, - HCCL_ERROR("[BatchSendRecvOperator][NeedIncrCreateLink]Fail to find executor for algName[%s]", + HCCL_ERROR("[BatchSendRecvOperator][CalcIncreLinkRequest]Fail to find executor for algName[%s]", algName.c_str()), HCCL_E_PARA); } return executor_->CalcIncreLinkRequest(param, resourceRequest); } +bool CollAlgOperator::JudgeIfNeedPreProcessAndGetParam(const OpParam& param, + std::unique_ptr &preMetaInfo) +{ + return false; +} + +HcclResult CollAlgOperator::PreparePreOpParam(OpParam& preProcessOpParam, + const std::unique_ptr &preMetaInfo, Stream &preProcessStream) +{ + return HCCL_SUCCESS; +} + +void CollAlgOperator::SetPreProcessResult(HostMem hostCollectBuffer) +{ + return; +} + void CollAlgOperator::SetTopoAttr() { serverNum_= hcclImpl_->serverNum_; @@ -141,6 +159,17 @@ void CollAlgOperator::SetAlgoAttr() return; } +void CollAlgOperator::SetExecutorAttr() +{ + executor_->SetAlgType(algType_); + executor_->SetVirtualDispatcher(vDispatcher_); + executor_->SetCCLInBuffer(hcclImpl_->GetInCCLbufferSize()); + ParallelTaskLoader* parallelTaskLoader = nullptr; + hcclImpl_->GetParallelTaskLoader(parallelTaskLoader); + executor_->SetParallelTaskLoader(parallelTaskLoader); + return; +} + std::string CollAlgOperator::GenerateNewTagByAlgTypeLevel1(std::string tag, std::string algTypeLevel1Tag) const { if (algTypeLevel1Tag == "") { @@ -234,7 +263,7 @@ HcclResult CollAlgOperator::GetDefaultAlgoLevel1V2(HcclCMDType hcclCMDType, u64 // 对于不支持Rdma Lite的场景,下发性能较差,RS和AG需要一个很大的数据量(AR的一半)才能掩盖下发时间 u64 pipelineMinSize = (isSupportRdmaLite_) ? (PIPELINE_MIN_SIZE) : (PIPELINE_MIN_SIZE_NO_LITE); if (((hcclCMDType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER && isInlineReduce && isRdmaReduce && - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE) || + topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE) || hcclCMDType == HcclCMDType::HCCL_CMD_ALLGATHER) && deviceNumPerAggregation_ != 1 && curSize >= pipelineMinSize && IsAlgTypeLevel0Mesh(originalAlgTypeLevel0)) { algType = AlgTypeLevel1::ALG_LEVEL1_PIPELINE; @@ -247,7 +276,7 @@ HcclResult CollAlgOperator::GetDefaultAlgoLevel1V2(HcclCMDType hcclCMDType, u64 // 计算每个slice的大小 u64 allreduceCurSize = 0; allreduceCurSize = curSize / (moduleNum_ * deviceNumPerAggregation_); - if ((isInlineReduce && isRdmaReduce) && hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && + if ((isInlineReduce && isRdmaReduce) && topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && deviceNumPerAggregation_ != 1 && allreduceCurSize >= pipelineMinSize && !isAivMode && IsAlgTypeLevel0Mesh(originalAlgTypeLevel0)) { algType = AlgTypeLevel1::ALG_LEVEL1_PIPELINE; @@ -378,6 +407,7 @@ HcclResult CollAlgOperator::SetInterServerHDAlgo(AlgType &algType) const case AlgType::ALG_NP_DOUBLE_RING_PLUS_PIPELINE: case AlgType::ALG_DOUBLE_RING_PLUS_RING: + case AlgType::ALG_NP_DOUBLE_RING_PLUS_NB: algType = AlgType::ALG_DOUBLE_RING_PLUS_HD; break; default: @@ -657,4 +687,42 @@ HcclResult CollAlgOperator::SelectAlgoTypeForReduce(float delay, u64 curSize, fl return HCCL_SUCCESS; } +bool CollAlgOperator::NAFullmeshSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize) +{ + bool rankSizeSupport = (rankSize <= MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH); + bool isDevice91073 = (deviceType == DevType::DEV_TYPE_910_73); + bool oneLevelUseMesh = + (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_NA && + GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_FULLMESH); + bool isHCCS = !GetExternalInputInterHccsDisable(); + HCCL_DEBUG("[CollAlgOperator][AlltoAllVCOutPlace]isDevice91073 %u oneLevelUseMesh %u isHCCS %u", + isDevice91073, oneLevelUseMesh, isHCCS); + CHK_PRT_CONT(!(oneLevelUseMesh && !isDevice91073), + HCCL_WARNING("[CollAlgOperator][NAFullmeshSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm only " + "support 91073 device type, use default algorithm type")); + CHK_PRT_CONT(!(oneLevelUseMesh && !isHCCS), + HCCL_WARNING("[CollAlgOperator][NAFullmeshSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm depends " + "on HCCS, use default algorithm type")); + return (isDevice91073 && oneLevelUseMesh && rankSizeSupport && isHCCS); +} + +bool CollAlgOperator::FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize) +{ + bool rankSizeSupport = (rankSize <= MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH); + bool isDevice91073 = (deviceType == DevType::DEV_TYPE_910_73); + bool twoLevelIntraUseMesh = + (GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[0] == HcclAlgoType::HCCL_ALGO_TYPE_FULLMESH && + GetExternalInputHcclAlgoConfig(HcclCMDType::HCCL_CMD_ALLTOALL)[1] == HcclAlgoType::HCCL_ALGO_TYPE_PAIRWISE); + bool isHCCS = !GetExternalInputInterHccsDisable(); + HCCL_DEBUG("[CollAlgOperator][AlltoAllVCOutPlace]isDevice91073 %u twoLevelIntraUseMesh %u isHCCS %u", + isDevice91073, twoLevelIntraUseMesh, isHCCS); + CHK_PRT_CONT(!(twoLevelIntraUseMesh && !isDevice91073), + HCCL_WARNING("[CollAlgOperator][FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm only " + "support 91073 device type, use default algorithm type")); + CHK_PRT_CONT(!(twoLevelIntraUseMesh && !isHCCS), + HCCL_WARNING("[CollAlgOperator][FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition] alltoall read only algorithm depends " + "on HCCS, use default algorithm type")); + return (isDevice91073 && twoLevelIntraUseMesh && rankSizeSupport && isHCCS); +} + } // namesapce hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.h b/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.h index c39967746b4b2220a6806e536d82e91f0736bee3..d6db31165f8f69f52048db99cae7315c3ac0a457 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.h @@ -18,14 +18,22 @@ #include "ccl_buffer_manager.h" #include "hccl_opbase_atrace_info_pub.h" #include "device_capacity.h" +#include "topo_matcher.h" #include "coll_alg_param.h" #include "coll_executor_base.h" namespace hccl { -class CollAlgOperator { +struct PreProcessMetaInfo { + HcclCMDType opType; + std::vector inputData; + u64 inputSize; + u64 outputSize; +}; +constexpr u64 MAX_ALLTOALL_MESH_ALGO_RANK_INTRA_MESH = 16; +class CollAlgOperator { public: - CollAlgOperator(std::unique_ptr &pImpl, HcclCMDType opType); + CollAlgOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher, HcclCMDType opType); virtual ~CollAlgOperator() = default; virtual HcclResult SelectAlg(const std::string& tag, @@ -35,10 +43,18 @@ public: virtual HcclResult Orchestrate(const std::string& algName, const OpParam& param, const AlgResourceResponse& algResource); // batchsendrecv判断是否需要增量建链 - bool NeedIncrCreateLink(const std::string& algName, const OpParam& param); HcclResult CalcIncreLinkRequest(const std::string& algName, const OpParam& param, AlgResourceRequest& resourceRequest); + virtual bool JudgeIfNeedPreProcessAndGetParam(const OpParam& param, + std::unique_ptr &preMetaInfo); + virtual HcclResult PreparePreOpParam(OpParam& preProcessOpParam, + const std::unique_ptr &preMetaInfo, Stream &preProcessStream); + virtual void SetPreProcessResult(HostMem hostCollectBuffer); + virtual bool CheckNeedRecreateComm(const std::string& algName, u64 lastScratchMemSize); + static bool NAFullmeshSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize); + static bool FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(DevType deviceType, u32 rankSize); + protected: bool IsAlgTypeLevel0Mesh(AlgTypeLevel0 &originalAlgTypeLevel0) const; std::string GenerateNewTagByAlgTypeLevel1(std::string tag, std::string algTypeLevel1Tag) const; @@ -161,8 +177,6 @@ protected: std::string identifier_; OpMode opMode; - const HcclDispatcher dispatcher_; // dispatcher放到最后析构 - const HcclDispatcher vDispatcher_; // virtualDispatcher放到最后析构 CCLBufferManager &cclBufferManager_; const std::unique_ptr ¬ifyPool_; @@ -189,6 +203,10 @@ protected: std::unordered_map pairLinkCounter_; // server内所有device间的链路类型计数 std::vector &rankInfoList_; // world group内rank的信息, 按照rank id递增依次排列 std::unique_ptr &hcclImpl_; + std::unique_ptr executor_; + std::unique_ptr &topoMatcher_; + HcclDispatcher dispatcher_; // dispatcher放到最后析构 + HcclDispatcher vDispatcher_; // virtualDispatcher放到最后析构 private: virtual HcclResult SelectAlgoTypeForReduceScatter(float delay, u64 recvCurSize, float bandWidth, @@ -211,6 +229,7 @@ private: AlgTypeLevel1 &algType, bool isInlineReduce = false, bool isRdmaReduce = false, bool isAivMode = false); void SetAlgoAttr(); void SetTopoAttr(); + void SetExecutorAttr(); std::map> selectFuncMap_ = { {HcclCMDType::HCCL_CMD_REDUCE_SCATTER, @@ -223,8 +242,6 @@ private: std::bind(&CollAlgOperator::SelectAlgoTypeForAllReduce, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4)}, }; - - std::unique_ptr executor_; }; } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/operator/common_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/common_operator.cc index 3c97fdc968f8d9325610a75480089d0b701c2698..1e79f79ac1b356dbe5024c57627940b7bdd4762a 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/common_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/common_operator.cc @@ -15,8 +15,9 @@ namespace hccl { -CommonOperator::CommonOperator(std::unique_ptr &pImpl, HcclCMDType opType) - : CollAlgOperator(pImpl, opType) +CommonOperator::CommonOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher, + HcclCMDType opType) + : CollAlgOperator(pImpl, topoMatcher, opType) { } @@ -30,10 +31,11 @@ HcclResult CommonOperator::CalUserMemSlices(const HcclDataType dataType, const H std::vector &userMemSlices) { if (opInfo == nullptr || opInfo->inputAddr == nullptr || opInfo->outputAddr == nullptr) { + // 910_73场景下,allreduce算子的userMem上的slice信息 userMemSlices = singleRingSliceZero; return HCCL_SUCCESS; } - + // 910_73场景下,reduce scatter和all gather算子的userMem上的slice信息 std::vector ring0 = multiRingsOrder[0]; for (u32 sliceIdx = 0; sliceIdx < singleRingSliceZero.size(); sliceIdx++) { Slice userMemSlice; @@ -68,7 +70,11 @@ HcclResult CommonOperator::GetSubStreamInfoOnOneRing(const innerStreamInfo_t &st std::vector> &mainSignalsInOneRing, std::vector> &subSignalsInOneRing) { - if (streamInfo.ringNum == OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING) { + if (GetExternalInputEnableRdmaSdmaConcurrent() && deviceType_ == DevType::DEV_TYPE_910_73) { + subStreamsInOneRing.push_back(streamInfo.ringStreams[ringIndex + RDMA_ADD_STREAMS_NUM]); + mainSignalsInOneRing.push_back(streamInfo.ringSignal[ringIndex + RDMA_ADD_STREAMS_NUM]); + subSignalsInOneRing.push_back(streamInfo.ringSignalAux[ringIndex + RDMA_ADD_STREAMS_NUM]); + } else if (streamInfo.ringNum == OUTER_PLANE_NUM_IN_NPRING_DOUBLE * STREAM_NUM_FOR_DMAREDUCE_ONE_RING) { // double ring subStreamsInOneRing.push_back(streamInfo.ringStreams[ringIndex + 1]); mainSignalsInOneRing.push_back(streamInfo.ringSignal[ringIndex + 1]); @@ -184,112 +190,6 @@ HcclResult CommonOperator::MultiRingGather(const std::string &tag, DeviceMem inp return HCCL_SUCCESS; } -HcclResult CommonOperator::MultiRingAllReduce(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, - const std::vector> &multRingsSliceZero, Stream stream, s32 profStage, - const u64 baseOffset) -{ - HcclResult ret = HCCL_SUCCESS; - u32 ringNum = multRingsSliceZero.size(); - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - - u64 reduceAttr = GetReduceAttr(inputMem, outputMem, dataType, reductionOp); - - CommInfo *currComm; - ret = hcclImpl_->GetCommInfo(currComm, tag); - std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); - - for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { - std::vector singleRingSliceZero = multRingsSliceZero[ringIndex]; - CHK_PRT_RET(singleRingSliceZero.empty(), - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]singleRingSliceZero is empty"), HCCL_E_INTERNAL); - - std::unique_ptr &commRing = currComm->commOuter[ringIndex]; - CHK_SMART_PTR_NULL(commRing); - - u32 rankSize = commRing->RankSize(); - u32 ringIndexOp = ringIndex; - std::unique_ptr executor; - executor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); - CHK_SMART_PTR_NULL(executor); - - if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline - CHK_RET(StreamActiveManager::GetInstance(deviceLogicId_).StreamActive( - streamInfo->ringStreams[ringIndex].ptr(), stream.ptr())); - } - - ret = LocalNotify::Wait(streamInfo->ringStreams[ringIndex], dispatcher_, - streamInfo->ringSignalAux[ringIndex], profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u] wait failed", \ - ringIndex), ret); - ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, - streamInfo->ringStreams[ringIndex], reductionOp, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, - baseOffset, ringNics[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u], allreduce(ring) prepare failed,"\ - "return[%d]", ringIndex, ret), ret); - - ret = executor->RegisterProfiler( - ((ringIndexOp + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + - (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commRing->Rank(), - profStage, HCCL_EXEC_STEP_NOT_SET, streamInfo->ringStreams[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u], allreduce(ring) register Profiler "\ - "failed,return[%d]", ringIndex, ret), ret); - - ret = commRing->RunExecutor(executor); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u], allreduce(ring) run failed,"\ - "return[%d]", ringIndex, ret), ret); - - ret = LocalNotify::Post(streamInfo->ringStreams[ringIndex], dispatcher_, streamInfo->ringSignal[ringIndex], - profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u] record failed", ringIndex), ret); - - ret = LocalNotify::Post(stream, dispatcher_, streamInfo->ringSignalAux[ringIndex], profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u] record failed", ringIndex), ret); - } else { // 主环 - executor.reset(new (std::nothrow) AllReduceRing(dispatcher_, reduceAttr)); - CHK_SMART_PTR_NULL(executor); - ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, stream, - reductionOp, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, baseOffset, ringNics[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u], allreduce(ring) prepare failed, "\ - "return[%d]", ringIndex, ret), ret); - - ret = executor->RegisterProfiler( - ((ringIndexOp + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + - (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commRing->Rank(), - profStage, HCCL_EXEC_STEP_NOT_SET, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u], allreduce(ring) register Profiler "\ - "failed,return[%d]", ringIndex, ret), ret); - - ret = commRing->RunExecutor(executor); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u], allreduce(ring) run failed, "\ - "return[%d]", ringIndex, ret), ret); - - for (u32 ring = 0; ring < (ringNum - 1); ring++) { - /* 等待executor执行完毕 */ - ret = LocalNotify::Wait(stream, dispatcher_, streamInfo->ringSignal[ring], profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingAllReduce]stream[%u] wait failed", ring), ret); - } - } - } - // 添加空task,保证执行时不乱序 - CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); - return HCCL_SUCCESS; -} - HcclResult CommonOperator::MultiRingAllGather(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const std::vector > multRingsSliceZero, Stream stream, s32 profStage, const u64 baseOffset, const HcomCollOpInfo *opInfo, @@ -319,6 +219,7 @@ HcclResult CommonOperator::MultiRingAllGather(const std::string &tag, DeviceMem CHK_PRT_RET(singleRingSliceZero.empty(), HCCL_ERROR("[CommonOperator][MultiRingAllGather]"\ "singleRingSliceZero is empty"), HCCL_E_INTERNAL); + // 910_73场景 生成userMemOut_上对应的slices std::vector userMemOutputSlices; if (multRingsUserMemSlice.size() == 0) { CHK_RET(CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, @@ -334,6 +235,7 @@ HcclResult CommonOperator::MultiRingAllGather(const std::string &tag, DeviceMem u32 rankSize = commRing->RankSize(); u32 ringIndexOp = ringIndex; + // 910_73场景 准备环中的从流 std::vector subStreamsInOneRing; std::vector> mainSignalsInOneRing; std::vector> subSignalsInOneRing; @@ -483,6 +385,7 @@ HcclResult CommonOperator::MultiRingAllGatherConcurrent(const std::string &tag, CHK_PRT_RET(singleRingSliceZero.empty(), HCCL_ERROR("[CommonOperator][MultiRingAllGatherConcurrent]"\ "singleRingSliceZero is empty"), HCCL_E_INTERNAL); + // 910_73场景 生成userMemOut_上对应的slices std::vector userMemOutputSlices; CHK_RET( CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, userMemOutputSlices)); @@ -496,6 +399,7 @@ HcclResult CommonOperator::MultiRingAllGatherConcurrent(const std::string &tag, u32 rankSize = commRing->RankSize(); u32 ringIndexOp = ringIndex; + // 910_73场景 准备环中的从流 std::vector subStreamsInOneRing; std::vector> mainSignalsInOneRing; std::vector> subSignalsInOneRing; @@ -784,186 +688,6 @@ HcclResult CommonOperator::MultiRingReduceScatter(const std::string &tag, Device return HCCL_SUCCESS; } -HcclResult CommonOperator::MultiRingReduceScatterConcurrent(const std::string &tag, DeviceMem inputMem, - DeviceMem outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, - const std::vector>> &multRingsSliceZero, Stream stream, s32 profStage, - const u64 baseOffset, const HcomCollOpInfo *opInfo) -{ - HcclResult ret = HCCL_SUCCESS; - u32 ringNum = multRingsSliceZero.size(); - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - - CommInfo *currComm; - ret = hcclImpl_->GetCommInfo(currComm, tag); - - std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); - - u32 halfRingSize = ringNum; - if (ringNum > RDMA_PLANE_NUM_IN_NPRING_DOUBLE) { - halfRingSize = ringNum / DEVICE_TWO; - } - - // 拿到ring环映射关系 - u32 ranksSize = currComm->commOuter[COMM_INDEX_0]->RankSize(); - std::vector> multiRingsOrder = GetRingsOrderByTopoType(ranksSize, topoType_, nicList_); - - u64 reduceAttr = GetReduceAttr(inputMem, outputMem, dataType, reductionOp); - - // 空拷贝用于后续操作附着 - CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); - for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { - std::vector singleRingSliceZero = multRingsSliceZero[ringIndex].second; - CHK_PRT_RET(singleRingSliceZero.empty(), - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]singleRingSliceZero is empty"), - HCCL_E_INTERNAL); - - // 生成userMemIn_上对应的slices - std::vector userMemInputSlices; - u32 commIndex = ringIndex % halfRingSize; - CHK_RET( - CalUserMemSlices(dataType, opInfo, singleRingSliceZero, ringIndex, multiRingsOrder, userMemInputSlices)); - std::vector rankOrder; - CHK_RET(GetRankOrder(multiRingsOrder, commIndex, rankOrder)); - std::unique_ptr &commRing = multRingsSliceZero[ringIndex].first ? currComm->commOuter[commIndex] : - currComm->commOuterRdma[commIndex]; - CHK_SMART_PTR_NULL(commRing); - - u32 rankSize = commRing->RankSize(); - u32 ringIndexOp = ringIndex; - - std::vector subStreamsInOneRing; - std::vector> mainSignalsInOneRing; - std::vector> subSignalsInOneRing; - if (opInfo != nullptr) { - CHK_RET(GetSubStreamInfoOnOneRing(*streamInfo, ringIndex, subStreamsInOneRing, mainSignalsInOneRing, - subSignalsInOneRing)); - } - - if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline - ret = StreamActiveManager::GetInstance(deviceLogicId_).StreamActive( - streamInfo->ringStreams[ringIndex].ptr(), stream.ptr()); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]active stream[%u], failed", - ringIndex), ret); - } - - if (!GetExternalInputHcclEnableFfts() && - GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - /* 更新线程参数 */ - if (opInfo != nullptr) { - streamInfo->ringThreadsManage[ringIndex]->Prepare( - inputMem, inputMem, outputMem, count, dataType, streamInfo->ringStreams[ringIndex], reductionOp, - OUTER_BRIDGE_RANK_ID, singleRingSliceZero, baseOffset, ringNics[ringIndex % halfRingSize], tag, - profStage, commRing.get(), streamInfo->ringSignalAux[ringIndex], - streamInfo->ringSignal[ringIndex], ringIndex, ExecutorType::REDUCE_SCATTER_RING_DIRECT, - reduceAttr, opInfo, subStreamsInOneRing, mainSignalsInOneRing, - subSignalsInOneRing, rankOrder, userMemInputSlices); - } else { - streamInfo->ringThreadsManage[ringIndex]->Prepare(inputMem, inputMem, outputMem, count, dataType, - streamInfo->ringStreams[ringIndex], reductionOp, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, - baseOffset, ringNics[ringIndex % halfRingSize], tag, profStage, commRing.get(), - streamInfo->ringSignalAux[ringIndex], streamInfo->ringSignal[ringIndex], ringIndex, - ExecutorType::REDUCE_SCATTER_RING, reduceAttr); - } - - streamInfo->ringThreadsManage[ringIndex]->NotifyStart(); // 给线程发通知启动线程执行 - } else { - std::unique_ptr executor; - if (opInfo != nullptr) { - executor.reset(new (std::nothrow) ReduceScatterRingConcurrentDirect( - dispatcher_, reduceAttr, opInfo, commRing->UserRank(), subStreamsInOneRing, - mainSignalsInOneRing, subSignalsInOneRing, rankOrder, userMemInputSlices)); - } else { - executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - } - CHK_SMART_PTR_NULL(executor); - - ret = LocalNotify::Wait(streamInfo->ringStreams[ringIndex], dispatcher_, - streamInfo->ringSignalAux[ringIndex], profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u] wait failed", ringIndex), - ret); - ret = executor->Prepare(inputMem, inputMem, outputMem, count, dataType, - streamInfo->ringStreams[ringIndex], reductionOp, OUTER_BRIDGE_RANK_ID, - singleRingSliceZero, baseOffset, ringNics[ringIndex % halfRingSize]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) "\ - "prepare failed,return[%d]", ringIndex, ret), ret); - ret = executor->RegisterProfiler( - ((ringIndexOp + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + - (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commRing->Rank(), - profStage, HCCL_EXEC_STEP_NOT_SET, streamInfo->ringStreams[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) "\ - "register Profiler failed,return[%d]", ringIndex, ret), ret); - - ret = commRing->RunExecutor(executor); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) run "\ - "failed,return[%d]", ringIndex, ret), ret); - - ret = LocalNotify::Post(streamInfo->ringStreams[ringIndex], dispatcher_, - streamInfo->ringSignal[ringIndex], profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u] record failed", ringIndex), - ret); - } - /* 主环record启动从环 */ - ret = LocalNotify::Post(stream, dispatcher_, streamInfo->ringSignalAux[ringIndex], profStage); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u] record failed", ringIndex), - ret); - } else { // 主环 最后一个环 - std::unique_ptr executor; - if (opInfo != nullptr) { - executor.reset(new (std::nothrow) ReduceScatterRingConcurrentDirect( - dispatcher_, reduceAttr, opInfo, commRing->UserRank(), subStreamsInOneRing, mainSignalsInOneRing, - subSignalsInOneRing, rankOrder, userMemInputSlices)); - } else { - executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - } - CHK_SMART_PTR_NULL(executor); - ret = executor->Prepare(inputMem, inputMem, outputMem, count, dataType, stream, - reductionOp, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, baseOffset, ringNics[ringIndex % halfRingSize]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) prepare "\ - "failed,return[%d]", ringIndex, ret), ret); - - ret = executor->RegisterProfiler( - ((ringIndexOp + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + - (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commRing->Rank(), - profStage, HCCL_EXEC_STEP_NOT_SET, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) register"\ - " Profiler failed,return[%d]", ringIndex, ret), ret); - - ret = commRing->RunExecutor(executor); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u],reduce scatter(ring) run "\ - "failed,return[%d]", ringIndex, ret), ret); - for (u32 ring = 0; ring < (ringNum - 1); ring++) { - if (!GetExternalInputHcclEnableFfts() && - GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - streamInfo->ringThreadsManage[ring]->WaitDone(); - } - /* 等待executor执行完毕 */ - ret = LocalNotify::Wait(stream, dispatcher_, streamInfo->ringSignal[ring], profStage); - - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingReduceScatterConcurrent]stream[%u] wait failed", ring), ret); - } - } - } - // 添加空task,保证子图执行时不乱序 - CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); - return HCCL_SUCCESS; -} - HcclResult CommonOperator::MultiRingScatter(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const std::vector > multRingsSliceZero, u32 root, Stream stream, const HcomCollOpInfo *opInfo) @@ -1107,115 +831,6 @@ HcclResult CommonOperator::MultiRingScatter(const std::string &tag, DeviceMem in return HCCL_SUCCESS; } -HcclResult CommonOperator::MultiRingMultiRootScatter(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - const u64 count, const HcclDataType dataType, const std::vector> &multRingsSliceZero, - u32 root, Stream stream, const u64 baseOffset) -{ - HcclResult ret = HCCL_SUCCESS; - u32 ringNum = multRingsSliceZero.size(); - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - - CommInfo *currComm; - ret = hcclImpl_->GetCommInfo(currComm, tag); - - std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); - - for (u32 ringIndex = 0; ringIndex < ringNum; ringIndex++) { - std::vector singleRingSliceZero = multRingsSliceZero[ringIndex]; - CHK_PRT_RET(singleRingSliceZero.empty(), - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]singleRingSliceZero is empty"), HCCL_E_INTERNAL); - - std::unique_ptr &commRing = currComm->commOuter[ringIndex]; - CHK_SMART_PTR_NULL(commRing); - u32 rankSize = commRing->RankSize(); - std::unique_ptr executor; - executor.reset(new (std::nothrow) MultiRootScatterRing(dispatcher_)); - CHK_SMART_PTR_NULL(executor); - - if (ringIndex != (ringNum - 1)) { - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { // offline - CHK_RET(StreamActiveManager::GetInstance(deviceLogicId_).StreamActive( - streamInfo->ringStreams[ringIndex].ptr(), stream.ptr())); - } - } - - u32 rootRank = 0; - ret = commRing->GetRankByUserRank(root, rootRank); - CHK_PRT_RET(ret == HCCL_E_PARA, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]invalid root [%u] to get userrank", root), ret); - - if (ringIndex != (ringNum - 1)) { // 0~ringNum-2的环 - ret = LocalNotify::Wait(streamInfo->ringStreams[ringIndex], dispatcher_, - streamInfo->ringSignalAux[ringIndex], PROF_STAGE_0); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]in stream[%u] wait failed", ringIndex), ret); - - ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, - streamInfo->ringStreams[ringIndex], HcclReduceOp::HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, - singleRingSliceZero, baseOffset, ringNics[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) "\ - "prepare failed,return[%d]", ringIndex, ret), ret); - - ret = executor->RegisterProfiler( - ((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + - commRing ->Rank(), PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, - streamInfo->ringStreams[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u], multirootscatter(ring) "\ - "register profiler failed,return[%d]", ringIndex, ret), ret); - - ret = commRing->RunExecutor(executor); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) "\ - "failed,return[%d]", ringIndex, ret), ret); - - ret = LocalNotify::Post(streamInfo->ringStreams[ringIndex], dispatcher_, streamInfo->ringSignal[ringIndex], - PROF_STAGE_0); - - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u] record failed", ringIndex), ret); - - ret = LocalNotify::Post(stream, dispatcher_, streamInfo->ringSignalAux[ringIndex], PROF_STAGE_0); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u] record failed", ringIndex), ret); - } else { // 主环 - executor.reset(new (std::nothrow) MultiRootScatterRing(dispatcher_)); - CHK_SMART_PTR_NULL(executor); - ret = executor->Prepare(inputMem, outputMem, outputMem, count, dataType, stream, - HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, singleRingSliceZero, baseOffset, ringNics[ringIndex]); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) "\ - "prepare failed,return[%d]", ringIndex, ret), ret); - - ret = executor->RegisterProfiler( - ((ringIndex + 1) << PROF_RINGINDEX_OFFSET_OF_PLANEID) + (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) - + commRing ->Rank(), PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u], multirootscatter(ring) "\ - "register profiler failed,return[%d]", ringIndex, ret), ret); - - ret = commRing->RunExecutor(executor); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u],multirootscatter(ring) run "\ - "failed,return[%d]", ringIndex, ret), ret); - for (u32 ring = 0; ring < (ringNum - 1); ring++) { - /* 等待executor执行完毕 , 当前环没有分配数据,跳过此环处理,继续下一个环 */ - ret = LocalNotify::Wait(stream, dispatcher_, streamInfo->ringSignal[ring], PROF_STAGE_0); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[CommonOperator][MultiRingMultiRootScatter]stream[%u] wait failed", ring), ret); - } - } - } - // 添加空task,保证子图执行时不乱序 - CHK_RET(ExecutorBase::ExecEmptyTask(inputMem, outputMem, stream, dispatcher_)); - return HCCL_SUCCESS; -} - HcclResult CommonOperator::MultiStreamReduceScatterMeshAtomic(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, const std::vector &dataSliceVct, Stream &stream, @@ -1592,7 +1207,7 @@ bool CommonOperator::Is910BSingleMesh() bool CommonOperator::NeedCreateSingleMeshPlane(const bool isInlineReduce) { // 910B 图模式非确定计算,inlineReduce使能,MESH拓扑场景下,创建一个mesh平面 - bool meshSinglePlane = Is910BSingleMesh() && hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && + bool meshSinglePlane = Is910BSingleMesh() && topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && isInlineReduce && (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); return meshSinglePlane; diff --git a/src/domain/collective_communication/algorithm/impl/operator/common_operator.h b/src/domain/collective_communication/algorithm/impl/operator/common_operator.h index 2ebf0bc5abd6993ea7efed384bd2695212f54c35..fff6bac4420496aac6a9dcb1ae8b8c8177493f18 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/common_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/common_operator.h @@ -16,7 +16,7 @@ namespace hccl { class CommonOperator : public CollAlgOperator { public: - CommonOperator(std::unique_ptr &pImpl, HcclCMDType opType); + CommonOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher, HcclCMDType opType); ~CommonOperator(); // CCL Op Share @@ -25,23 +25,12 @@ public: const std::vector > multRingsSliceZero, HcclReduceOp op, u32 root, Stream stream, s32 profStage); - HcclResult MultiRingAllReduce(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - const u64 count, const HcclDataType dataType, - const HcclReduceOp reductionOp, - const std::vector> &multRingsSliceZero, Stream stream, - s32 profStage, const u64 baseOffset = 0); - HcclResult MultiRingReduceScatter(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, const std::vector> multRingsSliceZero, Stream stream, s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr, const std::vector> multRingsUserMemSlice = std::vector> (0)); - HcclResult MultiRingReduceScatterConcurrent(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, - const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, - const std::vector>> &multRingsSliceZero, Stream stream, - s32 profStage, const u64 baseOffset = 0, const HcomCollOpInfo *opInfo = nullptr); - HcclResult MultiRingAllGather(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const std::vector > multRingsSliceZero, Stream stream, @@ -57,10 +46,6 @@ public: const std::vector > multRingsSliceZero, u32 root, Stream stream, const HcomCollOpInfo *opInfo); - HcclResult MultiRingMultiRootScatter(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - const u64 count, const HcclDataType dataType, const std::vector> &multRingsSliceZero, - u32 root, Stream stream, const u64 baseOffset); - HcclResult MultiStreamReduceScatterMesh(const std::string &tag, DeviceMem inputMem, DeviceMem outputMem, const u64 count, const HcclDataType dataType, const HcclReduceOp reductionOp, diff --git a/src/domain/collective_communication/algorithm/impl/operator/gather_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/gather_operator.cc index 1f5462f96ded9943c9c3890cdf9b241eb8c8ff7d..34a2ba5b8699e0493933a3b3743eac6261568026 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/gather_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/gather_operator.cc @@ -12,8 +12,8 @@ #include "executor_impl.h" namespace hccl { -GatherOperator::GatherOperator(std::unique_ptr &pImpl) - : CollAlgOperator(pImpl, HcclCMDType::HCCL_CMD_GATHER) +GatherOperator::GatherOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CollAlgOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_GATHER) { } diff --git a/src/domain/collective_communication/algorithm/impl/operator/gather_operator.h b/src/domain/collective_communication/algorithm/impl/operator/gather_operator.h index 0cdbf67691df4938339a77595f17668eec130474..52735f81249a2715acf1757a5668d45c795a79ce 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/gather_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/gather_operator.h @@ -16,7 +16,7 @@ namespace hccl { class GatherOperator : public CollAlgOperator { public: - GatherOperator(std::unique_ptr &pImpl); + GatherOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~GatherOperator(); HcclResult Gather(const std::string &tag, void *inputPtr, void *outputPtr, u32 rootRank, u64 inputCount, HcclDataType dataType, Stream stream); diff --git a/src/domain/collective_communication/algorithm/impl/operator/receive_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/receive_operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..98850be3b1fc997e42e50cfaa39fae4af8cd7a64 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/operator/receive_operator.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "receive_operator.h" +#include "rank_consistent.h" +#include "executor_impl.h" + +namespace hccl { +ReceiveOperator::ReceiveOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CollAlgOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_RECEIVE) +{ +} + +ReceiveOperator::~ReceiveOperator() +{ +} + +HcclResult ReceiveOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) +{ + algName = "ReceiveExecutor"; + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + newTag = tag; + } else { + newTag = tag + algName; + } + HCCL_INFO("[SelectAlg] receive newTag is [%s]", newTag.c_str()); + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_RECEIVE, Receive, ReceiveOperator); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/receive_operator.h b/src/domain/collective_communication/algorithm/impl/operator/receive_operator.h new file mode 100644 index 0000000000000000000000000000000000000000..c9c00ba03460e7690ceb90f041071dcd92d6e00b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/operator/receive_operator.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef RECEIVE_OPERATOR_H +#define RECEIVE_OPERATOR_H + +#include "common_operator.h" +#include +#include "coll_alg_op_registry.h" + +namespace hccl { +class ReceiveOperator : public CollAlgOperator { +public: + ReceiveOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); + ~ReceiveOperator(); + + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); +}; +} + + +#endif /** __RECEIVE_OPERATOR_H__ */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.cc index 5f865f2245886cc69bc7f09dc06458596cff2dc4..b8a8ce93dd8f26b395ddd3b148817975b5d079c0 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.cc @@ -16,8 +16,8 @@ namespace hccl { -ReduceOperator::ReduceOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_REDUCE) +ReduceOperator::ReduceOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_REDUCE) { if (UseInterServerNHRAlgo(algType_) || UseInterServerNHRV1Algo(algType_) || UseInterServerNBAlgo(algType_) || UseInterServerPipelineAlgo(algType_)) { @@ -85,7 +85,7 @@ HcclResult ReduceOperator::Reduce(const std::string &tag, void *inputPtr, void * HCCL_ERROR("[ReduceOperator][Reduce]errNo[0x%016llx] tag[%s],reduce run failed", HCCL_ERROR_CODE(ret), tag.c_str()), ret); - HCCL_INFO("tag[%s],reduce run success,take time [%lld]us.", tag.c_str(), DURATION_US(TIME_NOW() - startut)); + HCCL_INFO("tag[%s], rank[%u] root[%u] reduce run success,take time [%lld]us.", tag.c_str(), userRank_, root, DURATION_US(TIME_NOW() - startut)); return HCCL_SUCCESS; } @@ -389,7 +389,7 @@ HcclResult ReduceOperator::ReduceMeshExecutor(const std::string &tag, DeviceMem } } std::vector> &commMeshVec = currComm->commOuter; - if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && (dataType != HCCL_DATA_TYPE_INT64) && + if (topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && (dataType != HCCL_DATA_TYPE_INT64) && (deviceType_ == DevType::DEV_TYPE_910B && op != HCCL_REDUCE_PROD)) { CHK_RET(MultiStreamReduceScatterMeshAtomic(tag, inputMem, outputMem, count, dataType, op, dataSegsSlice, stream, commMeshVec)); @@ -503,12 +503,11 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi "!=multiRingsSliceZero size[%llu]", ringNum, multiRingsSliceZero.size()), HCCL_E_INTERNAL); CHK_RET(MultiRingReduceScatter(tag, inputMem, outputMem, count, dataType, op, multiRingsSliceZero, stream, PROF_STAGE_0)); - HCCL_INFO("[ReduceDoubleRingExecutor]reduce double ring stage0 run success"); + HCCL_INFO("[ReduceDoubleRingExecutor]stage0 run success"); u32 commIndex = 0; u64 level1Size = 0; u32 segmentIdx = 0; - CHK_RET(hcclImpl_->PrepareInnerCommInfo(segmentIdx, commIndex, level1Size, - currComm->commOuter, multiRingsSliceZero, tag)); + CHK_RET(hcclImpl_->PrepareInnerCommInfo(segmentIdx, commIndex, level1Size, currComm->commOuter, multiRingsSliceZero, tag)); u64 level1Count = level1Size / perDataSize; if (devNumInLevel2_ <= 1) { bRet = commIndex >= currComm->commInner.size(); @@ -522,10 +521,10 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi std::unique_ptr innerExecutor; if (UseInterServerRingAlgo(algType_)) { innerExecutor.reset(new (std::nothrow) ReduceRing(dispatcher_, reduceAttr)); - HCCL_INFO("[ReduceDoubleRingExecutor]reduce ring: using ring algo inter-server."); + HCCL_INFO("[ReduceDoubleRingExecutor]using ring algo inter-server."); } else { innerExecutor.reset(new (std::nothrow) ReduceRecursiveHalvingDoubling(dispatcher_, reduceAttr)); - HCCL_INFO("[ReduceDoubleRingExecutor]reduce ring: using Recursive halving-doubling algo inter-server."); + HCCL_INFO("[ReduceDoubleRingExecutor]using Recursive halving-doubling algo inter-server."); } CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); u32 rankSize = (currComm->commInner[commIndex]->RankSize()); @@ -539,13 +538,14 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi // 节点间的hd 使用环0来记录 CHK_SMART_PTR_NULL(innerExecutor); CHK_RET(innerExecutor->Prepare( - reduceInput, reduceOutput, reduceOutput, level1Count, dataType, stream, op, OUTER_BRIDGE_RANK_ID, + reduceInput, reduceOutput, reduceOutput, level1Count, dataType, stream, op, planeRoot, std::vector(0), dataSegsSlice[segmentIdx].offset)); CHK_RET(innerExecutor->RegisterProfiler( (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commInner->Rank(), PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, stream)); CHK_RET(commInner->RunExecutor(innerExecutor)); } else { + //节点间 reduce scatter CHK_RET(ExecutorBase::PrepareSliceData(level1Count, perDataSize, sliceNum, 0, dataSegsSlice)); bRet = commIndex >= currComm->commInner.size(); CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceDoubleRingExecutor] commIndex[%u] >= ",\ @@ -555,19 +555,23 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi DeviceMem reducescatterOutput = outputMem.range(dataSegsSlice[segmentIdx].offset, level1Size); u64 reduceAttr = GetReduceAttr(reducescatterInput, reducescatterOutput, dataType, op); std::unique_ptr level1RSExecutor; + u32 subUserrankRoot = hcclImpl_->GetSubRootUserRank(userRank_, root); + u32 planeRoot = 0; + std::unique_ptr &commInner = currComm->commInner[commIndex]; + CHK_RET(commInner->GetRankByUserRank(subUserrankRoot, planeRoot)); if (UseInterServerRingAlgo(algType_)) { level1RSExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); CHK_SMART_PTR_NULL(level1RSExecutor); CHK_RET(level1RSExecutor->Prepare( reducescatterInput, reducescatterInput, reducescatterOutput, level1Count, dataType, stream, op, - OUTER_BRIDGE_RANK_ID, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + planeRoot, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); HCCL_INFO("[ReduceDoubleRingExecutor]reducescatter ring: using ring algo inter-server."); } else { level1RSExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); CHK_SMART_PTR_NULL(level1RSExecutor); CHK_RET(level1RSExecutor->Prepare( reducescatterInput, reducescatterOutput, reducescatterOutput, level1Count, dataType, stream, op, - OUTER_BRIDGE_RANK_ID, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); + planeRoot, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); HCCL_INFO("[ReduceDoubleRingExecutor]reducescatter ring: using halving-doubling algo inter-server."); } CHK_RET(level1RSExecutor->RegisterProfiler( @@ -584,7 +588,11 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi bRet = commIndex >= currComm->commLevel2.size(); CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceDoubleRingExecutor] commIndex[%u] >= ",\ "(tag[%s])comm size[%llu]", commIndex, tag.c_str(), currComm->commLevel2.size()), HCCL_E_INTERNAL); - CHK_SMART_PTR_NULL(currComm->commLevel2[commIndex]); + std::unique_ptr &commSuperpod = currComm->commLevel2[commIndex]; + CHK_PTR_NULL(commSuperpod); + u32 subUserrankRootSupperPod = hcclImpl_->GetSubRootUserRankWithSuperPod(userRank_, root); + u32 planeRootSupperPod = 0; + CHK_RET(commSuperpod->GetRankByUserRank(subUserrankRootSupperPod,planeRootSupperPod)); u32 rankSize = currComm->commLevel2[COMM_INDEX_0]->RankSize(); DeviceMem reduceInput = inputMem.range(dataSegsSlice[segmentIdx].offset, rSize); DeviceMem reduceOutput = outputMem.range(dataSegsSlice[segmentIdx].offset, rSize); @@ -598,7 +606,7 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi HCCL_INFO("[ReduceDoubleRingExecutor]reducescatter ring: using halving-doubling algo inter-server."); } CHK_RET(level2RExecutor->Prepare( - reduceInput, reduceOutput, reduceOutput, arCount, dataType, stream, op, OUTER_BRIDGE_RANK_ID, + reduceInput, reduceOutput, reduceOutput, arCount, dataType, stream, op, planeRootSupperPod, std::vector(0), dataSegsSlice[segmentIdx].offset)); CHK_SMART_PTR_NULL(level2RExecutor); CHK_RET(level2RExecutor->RegisterProfiler( @@ -613,7 +621,7 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi level1GExecutor.reset(new (std::nothrow) GatherRing(dispatcher_)); CHK_SMART_PTR_NULL(level1GExecutor); CHK_RET(level1GExecutor->Prepare(gatherOutput, gatherOutput, gatherOutput, arCount, dataType, stream, - HcclReduceOp::HCCL_REDUCE_RESERVED, OUTER_BRIDGE_RANK_ID, dataSegsSlice, + HcclReduceOp::HCCL_REDUCE_RESERVED, planeRoot, dataSegsSlice, dataSegsSlice[segmentIdx].offset)); CHK_RET(level1GExecutor->RegisterProfiler( (sliceNum << PROF_RANKSIZE_OFFSET_OF_PLANEID) + currComm->commInner[commIndex]->Rank(), @@ -621,7 +629,7 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi CHK_RET(currComm->commInner[commIndex]->RunExecutor(level1GExecutor)); HCCL_INFO("[ReduceDoubleRingExecutor]reduce double ring [superpod] level1 gather run success"); } - HCCL_INFO("[ReduceDoubleRingExecutor]reduce double ring stage1 run success"); + HCCL_INFO("[ReduceDoubleRingExecutor]stage1 run success"); u32 rootRank = 0; std::unique_ptr &commOuter = currComm->commOuter[COMM_INDEX_0]; CHK_SMART_PTR_NULL(commOuter); @@ -635,4 +643,85 @@ HcclResult ReduceOperator::ReduceDoubleRingExecutor(const std::string &tag, Devi HCCL_INFO("[ReduceDoubleRingExecutor]reduce double ring stage2 run success"); return HCCL_SUCCESS; } + +HcclResult ReduceOperator::SelectAlg(const std::string &tag, const OpParam ¶m, std::string &algName, + std::string &newTag) +{ + HcclResult ret = HCCL_SUCCESS; + + if (userRankSize_ == 1 && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + algName = "ReduceSingleExecutor"; + return HCCL_SUCCESS; + } + + if (deviceType_ == DevType::DEV_TYPE_910) { + ret = SelectAlgfor910A(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910B) { + ret = SelectAlgfor910B(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910_73) { + ret = SelectAlgfor91073(param, algName); + } else { + HCCL_ERROR("[SelectAlg] device type[%d] is out of range for selector.", deviceType_); + return HCCL_E_NOT_SUPPORT; + } + + AlgTypeLevel1 algType1 = GetLevel1AlgType(algType_); + auto level1Iter = HCCL_ALGO_LEVEL1_NAME_MAP.find(algType1); + newTag = tag + level1Iter->second + algName; + + HCCL_INFO("[SelectAlg] reduce newTag is [%s]", newTag.c_str()); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[ReduceSelector][SelectAlg]tag[%s], reduce failed, return[%d]", tag.c_str(), ret), ret); + return ret; +} + +HcclResult ReduceOperator::SelectAlgfor910A(const OpParam& param, std::string& algName) +{ + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_4P_MESH || topoType_ == TopoType::TOPO_TYPE_2P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING || topoType_ == TopoType::TOPO_TYPE_8P_RING; + + if (isMeshTopo) { + algName = "ReduceMeshExecutor"; + } else if (isRingTopo) { + algName = "ReduceRingPlusHd"; + } else { + algName = "ReduceComm"; + } + + HCCL_INFO("[SelectAlgfor910A] reduce SelectAlgfor910A is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult ReduceOperator::SelectAlgfor910B(const OpParam& param, std::string& algName) +{ + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || + topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING; + + if (isMeshTopo) { + algName = "ReduceMeshExecutor"; + } else if (isRingTopo) { + algName = "ReduceRingPlusHd"; + } else { + algName = "ReduceComm"; + } + + HCCL_INFO("[SelectAlgfor910B] reduce SelectAlgfor910B is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult ReduceOperator::SelectAlgfor91073(const OpParam& param, std::string& algName) +{ + // 当前double ring算法不支持,与single ring保持一致 + if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + algName = "ReduceDoubleRingExecutor"; + } else { + algName = "ReduceComm"; + } + HCCL_INFO("[SelectAlgfor91073] areduce SelectAlgfor91073 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_REDUCE, Reduce, ReduceOperator); + } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.h b/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.h index 1f1f3ea7096c433ba4a15ac6dcf5868d44632d89..fec83b303973f7ea4d9f9bb7a4dc315875dddb52 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/reduce_operator.h @@ -12,11 +12,12 @@ #define REDUCE_OPERATOR_H #include "common_operator.h" +#include "coll_alg_op_registry.h" namespace hccl { class ReduceOperator : public CommonOperator { public: - ReduceOperator(std::unique_ptr &pImpl); + ReduceOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~ReduceOperator(); HcclResult Reduce(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, HcclReduceOp op, u32 root, Stream stream); @@ -43,6 +44,12 @@ private: HcclResult ReduceOutPlaceForOneRankSize(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, HcclReduceOp op, u32 root, Stream stream,bool isRootRank,ReduceType reduceType, const std::unique_ptr &opBaseAtraceInfo = nullptr); + + // 算法选择 + HcclResult SelectAlg(const std::string &tag, const OpParam ¶m, std::string &algName, std::string &newTag); + HcclResult SelectAlgfor910A(const OpParam& param, std::string& algName); // 算法选择 - 910A + HcclResult SelectAlgfor910B(const OpParam& param, std::string& algName); // 算法选择 - 910B + HcclResult SelectAlgfor91073(const OpParam& param, std::string& algName); // 算法选择 - 91073 }; } diff --git a/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.cc index 1ed06f8fb9c169a44c108f1aeaf8725030bc2b1f..8c86a8da993034eea60da7f10bad3fbde91eecd6 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.cc @@ -10,14 +10,11 @@ #include "reduce_scatter_operator.h" #include "device_capacity.h" -#include "rank_consistent.h" -#include "executor_impl.h" -#include "stream_active_manager.h" - namespace hccl { -ReduceScatterOperator::ReduceScatterOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_REDUCE_SCATTER) +ReduceScatterOperator::ReduceScatterOperator(std::unique_ptr &pImpl, + std::unique_ptr &topoMatcher) : + CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_REDUCE_SCATTER) { } @@ -25,1260 +22,162 @@ ReduceScatterOperator::~ReduceScatterOperator() { } -HcclResult ReduceScatterOperator::ReduceScatterCommFor310P(const std::string &tag, DeviceMem &inputMem, - DeviceMem &outputMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream) -{ - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - std::unique_ptr &commCombined = currComm->commIntraServer; - std::unique_ptr executor; - bool isInlineReduce = IsSupportSDMAReduce(inputMem.ptr(), outputMem.ptr(), dataType, op); - - u64 reduceAttr = 0; - if (isInlineReduce) { - SalSetBitOne(reduceAttr, ATTR_POS_INLINE_REDUCE); - } - executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - CHK_SMART_PTR_NULL(executor); - - CHK_RET(executor->Prepare(inputMem, outputMem, outputMem, count, dataType, stream, op)); - - u32 rankSize = commCombined->RankSize(); - CHK_RET(executor->RegisterProfiler( - (rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commCombined->Rank(), - PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); - - CHK_RET(commCombined->RunExecutor(executor)); - - return HCCL_SUCCESS; -} - -HcclResult ReduceScatterOperator::ReduceScatterDMAReduceRingExecutorMiddlelayer( - const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, DeviceMem &scratchMem, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream &stream, HcomCollOpInfo *opInfo) -{ - HcclResult ret = HCCL_SUCCESS; - u32 unitSize = SIZE_TABLE[dataType]; - // 中转内存单次最多能够接受的output count,放开ranksize限制 - u64 maxCountPerLoop = cclBufferManager_.GetInCCLbuffer().size() / (userRankSize_ * unitSize); - - u8 *curInputPtr = static_cast(opInfo->inputAddr); - u8 *curOutputPtr = static_cast(opInfo->outputAddr); - CHK_PTR_NULL(curInputPtr); - CHK_PTR_NULL(curOutputPtr); - - ReduceType reduceType = ((op != HCCL_REDUCE_PROD) && (dataType != HCCL_DATA_TYPE_INT64)) ? - ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; - - auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; - - u64 curCount = 0; - for (u64 countLeft = count, inputOffset = 0, outputOffset = 0; countLeft > 0; countLeft -= curCount) { - curInputPtr += inputOffset; - curOutputPtr += outputOffset; - opInfo->inputAddr = curInputPtr; - opInfo->outputAddr = curOutputPtr; - // 判断剩余数据量对应的input size是否大于中转input size - curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; - CHK_PRT_RET( - (curCount == 0), - HCCL_ERROR( - "[ReduceScatterOperator][ReduceScatterDMAReduceRingExecutorMiddlelayer]In OP_BASE curCount is zero"), - HCCL_E_PARA); - u64 curSize = curCount * unitSize; // 单位:字节 - DeviceMem curInputMem = inputMem.range(0, curSize * userRankSize_); - DeviceMem curOutputMem = outputMem.range(0, curSize); - DeviceMem curScratchMem = scratchMem.range(0, curSize * userRankSize_); - - /* 下沉子图reset,保证子图不复用标志生效 */ - bool hugeData = curSize > SDMA_SEND_MAX_SIZE; - auto meta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, dataType, reduceType, hugeData); - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); - ret = ReduceScatterDoubleRingExecutor(tag, curInputMem, curOutputMem, curScratchMem, curCount, dataType, op, - stream, opInfo); - inputOffset = curSize; - outputOffset = curSize; - CHK_RET(LaunchTask(dispatcher_, stream)); - } - return ret; -} - -HcclResult ReduceScatterOperator::RunReduceScatter(const std::string &tag, DeviceMem& inputMem, DeviceMem& outputMem, - DeviceMem& scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream& stream, HcomCollOpInfo *opInfo) +HcclResult ReduceScatterOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) { - HcclResult ret; - if (Is310P3Common()) { - ret = ReduceScatterCommFor310P(tag, inputMem, outputMem, count, dataType, op, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Run][ReduceScatter]tag[%s], reduce_scatter_run failed, return[%d]", tag.c_str(), ret), ret); + if (userRankSize_ == 1 && GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + algName = "ReduceScatterSingleExecutor"; return HCCL_SUCCESS; } - - switch (topoType_) { - case TopoType::TOPO_TYPE_NP_MESH: - case TopoType::TOPO_TYPE_4P_MESH: - case TopoType::TOPO_TYPE_2P_MESH: - case TopoType::TOPO_TYPE_1P_MESH: - if (opInfo != nullptr) { - if (isSingleMeshAggregation_) { - ret = ReduceScatterMeshOpbaseExecutorMiddlelayer(tag, inputMem, outputMem, scratchMem, - count, dataType, op, stream, opInfo); - } else { - ret = ReduceScatterMeshOpbasePipelineExecutor(tag, inputMem, count, dataType, - op, stream, opInfo); - } - break; - } else { - ret = ReduceScatterMeshExecutor(tag, inputMem, outputMem, scratchMem, count, dataType, op, stream, - opInfo); - break; - } - case TopoType::TOPO_TYPE_8P_RING: - case TopoType::TOPO_TYPE_NP_SINGLE_RING: - case TopoType::TOPO_TYPE_NP_DOUBLE_RING: - ret = ReduceScatterRingExecutor(tag, inputMem, outputMem, scratchMem, count, dataType, op, stream); - break; - default: - ret = ReduceScatterComm(tag, inputMem, outputMem, scratchMem, count, dataType, op, stream); - break; - } - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Run][ReduceScatter]tag[%s], reduce_scatter failed, retrun[%d]", - tag.c_str(), ret), ret); - - return ret; -} - -HcclResult ReduceScatterOperator::ReduceScatter(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, HcomCollOpInfo *opInfo) -{ - /* ------------集合通信资源准备------------ */ - u32 perDataSize = SIZE_TABLE[dataType]; - u64 sendSize = userRankSize_ * count * perDataSize; - DeviceMem inputMem(inputPtr, sendSize); - DeviceMem outputMem(outputPtr, count * perDataSize); - bool isInlineReduce = IsSupportSDMAReduce(inputPtr, outputPtr, dataType, op); - bool isRdmaReduce = IsSupportRDMAReduce(dataType, op); - bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || - topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; - - if (UseInterServerPipelineAlgo(algType_) && - (!(isRdmaReduce && isInlineReduce) || (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) || - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE || !isMeshTopo)) { - // 屏蔽不支持inlinreduce场景和pytorch子图+静态图场景 - HcclResult ret = SetInterServerHDAlgo(algType_); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatter]errNo[0x%016llx] tag[%s], reduceScatter "\ - "set inter server halving-doubling algo failed", HCCL_ERROR_CODE(ret), tag.c_str()), ret); - HCCL_WARNING("Pipeline algorithm is not supported because not inlineReduce, inter server is set to HD."); - } - - HcomCollOpInfo newopInfo; - bool deterministicOptimize = - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE && deviceNumPerServer_ > DEVICE_TWO; - bool enableSdmaGraph = - SingleMeshInlineReduce(inputPtr, outputPtr, dataType, op) && - (deterministicOptimize) && - (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); - if (enableSdmaGraph) { - newopInfo.inputAddr = inputPtr; - newopInfo.outputAddr = outputPtr; - newopInfo.count = count; - } - - DeviceMem scratchMem; - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE - && (deviceType_ == DevType::DEV_TYPE_910B) - && IsSupportSDMAReduce(inputPtr, outputPtr, dataType, op) && IsSupportRDMAReduce(dataType, op)) { - scratchMem = DeviceMem::create(outputPtr, - isSingleMeshAggregation_ && hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE ? - cclBufferManager_.GetInCCLbufferSize() : - sendSize); - } else if (Is310P3Common()) { - scratchMem = DeviceMem::create(outputPtr, outputMem.size()); + HcclResult ret; + if (deviceType_ == DevType::DEV_TYPE_310P3) { + ret = SelectAlgfor310P3(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910) { + ret = SelectAlgfor910A(param, algName); + } else if (deviceType_ == DevType::DEV_TYPE_910B) { + ret = SelectAlgfor910B(param, algName); } else { - u64 allocMemSize = sendSize + 2 * CCE_REDUCE_ALIGN_SIZE; /* cce reduce数据大小32字节对齐 2是指前后各有 */ - u64 allocWorkSpaceMemSize; - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - /* cce reduce数据大小32字节对齐 2是指前后各有 */ - allocWorkSpaceMemSize = cclBufferManager_.GetInCCLbufferSize() + 2 * CCE_REDUCE_ALIGN_SIZE; - } else { - allocWorkSpaceMemSize = allocMemSize; - } - DeviceMem tmpScratchMem; - CHK_RET(hcclImpl_->SetScratchMem(tmpScratchMem, tag, allocWorkSpaceMemSize)); - - DeviceMem MemMapValue = DeviceMem::create(tmpScratchMem.ptr(), allocMemSize); - CHK_SMART_PTR_NULL(MemMapValue); - u32 add_offset = (reinterpret_cast(MemMapValue.ptr())) % CCE_REDUCE_ALIGN_SIZE; // cce reduce地址32字节对齐 - scratchMem = MemMapValue.range(add_offset, sendSize); // 截取32字节对齐后的内存地址 + ret = SelectAlgfor91073(param, algName); } - meshSinglePlane_ = NeedCreateSingleMeshPlane(isInlineReduce); - - CHK_RET(hcclImpl_->PrepareCommRes(tag, inputMem, scratchMem, algType_, stream, INVALID_VALUE_RANKID, false, false, - false, meshSinglePlane_)); - - HCCL_PROFILER_ADD_STREAM(stream.ptr(), tag, 0, algType_); - - // 添加从流profiling, 用于维护planID - CHK_RET(hcclImpl_->AddSubStreamToProfiling(tag, HcclCMDType::HCCL_CMD_REDUCE_SCATTER)); - - /* ------------执行算法-------------- */ - HcclResult ret = HCCL_SUCCESS; - HcclUs startut = TIME_NOW(); - if (enableSdmaGraph) { - ret = RunReduceScatter(tag, inputMem, outputMem, scratchMem, count, dataType, op, stream, &newopInfo); + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + newTag = tag; } else { - ret = RunReduceScatter(tag, inputMem, outputMem, scratchMem, count, dataType, op, stream, opInfo); - } - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatter]errNo[0x%016llx] tag[%s],reduceScatter run failed", - HCCL_ERROR_CODE(ret), tag.c_str()), ret); - HCCL_INFO("tag[%s],reduce_scatter run success,take time [%lld]us.", tag.c_str(), DURATION_US(TIME_NOW() - startut)); - return HCCL_SUCCESS; -} - -HcclResult ReduceScatterOperator::ReduceScatterOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, - u64 count, HcclDataType dataType, HcclReduceOp op, Stream stream, - const std::unique_ptr &opBaseAtraceInfo) -{ - HcclResult ret; - auto rtStream = stream.ptr(); - - u8 *curInputPtr = static_cast(inputPtr); - u8 *curOutputPtr = static_cast(outputPtr); - CHK_PTR_NULL(curInputPtr); - CHK_PTR_NULL(curOutputPtr); - - auto inCCLbuffer = cclBufferManager_.GetInCCLbuffer(); - auto outCCLbuffer = cclBufferManager_.GetOutCCLbuffer(); - u32 unitSize = SIZE_TABLE[dataType]; - u64 maxCountPerLoop = inCCLbuffer.size() / (userRankSize_ * unitSize); // 中转内存单次最多能够接受的output count - u64 curCount = 0; - ReduceType reduceType = ((op != HCCL_REDUCE_PROD) && (dataType != HCCL_DATA_TYPE_INT64)) ? - ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; - - // 判断是否使用mesh算法,避免mesh物理链路下使用非mesh算法勿入SDMA消减流程 - // isSingleMeshAggregation_只是指示了物理链路为mesh,而SDMA消减只在mesh算法下使用 - bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || - topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; - bool isInlineReduce = IsSupportSDMAReduce(inCCLbuffer.ptr(), outCCLbuffer.ptr(), dataType, op); - bool isRdmaReduce = IsSupportRDMAReduce(dataType, op); - - u64 countSize = count * unitSize; // 单位:字节 - u64 cclBufferSize = inCCLbuffer.size() / userRankSize_; - std::string algTypeLevel1Tag; - CHK_RET(AutoSelectAlgTypeLevel1(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, countSize, cclBufferSize, algTypeLevel1Tag, - isInlineReduce, isRdmaReduce)); - if (opBaseAtraceInfo != nullptr) { - CHK_RET(opBaseAtraceInfo->SavealgtypeTraceInfo(algTypeLevel1Tag, tag)); - } - bool isPipeLine = ((deviceType_ == DevType::DEV_TYPE_910B) && (userRankSize_ != 1) && isMeshTopo && - (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) && isInlineReduce && - ((UseInterServerPipelineAlgo(algType_) && isRdmaReduce && - hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE) || - isSingleMeshAggregation_)); - bool isUseDMA = !GetExternalInputEnableRdmaSdmaConcurrent(); - if (userRankSize_ == 1 ) { - HCCL_PROFILER_ADD_TAG(tag, identifier_, GetWorkflowMode()); - HCCL_PROFILER_ADD_STREAM(rtStream, tag, 0, algType_); - HCCL_PROFILER_ADD_OPDATA(tag, count, inputPtr, outputPtr, dataType, INVALID_VALUE_RANKID, identifier_); - HCCL_PROFILER_ADD_GROUPRANK(identifier_, userRankSize_, userRank_); - auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; - bool hugeData = (count * unitSize) > SDMA_SEND_MAX_SIZE; - bool smallData = (count * unitSize) <= HCCL_SMALL_COUNT_32_KB; - if (inputPtr == outputPtr) { - auto opMeta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, dataType, reduceType, hugeData, - smallData, CopyPattern::ZCOPY); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); + if (deviceType_ == DevType::DEV_TYPE_310P3) { + newTag = tag + algName; } else { - auto opMeta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, dataType, reduceType, hugeData, - smallData, CopyPattern::BCOPY); - CHK_RET(InitTask(dispatcher_, stream, opMeta.isEnableCache, opMeta.GetCacheKey())); - DeviceMem srcMem(inputPtr, count*SIZE_TABLE[dataType]); - DeviceMem dstMem(outputPtr, count*SIZE_TABLE[dataType]); - HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream); // ranksize = 1; intput、output地址不同,input->output - } - CHK_RET(LaunchTask(dispatcher_, stream)); - HCCL_PROFILER_DEL_STREAM(rtStream); - HCCL_PROFILER_DEL_TAG(tag); - HCCL_PROFILER_DEL_OPDATA(tag); - HCCL_PROFILER_DEL_GROUPRANK(tag); - } else if (isUseDMA && (isPipeLine)) { - HcomCollOpInfo opInfo; - opInfo.inputAddr = inputPtr; - opInfo.outputAddr = outputPtr; - opInfo.count = count; - opInfo.dataType = dataType; - opInfo.reduceOp = op; - std::string newTag = tag; - if (!isSingleMeshAggregation_) { - newTag= GenerateNewTagByAlgTypeLevel1(tag, algTypeLevel1Tag); + AlgTypeLevel1 algType1 = GetLevel1AlgType(algType_); + auto level1Iter = HCCL_ALGO_LEVEL1_NAME_MAP.find(algType1); + newTag = tag + level1Iter->second + algName; } - HCCL_PROFILER_ADD_TAG(newTag, identifier_, GetWorkflowMode()); - HCCL_PROFILER_ADD_STREAM(rtStream, newTag, 0, algType_); - HCCL_PROFILER_ADD_OPDATA(newTag, count, inputPtr, outputPtr, dataType, INVALID_VALUE_RANKID, identifier_); - HCCL_PROFILER_ADD_GROUPRANK(identifier_, userRankSize_, userRank_); - HCCL_DEBUG("ReduceScatterOutPlace: curInputPtr[%p], curOutputPtr[%p], op[%s], recvCount[%llu], " - "dataType[%s], tag[%s]", curInputPtr, curOutputPtr, GetReduceOpEnumStr(op).c_str(), count, - GetDataTypeEnumStr(dataType).c_str(), newTag.c_str()); - CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, newTag, count, - dataType, op, inCCLbuffer.size(), outCCLbuffer.size())); - - ret = ReduceScatter(newTag, inCCLbuffer.ptr(), outCCLbuffer.ptr(), count, dataType, op, stream, &opInfo); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Loop][ReduceScatter]errNo[0x%016llx] op_base hcclComm reduce_scatter error, tag[%s], " - "input_ptr[%p], output_ptr[%p], count[%llu], data_type[%s], op[%s]", - HCCL_ERROR_CODE(ret), newTag.c_str(), inCCLbuffer.ptr(), outCCLbuffer.ptr(), count, - GetDataTypeEnumStr(dataType).c_str(), GetReduceOpEnumStr(op).c_str()), - ret); - CHK_RET(RankConsistent::GetInstance().DelOpPara(newTag)); - HCCL_PROFILER_DEL_STREAM(rtStream); - HCCL_PROFILER_DEL_TAG(newTag); - HCCL_PROFILER_DEL_OPDATA(newTag); - HCCL_PROFILER_DEL_GROUPRANK(newTag); - } else { - u64 countPerLoop = count > maxCountPerLoop ? maxCountPerLoop : count; - std::string newTag = GenerateNewTagByAlgTypeLevel1(tag, algTypeLevel1Tag); + bool isInlineReduce = + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType); + bool isRdmaReduce = IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType); const std::string REDUCE_SCATTER_NO_INLINE = "_no_inline"; newTag = (isInlineReduce && isRdmaReduce) ? newTag : newTag + REDUCE_SCATTER_NO_INLINE; + } - HCCL_PROFILER_ADD_TAG(newTag, identifier_, GetWorkflowMode()); - HCCL_PROFILER_ADD_STREAM(rtStream, newTag, 0, algType_); - HCCL_PROFILER_ADD_OPDATA(newTag, count, inputPtr, outputPtr, dataType, INVALID_VALUE_RANKID, identifier_); - HCCL_PROFILER_ADD_GROUPRANK(identifier_, userRankSize_, userRank_); - - HcomCollOpInfo opInfo = {"", inputPtr, outputPtr, countPerLoop, dataType, 0, op}; - CHK_RET(hcclImpl_->CreateOpBasedResources(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, newTag, opInfo)); - - for (u64 countLeft = count, inputOffset = 0, outputOffset = 0; countLeft > 0; countLeft -= curCount) { - curInputPtr += inputOffset; - curOutputPtr += outputOffset; - HCCL_INFO("-OP_BASE-ReduceScatterLoop:inputOffset[%llu], outputOffset[%llu]", inputOffset, outputOffset); - // 判断剩余数据量对应的input size是否大于中转input size - curCount = countLeft > maxCountPerLoop ? maxCountPerLoop : countLeft; - u64 curSize = curCount * unitSize; // 单位:字节 - - auto autoSelectedAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; - bool hugeData = (curSize * userRankSize_ / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE) || - (curSize > SDMA_SEND_MAX_SIZE); - u32 dataSplit = 0; - u64 dataValue = curCount * unitSize * userRankSize_; - if ((serverNum_ > 1) && ((dataValue / serverNum_) <= HCCL_SDMA_RDMA_SPLIT_SIZE)) { - dataSplit = 1; - } else if (dataValue <= HCCL_SDMA_RDMA_SPLIT_SIZE) { - dataSplit = HCCL_SPLIT_FLAG; - } - auto meta = HcclOpMetaInfo::GetOneForReduceScatter(autoSelectedAlgTypeLevel1, dataType, reduceType, - hugeData); - meta.dataSplit = dataSplit; - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); - HCCL_DEBUG("ReduceScatterOutPlace: curInputPtr[%p], curOutputPtr[%p], op[%s], recvCount[%llu], " - "dataType[%s], tag[%s]", curInputPtr, curOutputPtr, GetReduceOpEnumStr(op).c_str(), curCount, - GetDataTypeEnumStr(dataType).c_str(), newTag.c_str()); - - DeviceMem dstMem; - DeviceMem srcMem; - for (u32 i = 0; i < userRankSize_; i++) { - // 拷贝input上每个slice的数据到中转内存,源端每个slice的size固定为output的size - dstMem = inCCLbuffer.range(curSize * i, curSize); - srcMem = DeviceMem::create(curInputPtr + count * unitSize * i, curSize); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream)); - } - CHK_RET(RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, newTag, - curCount, dataType, op, inCCLbuffer.size(), outCCLbuffer.size())); - ret = ReduceScatter(newTag, inCCLbuffer.ptr(), outCCLbuffer.ptr(), curCount, dataType, op, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Loop][ReduceScatter]errNo[0x%016llx] op_base hcclComm reduce_scatter error, tag[%s], " - "input_ptr[%p], output_ptr[%p], count[%llu], data_type[%s], op[%s]", - HCCL_ERROR_CODE(ret), newTag.c_str(), inCCLbuffer.ptr(), outCCLbuffer.ptr(), curCount, - GetDataTypeEnumStr(dataType).c_str(), GetReduceOpEnumStr(op).c_str()), - ret); - CHK_RET(RankConsistent::GetInstance().DelOpPara(newTag)); - - srcMem = outCCLbuffer.range(0, curSize); - dstMem = DeviceMem::create(curOutputPtr, curSize); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream)); - - CHK_PRT_RET((curCount == 0), HCCL_ERROR("[Loop][ReduceScatter]In OP_BASE curCount is zero"), HCCL_E_PARA); - inputOffset = curSize; - outputOffset = curSize; - CHK_RET(LaunchTask(dispatcher_, stream)); - } - HCCL_PROFILER_DEL_STREAM(rtStream); - HCCL_PROFILER_DEL_TAG(newTag); - HCCL_PROFILER_DEL_OPDATA(newTag); - HCCL_PROFILER_DEL_GROUPRANK(newTag); - } - return HCCL_SUCCESS; -} - -HcclResult ReduceScatterOperator::ReduceScatterComm(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream) -{ - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - - bool bRet = currComm->commInner.size() <= 0; - CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterComm]tag[%s],reduce scatter op comm is empty", - tag.c_str()), HCCL_E_INTERNAL); - - std::unique_ptr &commCombine = currComm->commInner[COMM_INDEX_0]; - CHK_SMART_PTR_NULL(commCombine); - - u64 reduceAttr = GetReduceAttr(inputMem, outputMem, dataType, op); - - // 构造ring algorithm对应的reduce-scatter实例 - std::unique_ptr executor; - if (UseInterServerNHRAlgo(algType_)) { - executor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter comm: using nhr algo inter-server."); - CHK_SMART_PTR_NULL(executor); - CHK_RET(executor->Prepare(inputMem, outputMem, scratchMem, count, dataType, stream, op)); - } else if (UseInterServerNHRV1Algo(algType_)) { - executor.reset(new (std::nothrow) ReduceScatterNHRV1(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter comm: using nhr_v1 algo inter-server."); - CHK_SMART_PTR_NULL(executor); - CHK_RET(executor->Prepare(inputMem, outputMem, scratchMem, count, dataType, stream, op)); - CHK_RET(commCombine->RunExecutor(executor)); - } else if (UseInterServerNBAlgo(algType_)) { - executor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter comm: using nonuniform-bruck algo inter-server."); - CHK_SMART_PTR_NULL(executor); - CHK_RET(executor->Prepare(inputMem, outputMem, scratchMem, count, dataType, stream, op)); - CHK_RET(commCombine->RunExecutor(executor)); - } else { - executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter comm: using ring algo inter-server."); - CHK_SMART_PTR_NULL(executor); - CHK_RET(executor->Prepare(inputMem, inputMem, scratchMem, count, dataType, stream, op)); - CHK_RET(commCombine->RunExecutor(executor)); - // 将cclInBuffer中与userRank_对应的部分拷贝至cclOutBuffer - u64 dataSize = count * SIZE_TABLE[dataType]; - DeviceMem srcMem = inputMem.range(dataSize * userRank_, dataSize); - DeviceMem dstMem = outputMem.range(0, dataSize); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream)); - } - - return HCCL_SUCCESS; -} - -HcclResult ReduceScatterOperator::ReduceScatterMeshOpbaseExecutorMiddlelayer(const std::string &tag, - DeviceMem &inputMem, DeviceMem &outputMem, DeviceMem &scratchMem, u64 count, HcclDataType dataType, - HcclReduceOp op, Stream &stream, HcomCollOpInfo *opInfo) -{ - HcclResult ret = HCCL_SUCCESS; - - u32 unitSize = SIZE_TABLE[dataType]; - u64 maxCountPerLoop = cclBufferManager_.GetInCCLbuffer().size() / - unitSize; // 中转内存单次最多能够接受的output count,放开ranksize限制 - - if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE) { - maxCountPerLoop = (cclBufferManager_.GetInCCLbuffer().size() - - HCCL_MIN_SLICE_ALIGN_910B * deviceNumPerAggregation_) / unitSize / (deviceNumPerAggregation_ - 1); - maxCountPerLoop = maxCountPerLoop / HCCL_MIN_SLICE_ALIGN_910B; - maxCountPerLoop = maxCountPerLoop * HCCL_MIN_SLICE_ALIGN_910B; - } - - if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - maxCountPerLoop = count; - } - - u8 *curInputPtr = static_cast(opInfo->inputAddr); - u8 *curOutputPtr = static_cast(opInfo->outputAddr); - CHK_PTR_NULL(curInputPtr); - CHK_PTR_NULL(curOutputPtr); - - ReduceType reduceType = ((op != HCCL_REDUCE_PROD) && (dataType != HCCL_DATA_TYPE_INT64)) ? - ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; - - auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; - - u64 curCount = 0; - for (u64 countLeft = count, inputOffset = 0, outputOffset = 0; countLeft > 0; countLeft -= curCount) { - curInputPtr += inputOffset; - curOutputPtr += outputOffset; - opInfo->inputAddr = curInputPtr; - opInfo->outputAddr = curOutputPtr; - // 判断剩余数据量对应的input size是否大于中转input size - curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; - u64 curSize = curCount * unitSize; // 单位:字节 - - /* 下沉子图reset,保证子图不复用标志生效 */ - bool hugeData = curSize > SDMA_SEND_MAX_SIZE; - auto meta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, dataType, reduceType, hugeData, - curSize <= HCCL_SMALL_COUNT_32_KB); - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); - - if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE) { - ret = ReduceScatterDeterExecutor(tag, inputMem, outputMem, scratchMem, curCount, dataType, op, stream, - opInfo); - } else { - ret = ReduceScatterMeshExecutor(tag, inputMem, outputMem, scratchMem, curCount, dataType, op, stream, - opInfo); - } - CHK_PRT_RET((curCount == 0), HCCL_ERROR("[Loop][ReduceScatter]In OP_BASE curCount is zero"), HCCL_E_PARA); - inputOffset = curSize; - outputOffset = curSize; - if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { - CHK_RET(LaunchTask(dispatcher_, stream)); - } - } + HCCL_INFO("[SelectAlg] reduce_scatter newTag is [%s]", newTag.c_str()); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[ReduceScatterSelector][SelectAlg]tag[%s], reduce_scatter fsailed, retrun[%d]", + tag.c_str(), ret), ret); return ret; } -HcclResult ReduceScatterOperator::ReduceScatterDeterExecutor(const std::string &tag, DeviceMem& inputMem, - DeviceMem& outputMem, DeviceMem& scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream& stream, - HcomCollOpInfo *opInfo) +HcclResult ReduceScatterOperator::SelectAlgfor310P3(const OpParam& param, std::string& algName) { - u32 unitSize = SIZE_TABLE[dataType]; - std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 - std::unique_ptr outerExecutor; - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - bool bRet = currComm->commOuter.size() == 0; - CHK_PRT_RET( - bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterDeter]tag[%s],comm outer is empty", tag.c_str()), - HCCL_E_INTERNAL); - - std::unique_ptr &commOuter = currComm->commOuter[COMM_INDEX_0]; - CHK_SMART_PTR_NULL(commOuter); - - CHK_RET(hcclImpl_->ActiveRingStreams(tag, stream)); - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - - u64 reduceAttr = GetReduceAttr(inputMem, outputMem, dataType, op); - - if (((opInfo -> count) * unitSize > HCCL_SMALL_COUNT_32_KB) || - (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) || - ((deviceNumPerAggregation_ != DEVICE_EIGHT) && (deviceNumPerAggregation_ != DEVICE_FOUR))) { - outerExecutor.reset(new (std::nothrow) ReduceScatterLocalReduce(dispatcher_, reduceAttr, - streamInfo->ringStreams, streamInfo->ringSignal, streamInfo->ringSignalAux, commOuter->UserRank(), opInfo)); - } else { - outerExecutor.reset(new (std::nothrow) ReduceScatterHDStage(dispatcher_, reduceAttr, streamInfo->ringStreams, - streamInfo->ringSignal, streamInfo->ringSignalAux, commOuter->UserRank(), opInfo)); - } - - CHK_SMART_PTR_NULL(outerExecutor); - CHK_RET(outerExecutor->Prepare( - inputMem, scratchMem, outputMem, count, dataType, stream, op, OUTER_BRIDGE_RANK_ID, dataSegsSlice, 0)); - - CHK_RET( - outerExecutor->RegisterProfiler((commOuter->RankSize() << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commOuter->Rank(), - PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, stream)); - - CHK_RET(commOuter->RunExecutor(outerExecutor)); - HCCL_INFO("reducescatter mesh deter run success"); + algName = "ReduceScatterRing"; + HCCL_INFO("[SelectAlgfor310P3] reduce_scatter SelectAlgfor310P3 is algName [%s]", algName.c_str()); return HCCL_SUCCESS; } -HcclResult ReduceScatterOperator::ReduceScatterMeshExecutor(const std::string &tag, DeviceMem& inputMem, - DeviceMem& outputMem, DeviceMem& scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream& stream, - HcomCollOpInfo *opInfo) +HcclResult ReduceScatterOperator::SelectAlgfor910A(const OpParam& param, std::string& algName) { - u32 perDataSize = SIZE_TABLE[dataType]; - - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - bool bRet = currComm->commOuter.size() == 0; - CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshExecutor]tag[%s],comm outer is empty", - tag.c_str()), HCCL_E_INTERNAL); - - std::unique_ptr &commOuter = currComm->commOuter[COMM_INDEX_0]; - CHK_SMART_PTR_NULL(commOuter); + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_4P_MESH || topoType_ == TopoType::TOPO_TYPE_2P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING || topoType_ == TopoType::TOPO_TYPE_8P_RING; - /* ******************第一步: 节点间reducescatter *******************************/ - u32 commIndex = commOuter->Rank(); // 找到rank所在的节点间平面 - HCCL_DEBUG("commIndex:%u tagCommInfo_[tag].commInner.size():%llu", commIndex, currComm->commInner.size()); - bRet = commIndex >= currComm->commInner.size(); - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshExecutor]commIndex[%u] >=(tag[%s])comm size[%llu]", - commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); - - CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); - - u32 innerRankSize = currComm->commInner[commIndex]->RankSize(); - if (innerRankSize > 1) { - u64 reduceAttr = GetReduceAttr(inputMem, outputMem, dataType, op); - std::unique_ptr innerExecutor; - if (UseInterServerRingAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - CHK_SMART_PTR_NULL(innerExecutor); - HCCL_INFO("reducescatter mesh: using ring algo inter-server."); - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - // 申请临时内存作为scratch内存 - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else if (UseInterServerNHRAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter mesh: using nhr algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - // 申请临时内存作为scratch内存 - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else if (UseInterServerNHRV1Algo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNHRV1(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter mesh: using nhr_v1 algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - // 申请临时内存作为scratch内存 - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else if (UseInterServerNBAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter mesh: using nonuniform-bruck algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - // 申请临时内存作为scratch内存 - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else { - innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); - CHK_SMART_PTR_NULL(innerExecutor); - HCCL_INFO("reducescatter mesh: using halving-doubling algo inter-server."); - // 申请临时内存作为scratch内存 - u64 inputDataCount = inputMem.size() / perDataSize; - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, inputDataCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); // count是output的数据个数 - } - CHK_RET(innerExecutor->RegisterProfiler( - (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + currComm->commInner[commIndex]->Rank(), - PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); - - CHK_RET(currComm->commInner[commIndex]->RunExecutor(innerExecutor)); - } - - /* *******************第二步: 节点内reducescatter ******************************************/ - CHK_RET(hcclImpl_->ActiveRingStreams(tag, stream)); - - u32 sliceNum = currComm->commOuter[COMM_INDEX_0]->RankSize(); - // 根据数据量算每个环上数据的偏移和大小,把做完hd的slice均分成RankSize份 - std::vector dataSegsSlice; - CHK_RET(PrepareReduceScatterSliceData(count, perDataSize, sliceNum, dataSegsSlice)); - - // 每个server分配的slice大小 - u64 serverSliceSize = inputMem.size() / innerRankSize; - // 每个服务器对应的偏移 - u64 serverSliceOffset = serverSliceSize * currComm->commInner[commIndex]->Rank(); - - HCCL_DEBUG("inputMem.size()=%llu, commOuter->RankSize()=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ - "commIndex=%u commInner[commIndex]->rank=%u", inputMem.size(), commOuter->RankSize(), serverSliceSize, - serverSliceOffset, commIndex, currComm->commInner[commIndex]->Rank()); - - DeviceMem reduceScatterMeshInput = inputMem.range(serverSliceOffset, serverSliceSize); - CHK_SMART_PTR_NULL(reduceScatterMeshInput); - DeviceMem reduceScatterMeshOutput = scratchMem.range(serverSliceOffset, serverSliceSize); - CHK_SMART_PTR_NULL(reduceScatterMeshOutput); - - std::vector > &commMeshVec = currComm->commOuter; - if (hcclImpl_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && (dataType != HCCL_DATA_TYPE_INT64) && - (deviceType_ == DevType::DEV_TYPE_910B && op != HCCL_REDUCE_PROD)) { - CHK_RET(MultiStreamReduceScatterMeshAtomic(tag, reduceScatterMeshInput, reduceScatterMeshOutput, - count, dataType, op, dataSegsSlice, stream, commMeshVec, serverSliceOffset, opInfo)); - } else { - std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 - // mesh算法stream数量为rank数减1 - CHK_RET(ExecutorBase::PrepareSliceMeshStreams(dataSegsSlice, sliceNum - 1, multiStreamSlice)); - CHK_RET(MultiStreamReduceScatterMesh(tag, reduceScatterMeshInput, reduceScatterMeshOutput, - count, dataType, op, multiStreamSlice, stream, commMeshVec, serverSliceOffset)); - } - - bool isInlineReduce = IsSupportSDMAReduce(inputMem.ptr(), outputMem.ptr(), dataType, op); - if (isSingleMeshAggregation_ && (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) && - isInlineReduce && deviceType_ == DevType::DEV_TYPE_910B && opInfo != nullptr) { - /* 使用SDMA direct 拷贝,不需要再做DMAOUT->USROUT */ + if (isMeshTopo) { + algName = "ReduceScatterMeshExecutor"; + } else if (isRingTopo) { + algName = "ReduceScatterRingExecutor"; } else { - DeviceMem srcMem = inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, count * perDataSize); - CHK_SMART_PTR_NULL(srcMem); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem, srcMem, stream)); + algName = "ReduceScatterComm"; } - + HCCL_INFO("[SelectAlgfor910A] reduce_scatter SelectAlgfor910A is algName [%s]", algName.c_str()); return HCCL_SUCCESS; } -HcclResult ReduceScatterOperator::ReduceScatterDoubleRingExecutor(const std::string &tag, DeviceMem &inputMem, - DeviceMem &outputMem, DeviceMem &scratchMem, - u64 count, HcclDataType dataType, HcclReduceOp op, - Stream &stream, const HcomCollOpInfo *opInfo) +HcclResult ReduceScatterOperator::SelectAlgfor910B(const OpParam& param, std::string& algName) { - HCCL_INFO("[ReduceScatterOperator][ReduceScatterDoubleRingExecutor] The ReduceScatterDoubleRingExecutor starts."); - u32 perDataSize = 0; - CHK_RET(SalGetDataTypeSize(dataType, perDataSize)); - - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - - bool bRet = currComm->commOuter.size() == 0; - CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterDoubleRingExecutor]tag[%s],comm outer is empty", - tag.c_str()), HCCL_E_INTERNAL); - - std::unique_ptr &commOuter = currComm->commOuter[COMM_INDEX_0]; - CHK_SMART_PTR_NULL(commOuter); - - u32 ringNum; - if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { - ringNum = OUTER_PLANE_NUM_IN_NPRING_DOUBLE; - } else { - ringNum = OUTER_PLANE_NUM_IN_NPRING_SINGLE; - } - - std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 - std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 - u32 sliceNum = currComm->commOuter[COMM_INDEX_0]->RankSize(); - Slice sliceTemp; - u32 commIndex = currComm->commOuter[0]->Rank(); - commIndex = RefreshCommIdx(commIndex, nicList_, devicePhyId_); - - /* 超节点间通信域是commLevel2 */ - CHK_SMART_PTR_NULL(currComm->commLevel2[0]); - - u32 level2RankSize = currComm->commLevel2[0]->RankSize(); - if (level2RankSize > 1) { - /* ****************** 超节点间 reducescatter *******************************/ - u64 reduceAttr = GetReduceAttr(inputMem, scratchMem, dataType, op); - std::unique_ptr level2Executor; - - if (UseLevel2RingAlgo(algType_)) { - level2Executor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using ring algo inter-superPod."); - CHK_SMART_PTR_NULL(level2Executor); - - u64 ringCount = inputMem.size() / (level2RankSize * perDataSize); - CHK_RET(level2Executor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else { - level2Executor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using halving-doubling algo inter-superPod."); - - CHK_SMART_PTR_NULL(level2Executor); - u64 inputDataCount = inputMem.size() / perDataSize; - CHK_RET(level2Executor->Prepare(inputMem, inputMem, scratchMem, inputDataCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); // count是output的数据个数 - } - CHK_RET(level2Executor->RegisterProfiler( - (level2RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + currComm->commLevel2[0]->Rank(), - PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); - CHK_RET(currComm->commLevel2[0]->RunExecutor(level2Executor)); - - /* ****************** 节点间 reducescatter *******************************/ - HCCL_DEBUG("commIndex:%u tagCommInfo_[tag].commInner.size():%llu", commIndex, currComm->commInner.size()); - bRet = commIndex >= currComm->commInner.size(); - CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterDoubleRingExecutor]commIndex[%u]" \ - " >=(tag[%s])comm size[%llu]", commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); - - CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); - - u32 innerRankSize = currComm->commInner[commIndex]->RankSize(); - if (innerRankSize > 1) { - std::unique_ptr innerExecutor; - u32 level1Index = currComm->commInner[commIndex]->Rank(); + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; - if (UseInterServerRingAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using ring algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - - u64 ringSize = inputMem.size() / (innerRankSize * level2RankSize); - u64 ringCount = ringSize / perDataSize; - u64 level1SliceOffset = ringSize * level1Index; - DeviceMem level1InputMem = inputMem.range(level1SliceOffset, ringSize); - CHK_SMART_PTR_NULL(level1InputMem.ptr()); - - CHK_RET(innerExecutor->Prepare(level1InputMem, level1InputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0), level1SliceOffset)); - } else { - innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); - - CHK_SMART_PTR_NULL(innerExecutor); - u64 inputDataCount = inputMem.size() / (perDataSize * level2RankSize); - u64 level1SliceSize = inputMem.size() / level2RankSize; - u64 level1SliceOffset = level1SliceSize * level1Index; + bool isInlineReduce = + IsSupportSDMAReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType); + bool isRdmaReduce = IsSupportRDMAReduce(param.DataDes.dataType, param.reduceType); + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || + topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING; - DeviceMem level1InputMem = inputMem.range(level1SliceOffset, level1SliceSize); - // count是output的数据个数 - CHK_RET(innerExecutor->Prepare(level1InputMem, level1InputMem, scratchMem, inputDataCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0), level1SliceOffset)); - } - CHK_RET(innerExecutor->RegisterProfiler( - (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + currComm->commInner[commIndex]->Rank(), - PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); - CHK_RET(currComm->commInner[commIndex]->RunExecutor(innerExecutor)); + u64 dataSize = param.DataDes.count * unitSize; // 单位:字节 + u64 cclBufferSize = cclBufferManager_.GetInCCLbufferSize() / userRankSize_; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + std::string algTypeLevel1Tag; + CHK_RET(AutoSelectAlgTypeLevel1(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, dataSize, cclBufferSize, algTypeLevel1Tag, + isInlineReduce, isRdmaReduce)); + if (param.opBaseAtraceInfo != nullptr) { + CHK_RET(param.opBaseAtraceInfo->SavealgtypeTraceInfo(algTypeLevel1Tag, param.tag)); } - - /* *********** 节点内reducescatter (正常场景) *****************************/ - CHK_RET(hcclImpl_->ActiveRingStreams(tag, stream)); - - bool useInlineRduce = false; - bool isInlineReduce = IsSupportSDMAReduce(inputMem.ptr(), scratchMem.ptr(), dataType, op); - useInlineRduce = isInlineReduce && inlineReduceSwitchOn_; - multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, outputMem, dataSegsSlice, - tag); - bRet = (multiStreamSlice.size() != ringNum); - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]sliceNum-1[%u] != multiStreamSlice" \ - "size[%llu]", sliceNum - 1, multiStreamSlice.size()), HCCL_E_INTERNAL); - - DeviceMem srcMem; - // 每个server分配的slice大小 - u64 serverSliceSize = inputMem.size() / (innerRankSize * level2RankSize); - // 每个服务器对应的偏移 - u32 serverIndex = currComm->commInner[commIndex]->Rank(); - u64 serverSliceOffset = serverSliceSize * serverIndex; - HCCL_DEBUG("inputMem.size()=%llu, commOuter->RankSize()=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ - "commIndex=%u commInner[commIndex]->rank=%u", inputMem.size(), commOuter->RankSize(), serverSliceSize, - serverSliceOffset, commIndex, currComm->commInner[commIndex]->Rank()); - DeviceMem reduceScatterRingInput = inputMem.range(serverSliceOffset, serverSliceSize); - DeviceMem reduceScatterRingOutput = scratchMem.range(serverSliceOffset, serverSliceSize); - - u64 countLocal = serverSliceSize / perDataSize; - CHK_RET(MultiRingReduceScatter(tag, reduceScatterRingInput, reduceScatterRingOutput, countLocal, dataType, op, - multiStreamSlice, stream, PROF_STAGE_1, serverSliceOffset)); - - srcMem = inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, count * perDataSize); - CHK_SMART_PTR_NULL(srcMem.ptr()); - - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem, srcMem, stream)); - HCCL_INFO("reducescatter double ring run success"); - return HCCL_SUCCESS; } - // 节点内reduce scatter - CHK_RET(hcclImpl_->ActiveRingStreams(tag, stream)); - - HCCL_DEBUG("commIndex:%u tagCommInfo_[tag].commInner.size():%llu", commIndex, currComm->commInner.size()); - bRet = commIndex >= currComm->commInner.size(); - CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterDoubleRingExecutor]commIndex[%u]" \ - " >=(tag[%s])comm size[%llu]", commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); - - CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); - u32 innerRankSize = currComm->commInner[commIndex]->RankSize(); - - // 计算slice - std::vector > level0DataSegsSlice; - bool useInlineRduce = false; - bool isInlineReduce = IsSupportSDMAReduce(inputMem.ptr(), scratchMem.ptr(), dataType, op); - useInlineRduce = isInlineReduce && inlineReduceSwitchOn_; - multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, outputMem, dataSegsSlice, tag); - for (u32 ringIndex = 0; ringIndex < multiStreamSlice.size(); ringIndex++) { - std::vector dataSlice; - for (u32 level0Idx = 0; level0Idx < sliceNum; level0Idx++) { - Slice sliceTemp; - for (u32 level1Idx = 0; level1Idx < innerRankSize; level1Idx++) { - sliceTemp.size = multiStreamSlice[ringIndex][level0Idx].size; - sliceTemp.offset = - multiStreamSlice[ringIndex][level0Idx].offset + level1Idx * sliceNum * outputMem.size(); - dataSlice.push_back(sliceTemp); - } - } - level0DataSegsSlice.push_back(dataSlice); - } - std::vector> multRingsUserMemSlice; - if (opInfo == nullptr) { - multRingsUserMemSlice = level0DataSegsSlice; - } else { - for (u32 ringIndex = 0; ringIndex < level0DataSegsSlice.size(); ringIndex++) { - std::vector level1UserMemSlice; - for (auto &cclSlice : level0DataSegsSlice[ringIndex]) { - Slice tmpSlice; - tmpSlice.size = cclSlice.size; - tmpSlice.offset = - (cclSlice.offset / outputMem.size()) * opInfo->count * perDataSize + - multiStreamSlice[ringIndex][0].offset; - level1UserMemSlice.push_back(tmpSlice); - HCCL_DEBUG("rank[%u], ringIndex[%u], tmpSlice.offset=[%llu], size=[%llu]", - userRank_, ringIndex, tmpSlice.offset, tmpSlice.size); + if (isMeshTopo) { + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + if (SingleMeshInlineReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType)) { + if (topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE) { + algName = "ReduceScatterDeterExecutor"; + } else { + algName = "ReduceScatterMeshDmaEliminationExecutor"; + } + } else if (topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_DISABLE && + GetLevel1AlgType(algType_) == AlgTypeLevel1::ALG_LEVEL1_PIPELINE && + IsMultiMeshInlineReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType)) { + algName = "ReduceScatterMeshOpbasePipelineExecutor"; } - multRingsUserMemSlice.push_back(level1UserMemSlice); - } - } - // 区分消减拷贝场景 - if (opInfo != nullptr && innerRankSize > 1) { - HcomCollOpInfo opInfoByReduceScatterDMAreduce = *opInfo; - opInfoByReduceScatterDMAreduce.outputAddr = nullptr; - CHK_RET(MultiRingReduceScatter(tag, inputMem, scratchMem, count, dataType, op, level0DataSegsSlice, - stream, PROF_STAGE_1, 0, &opInfoByReduceScatterDMAreduce, multRingsUserMemSlice)); - } else { - CHK_RET(MultiRingReduceScatter(tag, inputMem, scratchMem, count, dataType, op, - level0DataSegsSlice, stream, PROF_STAGE_1, 0, opInfo, multRingsUserMemSlice)); - } - // 对于单server图模式场景最后一步需要把数据从ccl input拷贝到ccl output上 - if (innerRankSize == 1 && opInfo == nullptr) { - DeviceMem srcMem = inputMem.range(userRank_ * outputMem.size(), outputMem.size()); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem, srcMem, stream)); - } - if (innerRankSize > 1) { - // 节点间做reduce scatter(ring/NHR) - u64 reduceAttr = GetReduceAttr(inputMem, scratchMem, dataType, op); - std::unique_ptr innerExecutor; - - // 计算slice - u32 level0ServerIndex = 0; - HcclResult ret = currComm->commOuter[COMM_INDEX_0]->GetRankByUserRank(userRank_, level0ServerIndex); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterDoubleRingExecutor]Get "\ - "Rank[%u] by User Rank[%u] from CommOuter[%u] Failed!", level0ServerIndex, userRank_, commIndex), ret); - - std::vector level1DataSegsSlice; - for (u32 i = 0; i < innerRankSize; i++) { - sliceTemp.size = outputMem.size(); - u32 level1UserRank; - CHK_RET(currComm->commInner[commIndex]->GetUserRankByRank(i, level1UserRank)); - sliceTemp.offset = level1UserRank * outputMem.size(); - level1DataSegsSlice.push_back(sliceTemp); - HCCL_DEBUG("rank[%u], level1DataSegsSlice[%u].offset=%llu, size=[%llu]", userRank_, i, - sliceTemp.offset, sliceTemp.size); - } - if (UseInterServerRingAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using ring algo inter-server."); - } else if (UseInterServerNBAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using nonuniform-bruck algo inter-server."); } else { - innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using nonuniform-hierarchical-ring algo inter-server."); - } - CHK_SMART_PTR_NULL(innerExecutor); - - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, count, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, level1DataSegsSlice)); - CHK_RET(innerExecutor->RegisterProfiler( - (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + currComm->commInner[commIndex]->Rank(), - PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, stream)); - CHK_RET(currComm->commInner[commIndex]->RunExecutor(innerExecutor)); - - // 区分消减拷贝场景(消减拷贝数据需要拷贝到user output上) - DeviceMem srcMem = inputMem.range(userRank_ * outputMem.size(), outputMem.size()); - if (opInfo != nullptr) { - DeviceMem dstMem = DeviceMem::create(static_cast(opInfo->outputAddr), outputMem.size()); - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, stream)); - } else { - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem, srcMem, stream)); - } - } - HCCL_INFO("reducescatter double ring run success"); - return HCCL_SUCCESS; -} - -HcclResult ReduceScatterOperator::ReduceScatterMeshOpbasePipelineExecutor(const std::string &tag, DeviceMem& scratchMem, - u64 count, HcclDataType dataType, HcclReduceOp op, Stream& stream, HcomCollOpInfo *opInfo) -{ - HCCL_INFO("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor] begins."); - ReduceType reduceType = ((op != HCCL_REDUCE_PROD) && (dataType != HCCL_DATA_TYPE_INT64)) ? - ReduceType::INLINE_REDUCE : ReduceType::TBE_REDUCE; - // step 1 先获取 comm inner \ comm outer 的value - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - - CHK_PRT_RET(currComm->commOuter.empty(), - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor]errNo[0x%016llx]", \ - " comm outer is empty", HCCL_ERROR_CODE(HCCL_E_PARA)), HCCL_E_PARA); - u32 commIndex = 0; - u32 serverIndex = 0; - - CHK_SMART_PTR_NULL(currComm->commOuter[COMM_INDEX_0]); - commIndex = currComm->commOuter[COMM_INDEX_0]->Rank(); - bool bRet = commIndex >= currComm->commInner.size(); - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor]errNo[0x%016llx] commIndex[%u]" - " >= (tag:[%s]) comm_inner.size[%llu]", - HCCL_ERROR_CODE(HCCL_E_INTERNAL), commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); - CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); - HcclResult ret = currComm->commInner[commIndex]->GetRankByUserRank(userRank_, serverIndex); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor]Get Rank[%u] by User Rank[%u]" \ - "from CommInner[%u] Failed!", serverIndex, userRank_, devicePhyId_), ret); - bRet = currComm->commOuter.size() == 0; - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor]tag[%s],comm outer is empty", - tag.c_str()), HCCL_E_INTERNAL); - std::unique_ptr &commOuter = currComm->commOuter[COMM_INDEX_0]; - CHK_SMART_PTR_NULL(commOuter); - - innerStreamInfo_t *streamInfo = hcclImpl_->GetStreamInfo(tag); - CHK_PRT_RET(streamInfo == nullptr, - HCCL_ERROR("[GetStreamInfo]errNo[0x%016llx] tag[%s] can't find in stream info", - HCCL_ERROR_CODE(HCCL_E_NOT_FOUND), tag.c_str()), HCCL_E_PARA); - - bRet = commIndex >= currComm->commInner.size(); - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor]errNo[0x%016llx]" - " commIndex[%u] >= (tag:[%s])comm_inner.size[%llu]", - HCCL_ERROR_CODE(HCCL_E_INTERNAL), commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); - std::unique_ptr &commInner = currComm->commInner[commIndex]; - - u32 unitSize = SIZE_TABLE[dataType]; - DeviceMem userInMem = DeviceMem::create(opInfo->inputAddr, count * unitSize); - u64 reduceAttr = GetReduceAttr(userInMem, scratchMem, dataType, op); - auto bufferPtr = scratchMem.ptr(); - u64 bufferSize = 0; - CHK_RET(cclBufferManager_.GetInCCLbuffer(bufferPtr, bufferSize)); - u64 maxCountPerLoop = ((bufferSize / (HCCL_MIN_SLICE_ALIGN_910B * PIPELINE_DEPTH)) \ - * HCCL_MIN_SLICE_ALIGN_910B - HCCL_MIN_SLICE_ALIGN_910B) / unitSize; - - auto originalAlgTypeLevel1 = static_cast(algType_) >> HCCL_LEVEL_ALGO_WIDTH; - u64 curCount = 0; - u64 curOffset = 0; - u64 curSize = 0; - u8 *curInputPtr = static_cast(opInfo->inputAddr); - u8 *curOutputPtr = static_cast(opInfo->outputAddr); - CHK_PTR_NULL(curInputPtr); - CHK_PTR_NULL(curOutputPtr); - HCCL_INFO("[ReduceScatterOperator][ReduceScatterMeshOpbasePipelineExecutor]maxCountPerLoop[%llu]", maxCountPerLoop); - for (u64 countLeft = count; countLeft > 0; countLeft -= curCount) { - curInputPtr += curSize; - curOutputPtr += curSize; - opInfo->inputAddr = curInputPtr; - opInfo->outputAddr = curOutputPtr; - curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; - curSize = curCount * unitSize; - bool hugeData = curSize > RDMA_SEND_MAX_SIZE || curSize > SDMA_SEND_MAX_SIZE; - auto meta = HcclOpMetaInfo::GetOneForReduceScatter(originalAlgTypeLevel1, dataType, reduceType, - hugeData); - CHK_RET(InitTask(dispatcher_, stream, meta.isEnableCache, meta.GetCacheKey())); - std::unique_ptr executor; - executor.reset(new (std::nothrow) ReduceScatterPipeline(dispatcher_, reduceAttr)); - CHK_SMART_PTR_NULL(executor); - CHK_RET(executor->Prepare(opInfo, scratchMem, curCount, bufferSize, curOffset, commOuter, commInner, stream, - streamInfo->ringStreams, streamInfo->ringSignal, streamInfo->ringSignalAux)); - CHK_RET(executor->RunAsync()); - CHK_RET(LaunchTask(dispatcher_, stream)); - curOffset += curSize; - } - return HCCL_SUCCESS; -} - -HcclResult ReduceScatterOperator::ReduceScatterRingExecutor(const std::string &tag, DeviceMem& inputMem, - DeviceMem& outputMem, DeviceMem& scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream& stream, - const HcomCollOpInfo *opInfo) -{ - HCCL_INFO("[ReduceScatterOperator][ReduceScatterRingExecutor] The ReduceScatterRingExecutor starts."); - u32 perDataSize = 0; - CHK_RET(SalGetDataTypeSize(dataType, perDataSize)); - - CommInfo *currComm; - hcclImpl_->GetCommInfo(currComm, tag); - bool bRet = currComm->commOuter.size() == 0; - CHK_PRT_RET(bRet, HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]tag[%s],comm outer is empty", - tag.c_str()), HCCL_E_INTERNAL); - - std::unique_ptr &commOuter = currComm->commOuter[COMM_INDEX_0]; - CHK_SMART_PTR_NULL(commOuter); - u32 ringNum = (topoType_ == TopoType::TOPO_TYPE_8P_RING) ? OUTER_PLANE_NUM_IN_8PRING : - OUTER_PLANE_NUM_IN_NPRING_SINGLE; - - /* ******************网口裁剪步骤: 节点内allreduce *******************************/ - std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 - std::vector > multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 - u32 sliceNum = currComm->commOuter[COMM_INDEX_0]->RankSize(); - Slice sliceTemp; - u32 commIndex = (ringNum == OUTER_PLANE_NUM_IN_8PRING) ? devicePhyId_ : currComm->commOuter[0]->Rank(); - bool isMultiNic = topoType_ == TopoType::TOPO_TYPE_8P_RING && nicList_.size() != DEVICE_EIGHT; - if (isMultiNic) { - u64 inputDataCount = inputMem.size() / perDataSize; - CHK_RET(ExecutorBase::PrepareSliceData(inputDataCount, perDataSize, sliceNum, 0, dataSegsSlice)); - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag); - CHK_PRT_RET(multiStreamSlice.size() != ringNum, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]ringNum[%u] !=multiStreamSlice size[%llu]", - ringNum, multiStreamSlice.size()), HCCL_E_INTERNAL); - - CHK_RET(MultiRingAllReduce(tag, inputMem, scratchMem, inputDataCount, dataType, op, multiStreamSlice, - stream, PROF_STAGE_0)); - - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, inputMem, scratchMem, stream)); - } - - std::vector::iterator iterNic = std::find(nicList_.begin(), nicList_.end(), devicePhyId_); - bool innRunRet = isMultiNic && (iterNic == nicList_.end()); - if (!innRunRet) { // 1. 8P ring的拓扑。2. 网口不满配。3. 当前device不出网口。 的情况下不进行节点间的reduce scatter - /* ******************第一步: 节点间reducescatter *******************************/ - HCCL_DEBUG("commIndex:%u tagCommInfo_[tag].commInner.size():%llu", commIndex, currComm->commInner.size()); - bRet = commIndex >= currComm->commInner.size(); - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]commIndex[%u] >=(tag[%s])comm size[%llu]", \ - commIndex, tag.c_str(), currComm->commInner.size()), HCCL_E_INTERNAL); - - CHK_SMART_PTR_NULL(currComm->commInner[commIndex]); - - u32 innerRankSize = currComm->commInner[commIndex]->RankSize(); - if (innerRankSize > 1) { - u64 reduceAttr = GetReduceAttr(inputMem, scratchMem, dataType, op); - std::unique_ptr innerExecutor; - - if (UseInterServerRingAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterRing(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using ring algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else if (UseInterServerNHRAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNHR(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using nhr algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else if (UseInterServerNHRV1Algo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNHRV1(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using nhr_v1 algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else if (UseInterServerNBAlgo(algType_)) { - innerExecutor.reset(new (std::nothrow) ReduceScatterNB(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using nonuniform-bruck algo inter-server."); - CHK_SMART_PTR_NULL(innerExecutor); - - u64 ringSize = inputMem.size() / innerRankSize; - u64 ringCount = ringSize / perDataSize; - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, ringCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); - } else { - innerExecutor.reset(new (std::nothrow) ReduceScatterRecursiveHalvingDoubling(dispatcher_, reduceAttr)); - HCCL_INFO("reducescatter ring: using halving-doubling algo inter-server."); - - CHK_SMART_PTR_NULL(innerExecutor); - u64 inputDataCount = inputMem.size() / perDataSize; - CHK_RET(innerExecutor->Prepare(inputMem, inputMem, scratchMem, inputDataCount, dataType, - stream, op, OUTER_BRIDGE_RANK_ID, std::vector(0))); // count是output的数据个数 + if (SingleMeshInlineReduce(param.inputPtr, param.outputPtr, param.DataDes.dataType, param.reduceType)) { + if (topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_CONFIG_ENABLE) { + algName = "ReduceScatterDeterExecutor"; + } else { + algName = "ReduceScatterMeshExecutor"; + } } - CHK_RET(innerExecutor->RegisterProfiler( - (innerRankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + currComm->commInner[commIndex]->Rank(), - PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); - CHK_RET(currComm->commInner[commIndex]->RunExecutor(innerExecutor)); } - } - - /* ***********第二步: 节点内reducescatter(正常场景), 节点内多根结点scatter(网口裁剪)*****************************/ - CHK_RET(hcclImpl_->ActiveRingStreams(tag, stream)); - - bool useInlineRduce = false; - bool isInlineReduce = IsSupportSDMAReduce(inputMem.ptr(), scratchMem.ptr(), dataType, op); - useInlineRduce = isInlineReduce && inlineReduceSwitchOn_; - multiStreamSlice = ReduceScatterRingSlicePrepare(ringNum, sliceNum, useInlineRduce, outputMem, dataSegsSlice, tag); - bRet = (multiStreamSlice.size() != ringNum); - CHK_PRT_RET(bRet, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]sliceNum-1[%u] != multiStreamSlice size[%llu]", \ - sliceNum - 1, multiStreamSlice.size()), HCCL_E_INTERNAL); - - if (isMultiNic) { // 网口裁剪情况下需要改变slice最终在rank上位置 - PrepareMultiRingSlice(dataSegsSlice, tag, false, nicList_); // 刷新多环ringRankList信息 - std::vector> ringNics; - CHK_RET(hcclImpl_->GetRingNics(tag, ringNics)); - - for (u32 ringIdx = 0; ringIdx < ringNum; ringIdx++) { // 按第一个网口位置改变slice最终在rank上的位置 - u32 firstNicIdx = ringNics[ringIdx][0]; - std::rotate(multiStreamSlice[ringIdx].begin(), multiStreamSlice[ringIdx].begin() + firstNicIdx, - multiStreamSlice[ringIdx].end()); - } - } - - DeviceMem srcMem; - if (isMultiNic) { - u32 userRankSize = currComm->commOuter[0]->UserRankSize(); - u32 innerRankSize = userRankSize / DEVICE_EIGHT; - // 每个server分配的slice大小 - CHK_PRT_RET(innerRankSize == 0, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]innerRankSize is illegal"), HCCL_E_PARA); - u64 serverSliceSize = inputMem.size() / innerRankSize; - // 每个服务器对应的偏移 - u32 serverIndex = hcclImpl_->GetInnerCommRank(commIndex); - CHK_PRT_RET(serverIndex == INVALID_VALUE_RANKID, - HCCL_ERROR("[ReduceScatterOperator][ReduceScatterRingExecutor]get rank of "\ - "bridgeRank failed, commIdx[%u]", commIndex), HCCL_E_PARA); - u64 serverSliceOffset = serverSliceSize * serverIndex; - if (UseInterServerRingAlgo(algType_)) { - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, scratchMem, inputMem, stream)); + if (algName.empty()) { + algName = "ReduceScatterMeshExecutor"; } - DeviceMem reduceScatterRingOutput = scratchMem.range(serverSliceOffset, serverSliceSize); - CHK_SMART_PTR_NULL(reduceScatterRingOutput.ptr()); - u64 countLocal = serverSliceSize / perDataSize; - CHK_RET(MultiRingMultiRootScatter(tag, reduceScatterRingOutput, reduceScatterRingOutput, countLocal, dataType, - multiStreamSlice, serverIndex * DEVICE_EIGHT, stream, serverSliceOffset)); - - srcMem = reduceScatterRingOutput.range(dataSegsSlice[devicePhyId_].offset, count * perDataSize); - CHK_SMART_PTR_NULL(srcMem.ptr()); + } else if (isRingTopo) { + algName = "ReduceScatterRingExecutor"; } else { - u32 innerRankSize = currComm->commInner[commIndex]->RankSize(); - // 每个server分配的slice大小 - u64 serverSliceSize = inputMem.size() / innerRankSize; - // 每个服务器对应的偏移 - u32 serverIndex = currComm->commInner[commIndex]->Rank(); - u64 serverSliceOffset = serverSliceSize * serverIndex; - HCCL_DEBUG("inputMem.size()=%llu, commOuter->RankSize()=%u, serverSliceSize=%llu, serverSliceOffset=%llu "\ - "commIndex=%u commInner[commIndex]->rank=%u", inputMem.size(), commOuter->RankSize(), serverSliceSize, - serverSliceOffset, commIndex, currComm->commInner[commIndex]->Rank()); - DeviceMem reduceScatterRingInput = inputMem.range(serverSliceOffset, serverSliceSize); - CHK_SMART_PTR_NULL(reduceScatterRingInput.ptr()); - DeviceMem reduceScatterRingOutput = scratchMem.range(serverSliceOffset, serverSliceSize); - CHK_SMART_PTR_NULL(reduceScatterRingOutput.ptr()); - u64 countLocal = serverSliceSize / perDataSize; - CHK_RET(MultiRingReduceScatter(tag, reduceScatterRingInput, reduceScatterRingOutput, countLocal, dataType, op, - multiStreamSlice, stream, PROF_STAGE_1, serverSliceOffset, opInfo)); - - srcMem = inputMem.range(serverSliceOffset + dataSegsSlice[commIndex].offset, count * perDataSize); - CHK_SMART_PTR_NULL(srcMem.ptr()); + algName = "ReduceScatterComm"; } - - CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem, srcMem, stream)); - + HCCL_INFO("[SelectAlgfor910B] reduce_scatter SelectAlgfor910B is algName [%s]", algName.c_str()); return HCCL_SUCCESS; } -std::vector> ReduceScatterOperator::ReduceScatterRingSlicePrepare(u32 ringNum, u32 sliceNum, - bool useInlineReduce, DeviceMem& outputMem, std::vector& dataSegsSlice, const std::string &tag) +HcclResult ReduceScatterOperator::SelectAlgfor91073(const OpParam& param, std::string& algName) { - std::vector> multiStreamSlice; - u64 outputMenSize = outputMem.size(); - dataSegsSlice.clear(); - Slice sliceTemp; - for (u32 i = 0; i < sliceNum; i++) { // 根据数据量算每个环上数据的偏移和大小 - sliceTemp.size = outputMenSize; - sliceTemp.offset = outputMenSize * i; - dataSegsSlice.push_back(sliceTemp); - } - - // 再将每个 slice 划分为 ringNum 份 - if (ringNum == OUTER_PLANE_NUM_IN_8PRING) { - if (useInlineReduce) { - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag); - } else if (outputMem.size() % CCE_REDUCE_ALIGN_SIZE == 0) { - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag); - } else { - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, true); - } - } else if (ringNum == OUTER_PLANE_NUM_IN_NPRING_DOUBLE) { - // 双环场景,需要传入正确的 niclist (不涉及网口裁剪) - if (useInlineReduce) { - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, false, nicList_); - } else if (outputMem.size() % CCE_REDUCE_ALIGN_SIZE == 0) { - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, false, nicList_); + if (topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING) { + algName = "ReduceScatterRingFor91073Executor"; + } else if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + if (GetExternalInputEnableRdmaSdmaConcurrent()) { + HcclResult ret = SetInterServerRingAlgo(algType_); + HCCL_WARNING("[ReduceScatterOperator][SelectAlgfor91073] env HCCL_CONCURRENT_ENABLE is set, " + "set interserver algo to ring."); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[ReduceScatterOperator][SelectAlgfor91073]errNo[0x%016llx] tag[%s], ReduceScatter set inter " + "server ring algo failed", HCCL_ERROR_CODE(ret), param.tag.c_str()), ret); + algName = "ReduceScatterDoubleRingConcurrentExecutor"; } else { - multiStreamSlice = PrepareMultiRingSlice(dataSegsSlice, tag, true, nicList_); + algName = "ReduceScatterRingFor91073Executor"; } } else { - multiStreamSlice.push_back(dataSegsSlice); + algName = "ReduceScatterComm"; } - return multiStreamSlice; + // 91073超节点只支持server间ring,NB和NHR,默认需继续使用NHR + if (!(UseInterServerRingAlgo(algType_) || UseInterServerNBAlgo(algType_))) { + HcclResult ret = SetInterServerNHRAlgo(algType_); + HCCL_WARNING("[ReduceScatterOperator][SelectAlgfor91073] only support ring, NB and NHR in AlgoLevel1 yet, "\ + "default is algType=NHR."); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[ReduceScatterOperator][SelectAlgfor91073]errNo[0x%016llx] tag[%s], ReduceScatter set inter "\ + "server nhr algo failed", HCCL_ERROR_CODE(ret), param.tag.c_str()), ret); + } + + HCCL_INFO("[SelectAlgfor91073] reduce_scatter SelectAlgfor91073 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; } +REGISTER_OP(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, ReduceScatter, ReduceScatterOperator); + } \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.h b/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.h index ae96435aec657f9aa9cd430e2516d548a29f1556..699024919b506464e676339429a0574f25519da4 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/reduce_scatter_operator.h @@ -12,63 +12,23 @@ #define REDUCE_SCATTER_OPERATOR_H #include "common_operator.h" +#include "coll_alg_op_registry.h" namespace hccl { class ReduceScatterOperator : public CommonOperator { public: - ReduceScatterOperator(std::unique_ptr &pImpl); + ReduceScatterOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~ReduceScatterOperator(); - HcclResult ReduceScatter(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, HcomCollOpInfo *opInfo = nullptr); - HcclResult ReduceScatterOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, - HcclDataType dataType, HcclReduceOp op, Stream stream, - const std::unique_ptr &opBaseAtraceInfo = nullptr); + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); private: - // reducescatter - HcclResult RunReduceScatter(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, - HcomCollOpInfo *opInfo = nullptr); + HcclResult SelectAlgfor310P3(const OpParam& param, std::string& algName); - HcclResult ReduceScatterComm(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream); + HcclResult SelectAlgfor910A(const OpParam& param, std::string& algName); - HcclResult ReduceScatterDMAReduceRingExecutorMiddlelayer(const std::string &tag, DeviceMem &inputMem, - DeviceMem &outputMem, DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, - HcomCollOpInfo *opInfo); + HcclResult SelectAlgfor910B(const OpParam& param, std::string& algName); - HcclResult ReduceScatterMeshOpbaseExecutorMiddlelayer(const std::string &tag, DeviceMem &inputMem, - DeviceMem &outputMem, DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, - HcomCollOpInfo *opInfo = nullptr); - - HcclResult ReduceScatterDeterExecutor(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, - HcomCollOpInfo *opInfo = nullptr); - - HcclResult ReduceScatterMeshExecutor(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, - HcomCollOpInfo *opInfo = nullptr); - - HcclResult ReduceScatterDoubleRingExecutor(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, - Stream &stream, const HcomCollOpInfo *opInfo = nullptr); - - HcclResult ReduceScatterDoubleRingConcurrentExecutor(const std::string &tag, DeviceMem &inputMem, - DeviceMem &outputMem, DeviceMem &scratchMem, u64 count, HcclDataType dataType, - HcclReduceOp op, Stream &stream, const HcomCollOpInfo *opInfo = nullptr); - - HcclResult ReduceScatterRingExecutor(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - DeviceMem &scratchMem, u64 count, HcclDataType dataType, HcclReduceOp op, - Stream &stream, const HcomCollOpInfo *opInfo = nullptr); - - HcclResult ReduceScatterMeshOpbasePipelineExecutor(const std::string &tag, DeviceMem &scratchMem, - u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream, HcomCollOpInfo *opInfo); - - HcclResult ReduceScatterCommFor310P(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, - u64 count, HcclDataType dataType, HcclReduceOp op, Stream &stream); - - std::vector> ReduceScatterRingSlicePrepare(u32 ringNum, u32 sliceNum, bool useInlineReduce, - DeviceMem& outputMem, std::vector& dataSegsSlice, const std::string &tag); + HcclResult SelectAlgfor91073(const OpParam& param, std::string& algName); }; } diff --git a/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.cc b/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.cc index 81fd5efb2ef376d05873e9f3297acd0c8b7822b3..8b5e3b2dc363724fa3bc433c4de1d9281c79444c 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.cc @@ -10,7 +10,6 @@ #include "coll_alg_op_registry.h" - namespace hccl { CollAlgOpRegistry *CollAlgOpRegistry::Instance() @@ -31,13 +30,13 @@ HcclResult CollAlgOpRegistry::Register(const HcclCMDType &opType, const CollAlgO } std::unique_ptr CollAlgOpRegistry::GetAlgOp( - const HcclCMDType &opType, std::unique_ptr &pImpl) + const HcclCMDType &opType, std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) { if (opCreators_.find(opType) == opCreators_.end()) { HCCL_ERROR("[CollAlgOpRegistry]Creator for op type[%d] has not registered.", opType); return nullptr; } - return std::unique_ptr(opCreators_[opType](pImpl)); + return std::unique_ptr(opCreators_[opType](pImpl, topoMatcher)); } } // namespace Hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.h b/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.h index 542fa76c5e31d5f2cefa9dcc6368318f33438b0a..4da7b2b186edacb49e9e5617dd7b7bc008b84841 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.h +++ b/src/domain/collective_communication/algorithm/impl/operator/registry/coll_alg_op_registry.h @@ -19,19 +19,21 @@ namespace hccl { -using CollAlgOpCreator = std::function &)>; +using CollAlgOpCreator = std::function &, std::unique_ptr &)>; -template static CollAlgOperator *DefaultOpCreator(std::unique_ptr &pImpl) +template static CollAlgOperator *DefaultOpCreator(std::unique_ptr &pImpl, + std::unique_ptr &topoMatcher) { static_assert(std::is_base_of::value, "CollAlgOp type must derived from Hccl::CollAlgOperator"); - return new (std::nothrow) P(pImpl); + return new (std::nothrow) P(pImpl, topoMatcher); } class CollAlgOpRegistry { public: static CollAlgOpRegistry *Instance(); HcclResult Register(const HcclCMDType &opType, const CollAlgOpCreator &collAlgOpCreator); - std::unique_ptr GetAlgOp(const HcclCMDType &opType, std::unique_ptr &pImpl); + std::unique_ptr GetAlgOp(const HcclCMDType &opType, std::unique_ptr &pImpl, + std::unique_ptr &topoMatcher); private: std::map opCreators_; diff --git a/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.cc index ff6e2e86ce79f76773b61cec2a1c68bb4576d0a2..441eb2e77d9d5c527a30d7315f742ae5cf506019 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.cc @@ -12,16 +12,18 @@ #include "device_capacity.h" #include "rank_consistent.h" #include "executor_impl.h" +#include "hccl_alg.h" namespace hccl { -ScatterOperator::ScatterOperator(std::unique_ptr &pImpl) - : CommonOperator(pImpl, HcclCMDType::HCCL_CMD_SCATTER) +ScatterOperator::ScatterOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CommonOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_SCATTER) { - // 由于scatter只支持server间ring和nb,如果非nb需要重定向到ring - if (!UseInterServerNHRAlgo(algType_) && !UseInterServerNBAlgo(algType_)) { - HCCL_INFO("[ScatterOperator][ScatterOperator] algType[%d] is not supported, reset algType=ring", algType_); + // 由于scatter只支持server间ring、nb和nhr,其他算法需要重定向到ring + if (!UseInterServerNHRAlgo(algType_) && !UseInterServerNBAlgo(algType_) && !UseInterServerRingAlgo(algType_)) { + HCCL_INFO("[ScatterOperator][ScatterOperator] algType[%s] is not supported, reset algType=ring", + HcclAlg::AlgTypeToStr(algType_).c_str()); SetInterServerRingAlgo(algType_); } } @@ -469,4 +471,50 @@ HcclResult ScatterOperator::ScatterRingExecutor(const std::string &tag, DeviceMe return HCCL_SUCCESS; } + +HcclResult ScatterOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) +{ + HcclResult ret = HCCL_SUCCESS; + newTag = param.tag; + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && UseInterServerHDAlgo(algType_)) { + u32 part1Size = 2 * (moduleNum_ - (1 << static_cast(log2(moduleNum_)))); + u32 rootId = param.root / deviceNumPerAggregation_; + std::string appendTag = std::to_string((rootId >= part1Size) || ((rootId % 2) == 0)); + newTag = newTag + '_' + appendTag; + if (param.opBaseAtraceInfo != nullptr) { + CHK_RET(param.opBaseAtraceInfo->SavealgtypeTraceInfo(appendTag, param.tag)); + } + } + + // 由于scatter只支持server间ring,nb和NHR,如果不是需要重定向到ring + if (!UseInterServerNHRAlgo(algType_) && !UseInterServerNBAlgo(algType_) && !UseInterServerRingAlgo(algType_)) { + HCCL_INFO("[ScatterOperator][Scatter] algType[%s] is not supported, reset algType=ring", + HcclAlg::AlgTypeToStr(algType_).c_str()); + ret = SetInterServerRingAlgo(algType_); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[ScatterOperator][Scatter]errNo[0x%016llx] tag[%s],scatter set inter server "\ + "algo failed", HCCL_ERROR_CODE(ret), newTag.c_str()), ret); + } + + bool isMeshTopo = topoType_ == TopoType::TOPO_TYPE_NP_MESH || topoType_ == TopoType::TOPO_TYPE_4P_MESH || + topoType_ == TopoType::TOPO_TYPE_2P_MESH || topoType_ == TopoType::TOPO_TYPE_1P_MESH; + bool isRingTopo = topoType_ == TopoType::TOPO_TYPE_NP_SINGLE_RING || topoType_ == TopoType::TOPO_TYPE_8P_RING || + topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING; + + if (isMeshTopo) { + algName = "ScatterMeshExecutor"; + } else if (isRingTopo) { + algName = "ScatterRingExecutor"; + } else { + algName = "ScatterCommExecutor"; + } + if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + newTag = newTag + algName; + HCCL_INFO("[SelectAlg] Scatter newTag is [%s] algName is [%s]", newTag.c_str(), algName.c_str()); + } + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_SCATTER, Scatter, ScatterOperator); } diff --git a/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.h b/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.h index 4d871bb4a99f278c2402b890203910b16ca04c6a..627bd0e42a53f89439c4fdf521219a7d9425aa25 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/scatter_operator.h @@ -12,17 +12,20 @@ #define SCATTER_OPERATOR_H #include "common_operator.h" +#include "coll_alg_op_registry.h" namespace hccl { class ScatterOperator : public CommonOperator { public: - ScatterOperator(std::unique_ptr &pImpl); + ScatterOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~ScatterOperator(); HcclResult Scatter(const std::string &tag, void *inputPtr, void *outputPtr, u64 recvCount, HcclDataType dataType, u32 root, Stream stream); HcclResult ScatterOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 recvCount, HcclDataType dataType, u32 root, Stream stream, const std::unique_ptr &opBaseAtraceInfo = nullptr); + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag); private: HcclResult RunScatter(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, u64 count, HcclDataType dataType, u32 root, Stream &stream); diff --git a/src/domain/collective_communication/algorithm/impl/operator/send_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/send_operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f7a2586ac52f20e61b6944e4fc3ac5391aa2943 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/operator/send_operator.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "send_operator.h" +#include "rank_consistent.h" +#include "executor_impl.h" + +namespace hccl { +SendOperator::SendOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CollAlgOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_SEND) +{ +} + +SendOperator::~SendOperator() +{ +} + +HcclResult SendOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) +{ + algName = "SendExecutor"; + if (GetWorkflowMode() != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + newTag = tag; + } else { + newTag = tag + algName; + } + HCCL_INFO("[SelectAlg] send newTag is [%s]", newTag.c_str()); + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_SEND, Send, SendOperator); +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/send_operator.h b/src/domain/collective_communication/algorithm/impl/operator/send_operator.h new file mode 100644 index 0000000000000000000000000000000000000000..3b2cfe3a1adc9667d87192d57a898fa91d4298d6 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/operator/send_operator.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef SEND_OPERATOR_H +#define SEND_OPERATOR_H + +#include "common_operator.h" +#include "coll_alg_op_registry.h" +#include + +namespace hccl { +class SendOperator : public CollAlgOperator { +public: + SendOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); + ~SendOperator(); + + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag); +}; +} + +#endif /** __SEND_OPERATOR_H__ */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.cc index d2228268193d8fafa4966a828069cd5938d104c7..64cc6d1fd1470316827a5c9a78ac6f9cec278271 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.cc @@ -15,8 +15,8 @@ #define BATCH_SEND_RECV_TAG "_targetRanksHash_" namespace hccl { -SendReceiveOperator::SendReceiveOperator(std::unique_ptr &pImpl) - : CollAlgOperator(pImpl, HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) +SendReceiveOperator::SendReceiveOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher) + : CollAlgOperator(pImpl, topoMatcher, HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) { } diff --git a/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.h b/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.h index 0d7d5c6d53f27634c1903aef5fc4e783cca7522d..ed0808829c6124bb73c2d2d04f76892a014510c0 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.h +++ b/src/domain/collective_communication/algorithm/impl/operator/send_receive_operator.h @@ -17,7 +17,7 @@ namespace hccl { class SendReceiveOperator : public CollAlgOperator { public: - SendReceiveOperator(std::unique_ptr &pImpl); + SendReceiveOperator(std::unique_ptr &pImpl, std::unique_ptr &topoMatcher); ~SendReceiveOperator(); HcclResult SendRun(const std::string &tag, DeviceMem& inputPtr, u64 count, HcclDataType dataType, u32 destUserRank, Stream stream); diff --git a/src/domain/collective_communication/algorithm/impl/resource_manager/hccl_socket_manager.h b/src/domain/collective_communication/algorithm/impl/resource_manager/hccl_socket_manager.h index aa782c33cd51b671e88ebbabd4629882b9de7ff1..5e048826900ef3ab28e6193635bf18b28d020c33 100644 --- a/src/domain/collective_communication/algorithm/impl/resource_manager/hccl_socket_manager.h +++ b/src/domain/collective_communication/algorithm/impl/resource_manager/hccl_socket_manager.h @@ -1,6 +1,11 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. - * Description: HcclSocketManager 类实现 + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. */ #ifndef HCCL_SOCKET_MANAGER_H diff --git a/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.cc b/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.cc index cb2bfe3166ac2d0d7e958c2b8f5d65f8edb85dfe..023dc41d5ec38e858d8732f9566444a7fd021c5e 100644 --- a/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.cc +++ b/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.cc @@ -19,14 +19,14 @@ ParallelTaskLoader::ParallelTaskLoader(const s32 deviceLogicId, const HcclDispat ParallelTaskLoader::~ParallelTaskLoader() {} -HcclResult ParallelTaskLoader::Prepare(std::vector streamsPtr, void *commPtr) +HcclResult ParallelTaskLoader::Prepare(std::vector streamsPtr, SubCommInfo outerCommInfo) { // 参数保存 streamsPtr_.resize(streamsPtr.size()); for (u32 streamIndex = 0; streamIndex < streamsPtr.size(); streamIndex++) { streamsPtr_[streamIndex] = streamsPtr[streamIndex]; } - commPtr_ = commPtr; + commInfo_ = outerCommInfo; HCCL_INFO("[ParallelTaskLoader]Prepare streams size[%d], taskLoaderNum_[%u]", streamsPtr_.size(), taskLoaderNum_); // 当前现有的taskLoader线程可以满足业务多流的使用 @@ -57,7 +57,7 @@ HcclResult ParallelTaskLoader::StartTaskLoad() // 配置线程启动参数 for (u32 streamIndex = 0; streamIndex < streamsPtr_.size(); streamIndex++) { - streamTaskLoader_[streamIndex]->Prepare(streamsPtr_[streamIndex], commPtr_); + streamTaskLoader_[streamIndex]->Prepare(streamsPtr_[streamIndex], commInfo_); // 获取线程ID tidInfo_[streamIndex] = streamTaskLoader_[streamIndex]->GetTid(); } @@ -80,4 +80,13 @@ HcclResult ParallelTaskLoader::WaitTaskLoadFinish() } return HCCL_SUCCESS; } + +HcclResult ParallelTaskLoader::ClearTagCommInfo() +{ + commInfo_ = SubCommInfo{}; + for (u32 streamIndex = 0; streamIndex < streamsPtr_.size(); streamIndex++) { + CHK_RET(streamTaskLoader_[streamIndex]->ClearTagCommInfo()); + } + return HCCL_SUCCESS; +} } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.h b/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.h index 8d463f2ba74d390dd31fbbd8f9444d68e3cc3bf2..63713d3715cb1c93d6aef77dca1a028209e850ba 100644 --- a/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.h +++ b/src/domain/collective_communication/algorithm/impl/task/parallel_task_loader.h @@ -23,16 +23,17 @@ public: explicit ParallelTaskLoader(const s32 deviceLogicId, const HcclDispatcher dispatcher); ~ParallelTaskLoader(); - HcclResult Prepare(std::vector streamsPtr, void *commPtr); + HcclResult Prepare(std::vector streamsPtr, SubCommInfo outerCommInfo); HcclResult StartTaskLoad(); HcclResult WaitTaskLoadFinish(); + HcclResult ClearTagCommInfo(); protected: private: s32 deviceLogicId_; // 当前设备的device id const HcclDispatcher dispatcher_; // dispatcher引用 - void *commPtr_ = nullptr; + SubCommInfo commInfo_; std::vector streamsPtr_; std::vector> streamTaskLoader_; diff --git a/src/domain/collective_communication/algorithm/impl/task/task_loader.cc b/src/domain/collective_communication/algorithm/impl/task/task_loader.cc index 307f4f9ac4b46836bf3e55e40ff70dfcfb444bde..8386fc45dc9fe11f58586b4d8cc8ffd923bf6f5e 100644 --- a/src/domain/collective_communication/algorithm/impl/task/task_loader.cc +++ b/src/domain/collective_communication/algorithm/impl/task/task_loader.cc @@ -25,12 +25,12 @@ TaskLoader::~TaskLoader() } } -void TaskLoader::Prepare(Stream *stream, void *commPtr) +void TaskLoader::Prepare(Stream *stream, SubCommInfo outerCommInfo) { // 参数保存 stream_ = stream; HCCL_INFO("[TaskLoader] Prepare stream[%p]", stream_->ptr()); - commPtr_ = commPtr; + commInfo_ = outerCommInfo; executeResult_ = HCCL_SUCCESS; } @@ -98,10 +98,18 @@ void TaskLoader::WaitDone() HcclResult TaskLoader::ExecuteTransPortTaskInfo(TaskLogicInfo &info) { u32 index = info.taskLogicCmd.index; - hccl::CommBase *comm = static_cast(commPtr_); - CHK_SMART_PTR_NULL(comm); - std::shared_ptr destTransport = comm->GetTrasportInfoByVTransportInfoIndex(index); + std::shared_ptr destTransport = nullptr; + if (commInfo_.virtualLinks.size() <= index) { + HCCL_ERROR("[ExecuteTransPortTaskInfo]index[%u] is bigger than vlink size[%llu]", index, + commInfo_.virtualLinks.size()); + } else if (commInfo_.links.size() <= index) { + HCCL_ERROR("[ExecuteTransPortTaskInfo]index[%u] is bigger than link size[%llu]", index, + commInfo_.links.size()); + } else { + destTransport = commInfo_.links[index]; + } + CHK_SMART_PTR_NULL(destTransport); switch (info.taskFuncType) { @@ -222,4 +230,11 @@ uint32_t TaskLoader::GetTid() HCCL_INFO("[TaskLoader][GetTid]deviceLogicId_[%d], threadId_[%u]", deviceLogicId_, threadId_); return threadId_; } + +HcclResult TaskLoader::ClearTagCommInfo() +{ + commInfo_ = SubCommInfo{}; + return HCCL_SUCCESS; +} + } // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/task/task_loader.h b/src/domain/collective_communication/algorithm/impl/task/task_loader.h index 8739c2f29b584bb26f6de681bdaee984a900127a..35b80758f19a6ffe72cf6d8e0770ee50dd83adfd 100644 --- a/src/domain/collective_communication/algorithm/impl/task/task_loader.h +++ b/src/domain/collective_communication/algorithm/impl/task/task_loader.h @@ -16,6 +16,7 @@ #include #include "stream_pub.h" #include "dispatcher.h" +#include "coll_alg_param.h" namespace hccl { class TaskLoader { @@ -23,7 +24,7 @@ public: explicit TaskLoader(const s32 deviceLogicId, const HcclDispatcher dispatcher); ~TaskLoader(); - void Prepare(Stream *stream, void *commPtr); + void Prepare(Stream *stream, SubCommInfo outerCommInfo); HcclResult Init(); HcclResult Finalize(); @@ -33,6 +34,7 @@ public: void NotifyDone(); void WaitDone(); uint32_t GetTid(); + HcclResult ClearTagCommInfo(); protected: private: @@ -48,7 +50,7 @@ private: u32 userRank_; const HcclDispatcher dispatcher_; // dispatcher引用 Stream *stream_; // 执行线程对应的stream - void *commPtr_ = nullptr; + SubCommInfo commInfo_; std::mutex startMtx_; std::mutex doneMtx_; std::condition_variable startCv_; diff --git a/src/domain/collective_communication/algorithm/impl/topo_matcher.cc b/src/domain/collective_communication/algorithm/impl/topo_matcher.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a6e2ce5e4b3ac43bb9b285902ab5ad4bebc406c --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/topo_matcher.cc @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +// 内部依赖头文件 +#include +#include "dispatcher.h" +#include "comm_base_pub.h" +#include "externalinput_pub.h" +#include "coll_alg_param.h" +#include "topo_matcher.h" +#include "search_path.h" +#include "calc_p2p_transport_req.h" +namespace hccl { + +TopoMatcher::TopoMatcher(const std::vector>> CommPlaneRanks, + std::vector isBridgeVector, + HcclTopoInfo &topoInfo, + HcclAlgoInfo &algoInfo, + HcclExternalEnable &externalEnable, + std::vector>> &serverAndsuperPodToRank) + : CommPlaneVector_(CommPlaneRanks), isBridgeVector_(isBridgeVector), + topoInfo_(topoInfo), algoInfo_(algoInfo), externalEnable_(externalEnable), userRank_(topoInfo.userRank), + serverAndsuperPodToRank_(serverAndsuperPodToRank) +{ + SetRankMap(); +} + +HcclResult TopoMatcher::CalcCommPlaneInfo(const std::string &tag, const CommParaInfo &commParaInfo, + std::vector &commTransport, TransportMemType inputMemType, TransportMemType outputMemType) +{ + HcclUs startut = TIME_NOW(); + HcclResult ret = HCCL_SUCCESS; + HCCL_INFO("[Calc][CommPlane]tag[%s], commPlane[%d], commType[%d]", + tag.c_str(), commParaInfo.commPlane, commParaInfo.commType); + + u32 subUserRankRoot = INVALID_VALUE_RANKID; + if (commParaInfo.root != INVALID_VALUE_RANKID) { + subUserRankRoot = GetSubRootUserRank(userRank_, commParaInfo.root); + if (subUserRankRoot == INVALID_VALUE_RANKID) { + HCCL_ERROR("[TopoMatcher][CalcCommPlaneInfo]get sub root userrank value[%u] invalid.", subUserRankRoot); + return HCCL_E_PARA; + } + } + + std::unique_ptr calcTransportReq; + switch (commParaInfo.commType) { + case CommType::COMM_TAG_RING_INNER: + case CommType::COMM_TAG_RING_COMBINED: { + calcTransportReq.reset(new (std::nothrow) CalcRingTransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + case CommType::COMM_TAG_HALVING_DOUBLING: { + calcTransportReq.reset(new (std::nothrow) CalcHDTransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport, + subUserRankRoot); + break; + } + case CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING: + case CommType::COMM_TAG_WHOLE_NHR:{ + calcTransportReq.reset(new (std::nothrow) CalcNHRTransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + case CommType::COMM_TAG_NONUNIFORM_HIERARCHICAL_RING_V1: + case CommType::COMM_TAG_WHOLE_NHR_V1: { + calcTransportReq.reset(new (std::nothrow) CalcNHRV1TransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + case CommType::COMM_TAG_NONUNIFORM_BRUCK: + case CommType::COMM_TAG_WHOLE_NB: { + calcTransportReq.reset(new (std::nothrow) CalcNBTransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + case CommType::COMM_TAG_MESH: { + calcTransportReq.reset(new (std::nothrow) CalcMeshTransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + case CommType::COMM_TAG_PARTIAL_MESH_COMBINED: { + calcTransportReq.reset(new (std::nothrow) CalcPartialMeshTransportReq + (CommPlaneVector_[commParaInfo.commPlane], isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + case CommType::COMM_TAG_P2P: { + calcTransportReq.reset(new (std::nothrow) CalcP2PTransportReq(CommPlaneVector_[commParaInfo.commPlane], + isBridgeVector_, userRank_)); + ret = calcTransportReq->CalcTransportRequest(tag, inputMemType, outputMemType, commParaInfo, commTransport); + break; + } + default: { + HCCL_ERROR("[Calc][CommPlane]commType[%d] is invalid", commParaInfo.commType); + return HCCL_E_PARA; + } + } + + CHK_RET(SetIsUsedRdma(commParaInfo, commTransport)); + CHK_RET(GetRankMap(commParaInfo, commTransport)); + + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[Calc][CommPlane]failed, tag[%s], commPlane[%d], commType[%d]", + tag.c_str(), commParaInfo.commPlane, commParaInfo.commType), ret); + + HCCL_INFO("complete commPlane[%d] commType[%d] Calculation, Time:%lld us", + commParaInfo.commPlane, commParaInfo.commType, DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + +HcclResult TopoMatcher::GetRankMap(const CommParaInfo &commParaInfo, std::vector &commTransport) +{ + u32 ringSize = commTransport.size(); + + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { + SingleSubCommTransport &subCommTransport = commTransport[ringIndex]; + // 有建链诉求,则记录从userRank到subCommRank 和 从subCommRank到userRank的映射 + if (subCommTransport.transportRequests.size() != 0) { + if (commParaInfo.commType == CommType::COMM_TAG_PARTIAL_MESH_COMBINED) { + CHK_RET(GetSub2UserRankMap(commParaInfo.commPlane, 0, subCommTransport.subCommRank2UserRank)); + CHK_RET(GetUserRank2SubMap(commParaInfo.commPlane, 0, subCommTransport.userRank2subCommRank)); + } else { + CHK_RET(GetSub2UserRankMap(commParaInfo.commPlane, ringIndex, subCommTransport.subCommRank2UserRank)); + CHK_RET(GetUserRank2SubMap(commParaInfo.commPlane, ringIndex, subCommTransport.userRank2subCommRank)); + } + } + } + return HCCL_SUCCESS; +} + +HcclResult TopoMatcher::SetRankMap() +{ + // 构建由UserRank到子通信域的映射 + subCommRank2UserRank_.resize(static_cast(COMM_LEVEL_RESERVED)); + userRank2subCommRank_.resize(static_cast(COMM_LEVEL_RESERVED)); + + for (u32 levelIndex = 0; levelIndex < CommPlaneVector_.size(); levelIndex++) { + u32 ringSize = CommPlaneVector_[levelIndex].size(); + subCommRank2UserRank_[levelIndex].resize(ringSize); + userRank2subCommRank_[levelIndex].resize(ringSize); + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { + u32 rankSize = CommPlaneVector_[levelIndex][ringIndex].size(); + for (u32 rankIndex = 0; rankIndex < rankSize; rankIndex++) { + u32 userRank = CommPlaneVector_[levelIndex][ringIndex][rankIndex]; + subCommRank2UserRank_[levelIndex][ringIndex][rankIndex] = userRank; + userRank2subCommRank_[levelIndex][ringIndex][userRank] = rankIndex; + } + } + } + return HCCL_SUCCESS; +} + +HcclResult TopoMatcher::GetIsUsedRdma(const CommParaInfo &commParaInfo, bool &isUsedRdma) +{ + std::vector > commP2PPlaneVec; + if (commParaInfo.commType == CommType::COMM_TAG_P2P) { + // P2P只需要判断两张卡之间的连接关系 + bool invalidcheck = (topoInfo_.isUsedRdmaMap.size() <= topoInfo_.userRank) || + (topoInfo_.isUsedRdmaMap.size() <= commParaInfo.peerUserRank); + CHK_PRT_RET(invalidcheck, HCCL_ERROR("[GetIsUsedRdma]dstUserRank[%u] or userRank[%u] is bigger than "\ + "rankVector size[%u]", commParaInfo.peerUserRank, topoInfo_.userRank, topoInfo_.isUsedRdmaMap.size()), + HCCL_E_PARA); + + std::vector commP2PRankVec; + commP2PRankVec.push_back(topoInfo_.userRank); + commP2PRankVec.push_back(commParaInfo.peerUserRank); + commP2PPlaneVec.push_back(commP2PRankVec); + } + + std::vector > &commPlaneVec = (commParaInfo.commType == CommType::COMM_TAG_P2P) ? + commP2PPlaneVec : CommPlaneVector_[commParaInfo.commPlane]; + + for (const std::vector &commPlane : commPlaneVec) { + for (const u32 dstRank : commPlane) { + if (topoInfo_.isUsedRdmaMap[dstRank]) { + isUsedRdma = true; + return HCCL_SUCCESS; + } + } + } + isUsedRdma = false; + return HCCL_SUCCESS; +} + +HcclResult TopoMatcher::SetIsUsedRdma(const CommParaInfo &commParaInfo, + std::vector &commTransport) +{ + bool isUsedRdma = false; + CHK_RET(GetIsUsedRdma(commParaInfo, isUsedRdma)); + isUsedRdma = (GetExternalInputEnableRdmaSdmaConcurrent() && topoInfo_.deviceType == DevType::DEV_TYPE_910_73) ? + commParaInfo.forceRdma : isUsedRdma; + u32 ringSize = commTransport.size(); + + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { + SingleSubCommTransport &subCommTransport = commTransport[ringIndex]; + subCommTransport.isUsedRdma = isUsedRdma; + } + HCCL_INFO("[TopoMatcher][SetIsUsedRdma] commPlane[%d] isUsedRdma[%d]", commParaInfo.commPlane, isUsedRdma); + return HCCL_SUCCESS; +} + +HcclResult TopoMatcher::GetSub2UserRankMap(CommPlane commPlane, u32 ringIndex, + std::map &subCommRank2UserRank) +{ + subCommRank2UserRank = subCommRank2UserRank_[static_cast(commPlane)][ringIndex]; + return HCCL_SUCCESS; +} + +HcclResult TopoMatcher::GetUserRank2SubMap(CommPlane commPlane, u32 ringIndex, + std::map &userRank2subCommRank) +{ + userRank2subCommRank = userRank2subCommRank_[static_cast(commPlane)][ringIndex]; + return HCCL_SUCCESS; +} + +HcclTopoInfo TopoMatcher::GetTopoInfo() +{ + return topoInfo_; +} + +HcclAlgoInfo TopoMatcher::GetAlgoInfo() +{ + return algoInfo_; +} + +u32 TopoMatcher::GetExternalInputEnableRdmaSdmaConcurrent() +{ + return externalEnable_.enableRdmaSdmaConcurrent; +} + +u32 TopoMatcher::GetExternalInputHcclEnableFfts() +{ + return externalEnable_.enableFfts; +} + +u32 TopoMatcher::GetExternalInputHcclDeterministic() +{ + return externalEnable_.deterministic; +} + +u32 TopoMatcher::GetExternalInputHcclHighPerfEnable() +{ + return externalEnable_.highPerfEnable; +} + +u32 TopoMatcher::GetExternalInputIntraRoceSwitch() +{ + return externalEnable_.intraRoceSwitch; +} + +u32 TopoMatcher::GetExternalInputHcclDumpDebug() +{ + return externalEnable_.dumpDebug; +} + +bool CheckRankNeighbors(const std::vector &nicList) +{ + // 组成ROH环路必须偶数个,且2节点不能组成双环? + if (nicList.size() % 2 != 0 || nicList.size() < HCCL_DEVICE_NUM_FOUR) { + return false; + } + + std::vector tmpNicList(nicList); + std::sort(tmpNicList.begin(), tmpNicList.end()); + u32 halfNum = 2; + for (u32 i = 0; i < tmpNicList.size() / halfNum; i++) { + auto nicIndex = i * halfNum; + // 检查相邻下标的节点,devID是否相邻 + if (tmpNicList[nicIndex] + 1 != tmpNicList[nicIndex + 1]) { + return false; + } + } + + return true; +} + +// 适配ROH平面网段隔离,奇数rank互通,偶数rank互通,奇偶不通 +bool TopoMatcher::CheckSdmaWithRohTopo(const std::vector &nicList, std::vector &topoList) +{ + std::vector tmpNicList(nicList); + std::sort(tmpNicList.begin(), tmpNicList.end()); + SearchPath searchPath; + topoList = searchPath.Search(tmpNicList); + if (topoList.empty()) { + return false; + } + return true; +} + +const u32 TopoMatcher::GetSubCollectiveRank(const std::vector &vecPara) const +{ + // 在vecPara数据中,查询本user rank,查询到的vec下标就是rank值 + u32 tmpRank = INVALID_VALUE_RANKID; + + for (u32 rankIndex = 0; rankIndex < vecPara.size(); rankIndex++) { + if (userRank_ == vecPara[rankIndex]) { + tmpRank = rankIndex; + break; + } + } + + return tmpRank; +} + +u32 TopoMatcher::GetSubRootForScatter(const u32 root) +{ + // 通过root找到ringIndex, 通过userRank找到Inner中的rank + u32 subRoot = INVALID_VALUE_RANKID; + u32 planeIdx = INVALID_VALUE_RANKID; + u32 ringSize = CommPlaneVector_[COMM_LEVEL1_INDEX].size(); + + CHK_PRT_RET(ringSize == 0, HCCL_ERROR("[GET][GetSubRootForScatter]bridgeRankVector size is zero."), HCCL_E_PARA); + + u32 rank = INVALID_VALUE_RANKID; + for (u32 ringIndex = 0; ringIndex < ringSize; ringIndex++) { + if (isBridgeVector_[ringIndex]) { + rank = GetSubCollectiveRank(CommPlaneVector_[COMM_LEVEL1_INDEX][ringIndex]); // 确定userRank在Inner中的rank号 + } + for (u32 idx = 0; idx < CommPlaneVector_[COMM_LEVEL1_INDEX][ringIndex].size(); idx++) { + if (root == CommPlaneVector_[COMM_LEVEL1_INDEX][ringIndex][idx]) { // 获取root所在的平面 + planeIdx = ringIndex; + } + } + } + CHK_PRT_RET(rank == INVALID_VALUE_RANKID, + HCCL_ERROR("[GET][GetSubRootForScatter]get rankId in inner failed."), HCCL_E_PARA); + CHK_PRT_RET(planeIdx == INVALID_VALUE_RANKID, + HCCL_ERROR("[GET][GetSubRootForScatter]get root[%u] planeIdx[%u] failed.", root, planeIdx), HCCL_E_PARA); + subRoot = CommPlaneVector_[COMM_LEVEL1_INDEX][planeIdx][rank]; + HCCL_DEBUG("[GetSubRootForScatter] userRank_:[%u] subRoot:[%u]", userRank_, subRoot); + return subRoot; +} + +u32 TopoMatcher::GetSubRootUserRank(const u32 userRank, const u32 rootUserRank) +{ + u32 tmpUserRank = INVALID_VALUE_RANKID; + + u32 serverIdx = INVALID_VALUE_RANKID; + for (u32 i = 0; i < serverAndsuperPodToRank_[0].size(); i++) { + for (u32 j = 0; j < serverAndsuperPodToRank_[0][0].size(); j++) { + if (serverAndsuperPodToRank_[0][i][j] == rootUserRank) { + serverIdx = i; + break; + } + } + } + u32 rankIdx = INVALID_VALUE_RANKID; + for (u32 i = 0; i < serverAndsuperPodToRank_[0].size(); i++) { + for (u32 j = 0; j < serverAndsuperPodToRank_[0][0].size(); j++) { + if (serverAndsuperPodToRank_[0][i][j] == userRank) { + rankIdx = j; + break; + } + } + } + + if (serverIdx != INVALID_VALUE_RANKID && rankIdx != INVALID_VALUE_RANKID) { + tmpUserRank = serverAndsuperPodToRank_[0][serverIdx][rankIdx]; + } + return tmpUserRank; +} + +u32 TopoMatcher::GetSubRootUserRankWithSuperPod(const u32 userRank, const u32 rootUserRank) +{ + u32 tmpUserRank = INVALID_VALUE_RANKID; + + u32 superPodIdx = INVALID_VALUE_RANKID; + for (u32 i = 0; i < serverAndsuperPodToRank_[1].size(); i++) { + for (u32 j = 0; j < serverAndsuperPodToRank_[1][0].size(); j++) { + if (serverAndsuperPodToRank_[1][i][j] == rootUserRank) { + superPodIdx = i; + break; + } + } + } + u32 rankIdx = INVALID_VALUE_RANKID; + for (u32 i = 0; i < serverAndsuperPodToRank_[1].size(); i++) { + for (u32 j = 0; j < serverAndsuperPodToRank_[1][0].size(); j++) { + if (serverAndsuperPodToRank_[1][i][j] == userRank) { + rankIdx = j; + break; + } + } + } + + if (superPodIdx != INVALID_VALUE_RANKID && rankIdx != INVALID_VALUE_RANKID) { + tmpUserRank = serverAndsuperPodToRank_[1][superPodIdx][rankIdx]; + } + return tmpUserRank; +} + +HcclResult TopoMatcher::SetDeterministicConfig(const u8 deterministic) +{ + if (deterministic > 1) { + HCCL_ERROR("[SetDeterministicConfig] deterministic[%d] should be 0 or 1."); + return HCCL_E_PARA; + } + externalEnable_.deterministic = deterministic; + return HCCL_SUCCESS; +} + +u8 TopoMatcher::GetDeterministicConfig() const +{ + return externalEnable_.deterministic; +} + +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/topo_matcher.h b/src/domain/collective_communication/algorithm/impl/topo_matcher.h new file mode 100644 index 0000000000000000000000000000000000000000..565480cdfdaf0524f87bef30eef15dd186e419e1 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/topo_matcher.h @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TOPO_MATCHER_H +#define TOPO_MATCHER_H + +#include +#include "dispatcher.h" +#include "comm_base_pub.h" +#include "externalinput_pub.h" +#include "coll_alg_param.h" +#include "comm_factory_pub.h" +#include "nonuniform_hierarchical_ring_v1_base_pub.h" +#include "hccl_common.h" +#include "calc_impl.h" + +namespace hccl { +constexpr u32 COMM_LEVEL1_INDEX = 2; +using HcclAlgoInfo = struct HcclAlgoInfoDef { + bool inlineReduceSwitchOn; // 收到数量时同时完成Reduce计算 + std::string identifier; + bool isUsedRdmaOuter; + + HcclAlgoInfoDef() + : inlineReduceSwitchOn(true), + identifier(""), + isUsedRdmaOuter(false) + {} +}; + +using HcclTopoInfo = struct HcclTopoInfoDef { + u32 userRank; // 通信域 RankID + u32 userRankSize; // 通信域的 Rank数量 + u32 devicePhyId; + s32 deviceLogicId; + std::vector nicList; + bool isSingleMeshAggregation; + u32 deviceNumPerAggregation; // 每个module中的Device数量 + u32 devNumInLevel2; // 集群中总的超节点数 + DevType deviceType; + TopoType topoType; + bool is310P3Common; + u32 serverNum; + u32 meshAggregationRankSize; + u32 multiModuleDiffDeviceNumMode; + u32 realUserRank; + bool isDiffDeviceModule; + u32 moduleNum; + std::unordered_map isUsedRdmaMap; + std::unordered_map pairLinkCounter; // server内所有device间的链路类型计数 + + HcclTopoInfoDef() + : userRank(0), + userRankSize(0), + devicePhyId(0), + deviceLogicId(0), + nicList(0), + isSingleMeshAggregation(false), + deviceNumPerAggregation(0), + devNumInLevel2(0), + deviceType(DevType::DEV_TYPE_COUNT), + topoType(TopoType::TOPO_TYPE_COMMON), + is310P3Common(false), + serverNum(0), + meshAggregationRankSize(0), + multiModuleDiffDeviceNumMode(0), + realUserRank(0), + isDiffDeviceModule(false), + moduleNum(0) + {} +}; + +using HcclExternalEnable = struct HcclExternalEnableDef { + u32 enableRdmaSdmaConcurrent; + u32 enableFfts; + u32 deterministic; + u32 highPerfEnable; + u32 intraRoceSwitch; + u32 dumpDebug; + + HcclExternalEnableDef() + : enableRdmaSdmaConcurrent(0), + enableFfts(1), + deterministic(0), + highPerfEnable(0), + intraRoceSwitch(0), + dumpDebug(0) + {} +}; + +bool CheckRankNeighbors(const std::vector &nicList); +bool CheckSdmaWithRohTopo(const std::vector &nicList, std::vector &topoList); + +class TopoMatcher { +public: + explicit TopoMatcher(const std::vector>> CommPlaneRanks, + const std::vector isBridgeVector, + HcclTopoInfo &topoInfo, + HcclAlgoInfo &algoInfo, + HcclExternalEnable &externalEnable, + std::vector>> &serverAndsuperPodToRank); + HcclResult CalcCommPlaneInfo(const std::string &tag, const CommParaInfo &commParaInfo, + std::vector &commTransport, TransportMemType inPutMemType, + TransportMemType outPutMemType); + HcclTopoInfo GetTopoInfo(); + HcclAlgoInfo GetAlgoInfo(); + u32 GetExternalInputEnableRdmaSdmaConcurrent(); + u32 GetExternalInputHcclEnableFfts(); + u32 GetExternalInputHcclDeterministic(); + u32 GetExternalInputHcclHighPerfEnable(); + u32 GetExternalInputIntraRoceSwitch(); + u32 GetExternalInputHcclDumpDebug(); + bool CheckSdmaWithRohTopo(const std::vector &nicList, std::vector &topoList); + u32 GetSubRootForScatter(const u32 root); + u32 GetSubRootUserRank(const u32 userRank, const u32 rootUserRank); + u32 GetSubRootUserRankWithSuperPod(const u32 userRank, const u32 rootUserRank); + HcclResult SetDeterministicConfig(const u8 deterministic); + u8 GetDeterministicConfig() const; + +protected: + +private: + + HcclResult GetRankMap(const CommParaInfo &commParaInfo, std::vector &commTransport); + + HcclResult SetRankMap(); + + HcclResult SetIsUsedRdma(const CommParaInfo &commParaInfo, std::vector &commTransport); + + HcclResult GetSub2UserRankMap(CommPlane commPlane, u32 ringIndex, std::map &subCommRank2UserRank); + + HcclResult GetUserRank2SubMap(CommPlane commPlane, u32 ringIndex, std::map &userRank2subCommRank); + + HcclResult GetIsUsedRdma(const CommParaInfo &commParaInfo, bool &isUsedRdma); + + const u32 GetSubCollectiveRank(const std::vector &vecPara) const; + + std::vector>> CommPlaneVector_; + std::vector isBridgeVector_; + HcclTopoInfo &topoInfo_; + HcclAlgoInfo &algoInfo_; + HcclExternalEnable &externalEnable_; + u32 userRank_; + std::vector>> subCommRank2UserRank_; + std::vector>> userRank2subCommRank_; + + // serverAndsuperPodToRank_[0]: 通信域在当前superPod内, 按照serverIdx划分的所有rank信息 + // serverAndsuperPodToRank_[1]: 通信域所有rank的信息, 按照superPodId -> RankInfo 的结构划分 + std::vector>> serverAndsuperPodToRank_; + + u32 userRankIdx_; +}; +} // namespace hccl + +#endif /* * TOPO_MATCHER_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/framework/common/src/config.cc b/src/domain/collective_communication/framework/common/src/config.cc index d9def4dfd24d0b79bc41963960f8f2c81b99d433..8a0ec96e4f1da57a562ce358405faaa2a37efbb7 100644 --- a/src/domain/collective_communication/framework/common/src/config.cc +++ b/src/domain/collective_communication/framework/common/src/config.cc @@ -530,30 +530,20 @@ HcclResult CfgGetRoleTableInfo(const std::string &rankTableM, RoleTableInfo &rol return HCCL_SUCCESS; } -HcclResult SetRetryEnable( - const u32 &superPodNum, const u32 &serverNum, const u32 &deviceNumPerAggregation, bool &retryEnable) +void SetRetryEnable(DevType deviceType, const u32 &superPodNum, const u32 &serverNum, + const u32 &deviceNumPerAggregation, bool &retryEnable) { - if (superPodNum > 1 && GetExternalInputInterSuperPodRetryEnable()) { - retryEnable = true; - HCCL_INFO("[Config][SetRetryEnable] superPodNum[%u], retryEnable[%d].", superPodNum, retryEnable); - return HCCL_SUCCESS; - } - if (serverNum > 1 && GetExternalInputInterServerRetryEnable()) { - retryEnable = true; - HCCL_INFO("[Config][SetRetryEnable] serverNum[%u], retryEnable[%d].", serverNum, retryEnable); - return HCCL_SUCCESS; - } - if (deviceNumPerAggregation > 1 && GetExternalInputIntraServerRetryEnable()) { - retryEnable = true; - HCCL_INFO("[Config][SetRetryEnable] deviceNumPerAggregation[%u], retryEnable[%d].", - deviceNumPerAggregation, - retryEnable); - return HCCL_SUCCESS; - } - HCCL_INFO("[Config][SetRetryEnable] superPodNum[%u], serverNum[%u], deviceNumPerAggregation[%u], retryEnable[%d].", - superPodNum, - serverNum, - deviceNumPerAggregation, - retryEnable); - return HCCL_SUCCESS; + retryEnable = false; + if (deviceType != DevType::DEV_TYPE_910_73) { + retryEnable = false; + } else if (superPodNum > 1) { // L2重执行 + retryEnable = GetExternalInputInterSuperPodRetryEnable(); + } else if (serverNum > 1) { // L1重执行 + retryEnable = GetExternalInputInterServerRetryEnable(); + } else if (deviceNumPerAggregation > 1) { // L0重执行 + retryEnable = GetExternalInputIntraServerRetryEnable(); + } + + HCCL_INFO("[Config][SetRetryEnable]deviceType[%d], superPodNum[%u], serverNum[%u], deviceNum[%u], retryEnable[%d].", + deviceType, superPodNum, serverNum, deviceNumPerAggregation, retryEnable); } \ No newline at end of file diff --git a/src/domain/collective_communication/framework/common/src/config.h b/src/domain/collective_communication/framework/common/src/config.h index 61b75dc78a2b737385828b03c56ed82483b1e925..ba28fa944a53e3e26a6f4c8052bf31c7b2f89b37 100644 --- a/src/domain/collective_communication/framework/common/src/config.h +++ b/src/domain/collective_communication/framework/common/src/config.h @@ -46,6 +46,6 @@ HcclResult GetDevNum(const std::vector &rankList, u32 &devNum); HcclResult GetServerNum(const std::vector &rankList, u32 &serverNum); HcclResult GetSuperPodNum(const std::vector &rankList, u32 &superPodNum); HcclResult GetSuperPodNum(const std::vector &rankList, u32 &superPodNum); -HcclResult SetRetryEnable( - const u32 &superPodNum, const u32 &serverNum, const u32 &deviceNumPerAggregation, bool &retryEnable); +void SetRetryEnable(DevType deviceType, const u32 &superPodNum, const u32 &serverNum, + const u32 &deviceNumPerAggregation, bool &retryEnable); #endif // CONFIG_H diff --git a/src/domain/collective_communication/framework/common/src/param_check.cc b/src/domain/collective_communication/framework/common/src/param_check.cc index 4ea279e72633a03f756ed97a6e1baa557c7b1efd..9fe65dc303a32cc37c7553d16e2623a4beec6dc9 100644 --- a/src/domain/collective_communication/framework/common/src/param_check.cc +++ b/src/domain/collective_communication/framework/common/src/param_check.cc @@ -264,7 +264,7 @@ HcclResult HcomCheckReductionOp(const HcclReduceOp op) HcclResult HcomCheckReduceDataType(const HcclDataType dataType, const HcclReduceOp op, DevType deviceType) { - if (deviceType == DevType::DEV_TYPE_910B) { + if ((deviceType == DevType::DEV_TYPE_910B) || (deviceType == DevType::DEV_TYPE_910_73)) { if ((op == HCCL_REDUCE_PROD) && ((dataType == HCCL_DATA_TYPE_INT16) || (dataType == HCCL_DATA_TYPE_BFP16))) { RPT_INPUT_ERR(true, "EI0003", std::vector({"ccl_op", "parameter", "value", "tips"}),\ diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.cc b/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.cc index fac6845c1272ff230905ba106496cfa8164316af..cca1f0cfff819d190ec9bae6ff0962c32b858815 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.cc +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.cc @@ -28,11 +28,33 @@ TopoInfoDetect::TopoInfoDetect() : deviceLogicID_(INVALID_INT), localRankInfo_() TopoInfoDetect::~TopoInfoDetect() { + if (exchangeServerThreadPtr_ && exchangeServerThreadPtr_->joinable()) { + exchangeServerThreadPtr_->join(); + } + exchangeServerThreadPtr_ = nullptr; + pTopoExchangeServer_ = nullptr; + (void)Teardown(); + return; +} + +HcclResult TopoInfoDetect::GetServerConnections(std::map> &connectSockets) +{ + if (pTopoExchangeServer_) { + return pTopoExchangeServer_->GetConnections(connectSockets); + } else { + return HCCL_SUCCESS; + } +} + +HcclResult TopoInfoDetect::GetAgentConnection(std::shared_ptr &connectSocket) +{ + CHK_SMART_PTR_NULL(pTopoExchangeAgent_); + return pTopoExchangeAgent_->GetConnection(connectSocket); } void TopoInfoDetect::SetupTopoExchangeServer(s32 devicePhysicID, s32 deviceLogicID, HcclIpAddress hostIP, u32 hostPort, - vector whitelist, HcclNetDevCtx netDevCtx, std::unique_ptr listenSocket, - bool isMasterInfo) const + vector whitelist, HcclNetDevCtx netDevCtx, std::shared_ptr listenSocket, + bool isMasterInfo) { HcclResult ret = hrtSetDevice(deviceLogicID); if (ret != HCCL_SUCCESS) { @@ -41,23 +63,19 @@ void TopoInfoDetect::SetupTopoExchangeServer(s32 devicePhysicID, s32 deviceLogic return; } - unique_ptr pTopoExchangeServer; - pTopoExchangeServer.reset( - new (nothrow) TopoInfoExchangeServer(hostIP, hostPort, whitelist, netDevCtx, listenSocket)); - if (!pTopoExchangeServer) { + pTopoExchangeServer_.reset(new (nothrow) TopoInfoExchangeServer(hostIP, hostPort, whitelist, netDevCtx, + listenSocket, rootInfo_.identifier)); + if (!pTopoExchangeServer_) { topoExchangeServerStatus_[hostPort] = TOPO_EXCHANGE_SERVER_STATUS_ERROR; HCCL_ERROR("[Setup][TopoExchangeServer]build topoExchangeServer failed. "); } else { - ret = isMasterInfo ? pTopoExchangeServer->SetupByMasterInfo() : pTopoExchangeServer->Setup(); + ret = isMasterInfo ? pTopoExchangeServer_->SetupByMasterInfo() : pTopoExchangeServer_->Setup(); if (ret != HCCL_SUCCESS) { topoExchangeServerStatus_[hostPort] = TOPO_EXCHANGE_SERVER_STATUS_ERROR; HCCL_ERROR("[Setup][TopoExchangeServer]setup topoExchangeServer failed, ret[%u]", ret); } - pTopoExchangeServer = nullptr; } - (void)HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_HOST, devicePhysicID, deviceLogicID); - ret = hrtResetDevice(deviceLogicID); if (ret != HCCL_SUCCESS) { topoExchangeServerStatus_[hostPort] = TOPO_EXCHANGE_SERVER_STATUS_ERROR; @@ -66,7 +84,7 @@ void TopoInfoDetect::SetupTopoExchangeServer(s32 devicePhysicID, s32 deviceLogic } topoExchangeServerStatus_[hostPort] = TOPO_EXCHANGE_SERVER_STATUS_IDLE; } -HcclResult TopoInfoDetect::SetupServerByMasterInfo(const HcclIpAddress& masterIP, u32 masterPort) +HcclResult TopoInfoDetect::SetupServerByMasterInfo(const HcclIpAddress& masterIP, u32 masterPort, const HcclRootHandle &rootInfo) { CHK_RET(hrtGetDevice(&deviceLogicID_)); CHK_RET(hrtGetDevicePhyIdByIndex(deviceLogicID_, devicePhysicID_)); @@ -74,12 +92,13 @@ HcclResult TopoInfoDetect::SetupServerByMasterInfo(const HcclIpAddress& masterIP if (GetExternalInputHcclEnableWhitelist() == HCCL_WHITELIST_ON) { CHK_RET(ReadHostSocketWhitelist(whitelist)); } + rootInfo_ = rootInfo; CHK_RET(HcclNetInit(NICDeployment::NIC_DEPLOYMENT_HOST, devicePhysicID_, deviceLogicID_, true)); CHK_RET(StartRootNetwork(whitelist, masterIP, masterPort)); topoExchangeServerStatus_[GetExternalInputMasterInfo().port] = TOPO_EXCHANGE_SERVER_STATUS_RUNING; thread threadHandle(&TopoInfoDetect::SetupTopoExchangeServer, this, devicePhysicID_, deviceLogicID_, - masterIP, GetExternalInputMasterInfo().port, whitelist, serverPortCtx_, std::move(listenSocket_), true); + masterIP, GetExternalInputMasterInfo().port, whitelist, serverPortCtx_, listenSocket_, true); threadHandle.detach(); return HCCL_SUCCESS; @@ -115,24 +134,28 @@ HcclResult TopoInfoDetect::SetupServer(HcclRootHandle &rootInfo) } else { hostPort = GetExternalInputHcclIfBasePort() + devicePhysicID_; } - HCCL_INFO("[Setup][hcclIfBasePort], hostPort[%u]", hostPort); - + CHK_RET(GenerateRootInfo(hostIP, hostPort, devicePhysicID_, rootInfo_)); CHK_RET(StartRootNetwork(whitelist, hostIP, hostPort)); - topoExchangeServerStatus_[hostPort] = TOPO_EXCHANGE_SERVER_STATUS_RUNING; - thread threadHandle(&TopoInfoDetect::SetupTopoExchangeServer, this, devicePhysicID_, deviceLogicID_, hostIP, - hostPort, whitelist, serverPortCtx_, std::move(listenSocket_), false); - threadHandle.detach(); + exchangeServerThreadPtr_.reset(new (nothrow) thread(&TopoInfoDetect::SetupTopoExchangeServer, this, devicePhysicID_, + deviceLogicID_, hostIP, hostPort, whitelist, serverPortCtx_, listenSocket_, false)); + CHK_SMART_PTR_NULL(exchangeServerThreadPtr_); + + rootInfo = rootInfo_; + HCCL_INFO("setup topo exchange server complete, identifier[%s]", rootInfo.identifier); + return HCCL_SUCCESS; +} +HcclResult TopoInfoDetect::GenerateRootInfo(const HcclIpAddress &hostIP, u32 hostPort, u32 devicePhysicID, HcclRootHandle &rootInfo) +{ u64 timestamp = 0; CHK_RET(SalGetCurrentTimestamp(timestamp)); string identifier = hostIP.GetReadableAddress(); - identifier.append("_"); identifier.append(to_string(hostPort)); identifier.append("_"); - identifier.append(to_string(devicePhysicID_)); + identifier.append(to_string(devicePhysicID)); identifier.append("_"); identifier.append(to_string(timestamp)); CHK_PRT_RET((identifier.length() >= ROOTINFO_INDENTIFIER_MAX_LENGTH), @@ -152,6 +175,21 @@ HcclResult TopoInfoDetect::SetupServer(HcclRootHandle &rootInfo) return HCCL_SUCCESS; } +HcclResult TopoInfoDetect::TeardownServer() +{ + if(pTopoExchangeServer_) { + CHK_RET(pTopoExchangeServer_->Teardown()); + } + + if (serverPortCtx_) { + HcclNetCloseDev(serverPortCtx_); + serverPortCtx_ = nullptr; + CHK_RET(HcclNetDeInit(NICDeployment::NIC_DEPLOYMENT_HOST, devicePhysicID_, deviceLogicID_)); + } + HCCL_INFO("TopoInfoDetect TeardownServer ok, identifier[%s].", rootInfo_.identifier); + return HCCL_SUCCESS; +} + HcclResult TopoInfoDetect::WaitTopoExchangeServerCompelte(u32 idx) const { const auto start = chrono::steady_clock::now(); @@ -220,29 +258,40 @@ HcclResult TopoInfoDetect::SetupAgent(u32 rankSize, u32 myrank, const HcclRootHa localRankInfo_.hostIP.GetReadableAddress(), localRankInfo_.deviceLogicID, localRankInfo_.devicePhysicID, localRankInfo_.deviceIP[0].GetReadableIP()) ; - unique_ptr pTopoExchangeAgent; - pTopoExchangeAgent.reset(new (nothrow) TopoInfoExchangeAgent(rootIP, rootInfo.port, + pTopoExchangeAgent_.reset(new (nothrow) TopoInfoExchangeAgent(rootIP, rootInfo.port, rootInfo.identifier, agentPortCtx_, localRankInfo_)); - CHK_SMART_PTR_NULL(pTopoExchangeAgent); - CHK_RET(pTopoExchangeAgent->Setup()); + CHK_SMART_PTR_NULL(pTopoExchangeAgent_); + CHK_RET(pTopoExchangeAgent_->Setup()); + CHK_RET(pTopoExchangeAgent_->GetClusterTopoInfo(clusterTopoInfo_)); + rootInfo_ = rootInfo; + HCCL_INFO("topo detect completed. myrank[%u], totalranks[%u], myhost[%s], totalservers[%u].", + myrank, rankSize, localRankInfo_.hostIP.GetReadableAddress(), clusterTopoInfo_.serverNum); + return HCCL_SUCCESS; +} - CHK_RET(pTopoExchangeAgent->Teardown()); +HcclResult TopoInfoDetect::TeardownAgent() +{ + if (!pTopoExchangeAgent_) { + return HCCL_SUCCESS; + } + CHK_RET(pTopoExchangeAgent_->Teardown()); - ret = StopNetwork(hostIP, bInitDevNic); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Setup][Agent]topo detect agent stop network failed! rank[%u]", myrank), ret); - CHK_RET(pTopoExchangeAgent->GetClusterTopoInfo(clusterTopoInfo_)); + bool bInitDevNic = clusterTopoInfo_.rankNum != 1 ? true : false; + HcclIpAddress hostIP = GetBootstrapHostIP(); - HCCL_INFO("topo detect completed. myrank[%u], totalranks[%u], myhost[%s], totalservers[%u].", - myrank, rankSize, localRankInfo_.hostIP.GetReadableAddress(), clusterTopoInfo_.serverNum); + auto ret = StopNetwork(hostIP, bInitDevNic); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[Setup][Agent]topo detect agent stop network failed!"), ret); + HCCL_INFO("TopoInfoDetect TeardownAgent ok, identifier[%s].", rootInfo_.identifier); return HCCL_SUCCESS; } + HcclResult TopoInfoDetect::SetupAgentByMasterInfo(HcclIpAddress &localHostIp, const HcclRootHandle &rootInfo) { CHK_RET(hrtGetDevice(&deviceLogicID_)); SetBootstrapHostIP(localHostIp); CHK_RET(hrtGetDevicePhyIdByIndex(deviceLogicID_, devicePhysicID_)); - + rootInfo_ = rootInfo; bool bInitDevNic = GetExternalInputMasterInfo().rankSize != 1 ? true : false; HcclResult ret = StartNetwork(localHostIp, bInitDevNic); CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Setup][Agent]topo detect agent start network failed!"), ret); @@ -293,11 +342,18 @@ HcclResult TopoInfoDetect::SetupAgentByMasterInfo(HcclIpAddress &localHostIp, co return HCCL_SUCCESS; } -HcclResult TopoInfoDetect::Teardown(const HcclRootHandle &rootInfo) +HcclResult TopoInfoDetect::WaitComplete(const HcclRootHandle &rootInfo) { return WaitTopoExchangeServerCompelte(rootInfo.port); } +HcclResult TopoInfoDetect::Teardown() +{ + CHK_RET(TeardownAgent()); + CHK_RET(TeardownServer()); + return HCCL_SUCCESS; +} + HcclResult TopoInfoDetect::ReadHostSocketWhitelist(vector &whitelist) const { RPT_ENV_ERR((GetExternalInputHcclWhiteListFile().length() == 0), "EI0001", @@ -380,8 +436,7 @@ HcclResult TopoInfoDetect::StartRootNetwork(const vector &whiteli CHK_RET(HcclNetOpenDev(&serverPortCtx_, NicType::HOST_NIC_TYPE, devicePhysicID_, deviceLogicID_, hostIP)); CHK_PTR_NULL(serverPortCtx_); - EXECEPTION_CATCH((listenSocket_ = std::make_unique( - serverPortCtx_, usePort)), return HCCL_E_PTR); + listenSocket_.reset(new (nothrow) HcclSocket(serverPortCtx_, usePort)); CHK_SMART_PTR_NULL(listenSocket_); CHK_RET(listenSocket_->Init()); @@ -404,7 +459,7 @@ HcclResult TopoInfoDetect::AddSocketWhiteList(u32 port, wlistInfo.connLimit = HOST_SOCKET_CONN_LIMIT; wlistInfo.remoteIp.addr = ip.GetBinaryAddress().addr; wlistInfo.remoteIp.addr6 = ip.GetBinaryAddress().addr6; - string tag = TOPO_DETECT_TAG + "_" + to_string(port); + string tag = TOPO_DETECT_TAG + "_" + rootInfo_.identifier + "_" + to_string(port); s32 sRet = memcpy_s(&wlistInfo.tag[0], sizeof(wlistInfo.tag), tag.c_str(), tag.size() + 1); if (sRet != EOK) { HCCL_ERROR("[Add][SocketWhiteList]memory copy failed. errorno[%d]", sRet); @@ -467,6 +522,19 @@ HcclResult TopoInfoDetect::GenerateLocalRankInfo(u32 rankSize, u32 rankID, HcclB CHK_RET(hrtGetDeviceType(localRankInfo.deviceType)); CHK_RET(hrtGetDevice(reinterpret_cast(&localRankInfo.deviceLogicID))); CHK_RET(hrtGetDevicePhyIdByIndex(static_cast(localRankInfo.deviceLogicID), localRankInfo.devicePhysicID)); + + if (localRankInfo.deviceType == DevType::DEV_TYPE_910_73) { + s64 superPodId = 0; + s64 superDeviceId = 0; + CHK_RET(hrtGetDeviceInfo(localRankInfo.deviceLogicID, HcclRtDeviceModuleType::HCCL_RT_MODULE_TYPE_SYSTEM, + HcclRtDeviceInfoType::HCCL_INFO_TYPE_SUPER_POD_ID, superPodId)); + CHK_RET(hrtGetDeviceInfo(localRankInfo.deviceLogicID, HcclRtDeviceModuleType::HCCL_RT_MODULE_TYPE_SYSTEM, + HcclRtDeviceInfoType::HCCL_INFO_TYPE_SDID, superDeviceId)); + localRankInfo.superPodId = std::to_string(superPodId); + localRankInfo.superDeviceId = static_cast(superDeviceId); + HCCL_INFO("[Generate][LocalRankInfo]deviceLogicID[%d], superPodId[%s], superDeviceId[%u]", + localRankInfo.deviceLogicID, localRankInfo.superPodId.c_str(), localRankInfo.superDeviceId); + } localRankInfo.deviceIP.clear(); if (localRankInfo.nicDeploy == NICDeployment::NIC_DEPLOYMENT_DEVICE && rankSize != 1) { @@ -584,7 +652,7 @@ HcclResult TopoInfoDetect::Struct2JsonRankTable(const RankTable_t &clusterInfo, ClusterJson[PROP_SUPER_POD_LIST] = superPodListJson; ClusterJson[PROP_STATUS] = "completed"; - ClusterJson[PROP_VERSION] = "1.0"; + ClusterJson[PROP_VERSION] = (localRankInfo_.deviceType == DevType::DEV_TYPE_910_73) ? "1.2" : "1.0"; return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.h b/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.h index 4477fd1247890b1041991d17f1d13be54ea4d48f..cdde3fbe50d45c072359387cb9c7f588f41e6981 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.h +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_detect.h @@ -29,15 +29,21 @@ public: HcclResult SetupAgent(u32 rankSize, u32 myrank, const HcclRootHandle &rootInfo); HcclResult SetupAgentByMasterInfo(HcclIpAddress &localHostIp, const HcclRootHandle &rootInfo); HcclResult SetupServer(HcclRootHandle &rootInfo); - HcclResult SetupServerByMasterInfo(const HcclIpAddress &masterIP, u32 masterPort); + HcclResult SetupServerByMasterInfo(const HcclIpAddress &masterIP, u32 masterPort, const HcclRootHandle &rootInfo); + HcclResult Teardown(); + HcclResult WaitComplete(const HcclRootHandle &rootInfo); HcclResult GetCluterInfo(RankTable_t &clusterInfo); HcclResult GetLocalRankInfo(HcclBasicRankInfo &rankInfo); - HcclResult Teardown(const HcclRootHandle &rootInfo); HcclResult GetRankId(u32 &rankId); HcclResult TransformRankTableStr(const RankTable_t &clusterInfo, std::string &ranktableStr); + HcclResult GetAgentConnection(std::shared_ptr &connectSocket); + HcclResult GetServerConnections(std::map> &connectSockets); + HcclResult GenerateRootInfo(const HcclIpAddress &hostIP, u32 hostPort, u32 devicePhysicID, HcclRootHandle &rootInfo); protected: private: + HcclResult TeardownAgent(); + HcclResult TeardownServer(); HcclResult Struct2JsonRankTable(const RankTable_t &clusterInfo, nlohmann::json &ClusterJson); HcclResult GetRootHostIP(const std::vector &whitelist, HcclIpAddress &ip, u32 devPhyId); HcclResult StartNetwork(HcclIpAddress &hostIP, bool bInitDevNic); @@ -54,8 +60,8 @@ private: nlohmann::json &perServerJson, u32 serverIndex); HcclResult TransformSuperPodList(const std::vector &rankInfo, nlohmann::json &superPodListJson) const; void SetupTopoExchangeServer(s32 devicePhysicID, s32 deviceLogicID, HcclIpAddress hostIP, u32 hostPort, - std::vector whitelist, HcclNetDevCtx netDevCtx, std::unique_ptr listenSocket, - bool isMasterInfo = false) const; + std::vector whitelist, HcclNetDevCtx netDevCtx, std::shared_ptr listenSocket, + bool isMasterInfo = false); HcclResult WaitTopoExchangeServerCompelte(u32 idx) const; void SetBootstrapHostIP(HcclIpAddress &ip) const; HcclIpAddress GetBootstrapHostIP() const; @@ -69,7 +75,11 @@ private: HcclNetDevCtx agentPortCtx_{nullptr}; HcclNetDevCtx devNicCtx_{nullptr}; u32 devicePhysicID_{INVALID_UINT}; - std::unique_ptr listenSocket_{nullptr}; + std::shared_ptr listenSocket_{nullptr}; + HcclRootHandle rootInfo_; + std::shared_ptr pTopoExchangeAgent_{nullptr}; + std::shared_ptr pTopoExchangeServer_{nullptr}; + std::unique_ptr exchangeServerThreadPtr_{nullptr}; }; } // namespace hccl #endif /* TOPOINFO_DETECT_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.cc b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.cc index 789ea65568d31b8c27ccb4cfc2661b04f6aef254..e1bd453ed27d87398efcfdebc8dd627fcc590c5f 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.cc +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.cc @@ -18,16 +18,15 @@ namespace hccl { constexpr s32 DEVICE_LOGIC_ID_LENGTH = 4; -TopoInfoExchangeAgent::TopoInfoExchangeAgent(HcclIpAddress &serverIp, u32 serverPort, - std::string identifier, HcclNetDevCtx netDevCtx, HcclBasicRankInfo localRankInfo) +TopoInfoExchangeAgent::TopoInfoExchangeAgent(HcclIpAddress &serverIp, u32 serverPort, std::string identifier, + HcclNetDevCtx netDevCtx, HcclBasicRankInfo localRankInfo) : serverIP_(serverIp), serverPort_(serverPort), identifier_(identifier), localRankInfo_(localRankInfo), clusterTopoInfo_(), netDevCtx_(netDevCtx) -{ -} +{} TopoInfoExchangeAgent::~TopoInfoExchangeAgent() { @@ -36,16 +35,13 @@ TopoInfoExchangeAgent::~TopoInfoExchangeAgent() HcclResult TopoInfoExchangeAgent::Setup() { - std::shared_ptr socket; - HcclResult ret = Connect(serverIP_, serverPort_, socket); + HcclResult ret = Connect(serverIP_, serverPort_, socket_); CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[TopoInfoExchangeAgent][Setup]TopoExchangeAgent: "\ "connect server[%s : %u] failed", serverIP_.GetReadableAddress(), serverPort_), ret); HCCL_INFO("TopoExchangeAgent: client connect with server ip[%s] port[%u] success.", serverIP_.GetReadableAddress(), serverPort_); - CHK_RET(DetectClusterTopoInfo(socket, clusterTopoInfo_)); - - CHK_RET(Disconnect(socket)); + CHK_RET(DetectClusterTopoInfo(socket_, clusterTopoInfo_)); CHK_RET(SaveClusterInfo(clusterTopoInfo_)); @@ -53,14 +49,23 @@ HcclResult TopoInfoExchangeAgent::Setup() return HCCL_SUCCESS; } -HcclResult TopoInfoExchangeAgent::SetupByMasterInfo() + +HcclResult TopoInfoExchangeAgent::Teardown() { - isByMasterInfo_ = true; - CHK_RET(Setup()); + CHK_RET(Disconnect(socket_)); return HCCL_SUCCESS; } -HcclResult TopoInfoExchangeAgent::Teardown() const + +HcclResult TopoInfoExchangeAgent::GetConnection(std::shared_ptr &socket) { + socket = socket_; + return HCCL_SUCCESS; +} + +HcclResult TopoInfoExchangeAgent::SetupByMasterInfo() +{ + isByMasterInfo_ = true; + CHK_RET(Setup()); return HCCL_SUCCESS; } @@ -132,7 +137,7 @@ HcclResult TopoInfoExchangeAgent::GetIdentifier(u32 &indentify) HcclResult TopoInfoExchangeAgent::Connect(HcclIpAddress &serverIp, u32 port, std::shared_ptr &socket) { - std::string tag = TOPO_DETECT_TAG + "_" + std::to_string(port); + std::string tag = TOPO_DETECT_TAG + "_" + identifier_ + "_" + std::to_string(port); EXECEPTION_CATCH((socket = std::make_shared(tag, netDevCtx_, serverIp, port, HcclSocketRole::SOCKET_ROLE_CLIENT)), return HCCL_E_PTR); CHK_SMART_PTR_NULL(socket); @@ -231,6 +236,7 @@ void TopoInfoExchangeAgent::GenerateAgentID(HcclBasicRankInfo &localRankInfo, st HcclResult TopoInfoExchangeAgent::Disconnect(std::shared_ptr &socket) { CHK_RET(DisconnectSocket(socket)); + socket = nullptr; return HCCL_SUCCESS; } @@ -365,6 +371,8 @@ HcclResult TopoInfoExchangeAgent::VerifyClusterInfo(const RankTable_t &clusterIn CHK_RET(CheckRankIpFamily(clusterInfo.rankList)); } + // 超节点校验 + CHK_RET(VerifyClusterSuperPodInfo(clusterInfo.rankList)); return HCCL_SUCCESS; } @@ -436,4 +444,58 @@ HcclResult TopoInfoExchangeAgent::VerifyServerDevicePhysicID(const std::vector &rankInfo) const +{ + DevType deviceType; + CHK_RET(hrtGetDeviceType(deviceType)); + CHK_PRT_RET(deviceType != DevType::DEV_TYPE_910_73, + HCCL_DEBUG("[Verify][SuperPodInfo]deviceType[%d] does not need verify superPod info", deviceType), + HCCL_SUCCESS); + + // 获取每个超节点内的serverId + std::map> superPodSrvIdMap; // super_pod_id -> serverId + std::map> superPodSdidMap; // super_pod_id -> superDeviceId + for (u32 i = 0; i < rankInfo.size(); i++) { + auto iter = superPodSrvIdMap.find(rankInfo[i].superPodId); + if (iter == superPodSrvIdMap.end()) { + std::set serverIdSet; + serverIdSet.insert(rankInfo[i].serverId); + superPodSrvIdMap.insert({rankInfo[i].superPodId, serverIdSet}); + } else if (iter->second.find(rankInfo[i].serverId) == iter->second.end()) { + iter->second.insert(rankInfo[i].serverId); + } + + auto it = superPodSdidMap.find(rankInfo[i].superPodId); + if (it == superPodSdidMap.end()) { + std::set superDeviceIdSet; + superDeviceIdSet.insert(rankInfo[i].superDeviceId); + superPodSdidMap.insert({rankInfo[i].superPodId, superDeviceIdSet}); + } else if (it->second.find(rankInfo[i].superDeviceId) == it->second.end()) { + it->second.insert(rankInfo[i].superDeviceId); + } else { + // 超节点内superDeviceId在超节点内唯一 + CHK_PRT_RET(it->second.find(rankInfo[i].superDeviceId) != it->second.end(), + HCCL_ERROR("[Verify][SuperPodInfo]superDeviceId[0x%x] in superPod[%s]" + "is already exist.", + rankInfo[i].superDeviceId, it->first.c_str()), + HCCL_E_PARA); + } + } + + // 校验每个超节点内的server数量一致 + u32 serverNumPerPod = 0; + for (auto iter = superPodSrvIdMap.begin(); iter != superPodSrvIdMap.end(); ++iter) { + if (iter == superPodSrvIdMap.begin()) { + serverNumPerPod = superPodSrvIdMap.begin()->second.size(); + } + u32 serverNumCurPod = iter->second.size(); + CHK_PRT_RET(serverNumPerPod != serverNumCurPod, + HCCL_ERROR("[Verify][SuperPodInfo]serverNum[%u] in superPod[%s] and serverNum[%u] in superPod[%s] "\ + "are different.", serverNumPerPod, superPodSrvIdMap.begin()->first.c_str(), + serverNumCurPod, iter->first.c_str()), HCCL_E_PARA); + } + + return HCCL_SUCCESS; +} } diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.h b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.h index e1f4fa70c926eb0dc78a1fadb7397a1b1b9084cf..0bb65efee912c45812d690cba51b4a93ed358b26 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.h +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_agent.h @@ -41,9 +41,10 @@ public: ~TopoInfoExchangeAgent() override; HcclResult Setup(); HcclResult SetupByMasterInfo(); - HcclResult Teardown() const; + HcclResult Teardown(); HcclResult GetClusterTopoInfo(RankTable_t &clusterInfo); HcclResult GetIdentifier(u32 &indentify); + HcclResult GetConnection(std::shared_ptr &socket); private: HcclResult DetectClusterTopoInfo(std::shared_ptr socket, RankTable_t &clusterTopoInfo); @@ -62,6 +63,7 @@ private: HcclResult VerifyClusterDeviceIP(const RankTable_t &clusterInfo); HcclResult VerifyClusterRankID(const RankTable_t &clusterInfo) const; HcclResult VerifyServerDevicePhysicID(const std::vector &serverInfo) const; + HcclResult VerifyClusterSuperPodInfo(const std::vector &rankInfo) const; bool HasRepeatedIP(const std::vector &deviceAIP, const std::vector &deviceBIP) const; HcclResult DetectTransportType(const RankInfo_t &localRankInfo, const RankInfo_t &remoteRankInfo, @@ -74,6 +76,7 @@ private: HcclBasicRankInfo localRankInfo_; RankTable_t clusterTopoInfo_; HcclNetDevCtx netDevCtx_{nullptr}; + std::shared_ptr socket_; }; } // namespace hccl diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_base.cc b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_base.cc index 89b6bb69d19cd2e872359a4b18344099c98f2cd2..1aef275d14bc58eb508cd32e00103c9ab3106837 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_base.cc +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_base.cc @@ -40,7 +40,9 @@ HcclResult TopoInfoExchangeBase::SaveClusterInfo(const RankTable_t &clusterInfo) HcclResult TopoInfoExchangeBase::DisconnectSocket(std::shared_ptr socket) const { - socket->Close(); + if (socket) { + socket->Close(); + } return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.cc b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.cc index 9975b7237b8a522f40f4a3a56b1e15d1a639db55..21afe5bf22a14ce296a54d43ab975cc4d64665de 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.cc +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.cc @@ -23,12 +23,13 @@ const u32 DISPLAY_RANKNUM_PERLINE = 8; using namespace std; TopoInfoExchangeServer::TopoInfoExchangeServer(HcclIpAddress &hostIP, u32 hostPort, const std::vector whitelist, HcclNetDevCtx netDevCtx, - const std::unique_ptr &listenSocket) + std::shared_ptr listenSocket, const std::string &identifier) : hostIP_(hostIP), hostPort_(hostPort), whitelist_(whitelist), netDevCtx_(netDevCtx), - listenSocket_(listenSocket) + listenSocket_(listenSocket), + identifier_(identifier) { } @@ -40,38 +41,55 @@ HcclResult TopoInfoExchangeServer::Setup() { HcclResult ret; HcclResult error = HCCL_SUCCESS; - std::map> connectSockets; + do { - ret = Connect(connectSockets); + ret = Connect(connectSockets_); CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[TopoInfoExchangeServer][Setup]cluster topo exchange server connect "\ - "client failed"), error = ret); + HCCL_ERROR("[TopoInfoExchangeServer][Setup]cluster topo exchange server connect client failed"), + error = ret); HCCL_INFO("cluster topo exchange server connect with all agent success."); RankTable_t rankTable; - ret = GetRanksBasicInfo(connectSockets, rankTable); - CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[TopoInfoExchangeServer][Setup]GetRanksBasicInfo failed"), error = ret); + ret = GetRanksBasicInfo(connectSockets_, rankTable); + CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[TopoInfoExchangeServer][Setup]GetRanksBasicInfo failed"), + error = ret); HCCL_INFO("cluster topo exchange server get rank basic info from all agent success."); TopoInfoExchangeDispather dispatcher(this); - ret = dispatcher.BroadcastRankTable(connectSockets, rankTable); + ret = dispatcher.BroadcastRankTable(connectSockets_, rankTable); CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[TopoInfoExchangeServer][Setup]Broadcast Rank Basic Infos failed"), error = ret); HCCL_INFO("cluster topo exchange server send rank basic info to all agent success."); - ret = Disconnect(connectSockets); + ret = StopSocketListen(whitelist_, hostIP_, hostPort_); CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[TopoInfoExchangeServer][Setup]topo exchange server disconnect "\ - "socket failed."), error = ret); + HCCL_ERROR("[TopoInfoExchangeServer][Setup]topo exchange server stop socket listen failed."), error = ret); } while (0); - CHK_RET(StopNetwork(whitelist_, hostIP_, hostPort_)); + if (error) { + CHK_RET(Disconnect(connectSockets_)); + CHK_RET(StopNetwork(whitelist_, hostIP_, hostPort_)); + } HCCL_INFO("cluster topo exchange server completed, exit[%u].", error); return error; } + +HcclResult TopoInfoExchangeServer::Teardown() +{ + CHK_RET(Disconnect(connectSockets_)); + CHK_RET(StopNetwork(whitelist_, hostIP_, hostPort_)); + return HCCL_SUCCESS; +} + +HcclResult TopoInfoExchangeServer::GetConnections(std::map> &connectSockets) +{ + connectSockets = connectSockets_; + return HCCL_SUCCESS; +} + + HcclResult TopoInfoExchangeServer::SetupByMasterInfo() { isByMasterInfo_ = true; @@ -93,7 +111,7 @@ HcclResult TopoInfoExchangeServer::Connect(std::map socket; - std::string tag = TOPO_DETECT_TAG + "_" + std::to_string(hostPort_); + std::string tag = TOPO_DETECT_TAG + "_" + identifier_ + "_" + std::to_string(hostPort_); HcclResult ret = listenSocket_->Accept(tag, socket); if (ret == HCCL_SUCCESS) { HCCL_INFO("listenSocket_->Accept completed."); @@ -164,9 +182,11 @@ HcclResult TopoInfoExchangeServer::DisplayConnectionedRank( HcclResult TopoInfoExchangeServer::Disconnect(std::map> &connectSockets) { + std::unique_lock lock(lock_); for (auto &socket : connectSockets) { CHK_RET(DisconnectSocket(socket.second)); } + connectSockets.clear(); return HCCL_SUCCESS; } @@ -179,7 +199,7 @@ HcclResult TopoInfoExchangeServer::DeleteSocketWhiteList(u32 port, wlistInfo.connLimit = HOST_SOCKET_CONN_LIMIT; wlistInfo.remoteIp.addr = ip.GetBinaryAddress().addr; wlistInfo.remoteIp.addr6 = ip.GetBinaryAddress().addr6; - std::string tag = TOPO_DETECT_TAG + "_" + std::to_string(port); + std::string tag = TOPO_DETECT_TAG + "_" + identifier_ + "_" + std::to_string(port); s32 sRet = memcpy_s(&wlistInfo.tag[0], sizeof(wlistInfo.tag), tag.c_str(), tag.size() + 1); if (sRet != EOK) { HCCL_ERROR("[Delete][SocketWhiteList]memory copy failed. errorno[%d]", sRet); @@ -194,21 +214,26 @@ HcclResult TopoInfoExchangeServer::DeleteSocketWhiteList(u32 port, return HCCL_SUCCESS; } -HcclResult TopoInfoExchangeServer::StopNetwork(const std::vector &whitelist, +HcclResult TopoInfoExchangeServer::StopSocketListen(const std::vector &whitelist, HcclIpAddress &hostIP, u32 hostPort) { - if (GetExternalInputHcclEnableWhitelist() == HCCL_WHITELIST_ON) { - CHK_RET(DeleteSocketWhiteList(hostPort, whitelist)); - } - if (listenSocket_) { + if (GetExternalInputHcclEnableWhitelist() == HCCL_WHITELIST_ON) { + CHK_RET(DeleteSocketWhiteList(hostPort, whitelist)); + } CHK_RET(listenSocket_->DeInit()); + listenSocket_ = nullptr; } + return HCCL_SUCCESS; +} - if (netDevCtx_) { - HcclNetCloseDev(netDevCtx_); - netDevCtx_ = nullptr; - } +HcclResult TopoInfoExchangeServer::StopNetwork(const std::vector &whitelist, + HcclIpAddress &hostIP, u32 hostPort) +{ + std::unique_lock lock(lock_); + CHK_RET(StopSocketListen(whitelist, hostIP, hostPort)); + + netDevCtx_ = nullptr; return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.h b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.h index a9009f4f6b17c478ed10aba02be7625133bc46e5..80db7004373b53d3123963e456f6335bd9b03185 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.h +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_exchange_server.h @@ -25,10 +25,12 @@ namespace hccl { class TopoInfoExchangeServer : public TopoInfoExchangeBase { public: explicit TopoInfoExchangeServer(HcclIpAddress &hostIP, u32 hostPort, const std::vector whitelist, - HcclNetDevCtx netDevCtx, const std::unique_ptr &listenSocket); + HcclNetDevCtx netDevCtx, const std::shared_ptr listenSocket, const std::string &identifier); ~TopoInfoExchangeServer() override; HcclResult Setup(); HcclResult SetupByMasterInfo(); + HcclResult Teardown(); + HcclResult GetConnections(std::map> &connectSockets); private: HcclResult Connect(std::map> &connectSockets); @@ -37,6 +39,8 @@ private: HcclResult DeleteSocketWhiteList(u32 port, const std::vector &whitelist); HcclResult StopNetwork(const std::vector &whitelist, HcclIpAddress &hostIP, u32 hostPort); + HcclResult StopSocketListen(const std::vector &whitelist, + HcclIpAddress &hostIP, u32 hostPort); HcclResult GetRanksBasicInfo( const std::map> &connectSockets, RankTable_t &rankTable); HcclResult GetRanksTransInfo( @@ -57,8 +61,11 @@ private: SocketHandle socketHandle_; std::vector whitelist_; HcclNetDevCtx netDevCtx_{nullptr}; - const std::unique_ptr &listenSocket_; + std::shared_ptr listenSocket_; friend class TopoInfoExchangeDispather; + std::map> connectSockets_; + std::mutex lock_; + std::string identifier_; }; } // namespace hccl diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableConcise.cc b/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableConcise.cc index 504724eeea17525b50705b30f28f0921d4787d65..9537ae69a20af50c8e0ba4b992e09092f3e6be0c 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableConcise.cc +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableConcise.cc @@ -317,14 +317,14 @@ HcclResult TopoinfoRanktableConcise::GetSingleDevice(const nlohmann::json &devic u32 devicePhyId = 0; CHK_RET(SalStrToULong(strDevid, HCCL_BASE_DECIMAL, devicePhyId)); - if ((deviceType == DevType::DEV_TYPE_310P3 || deviceType == DevType::DEV_TYPE_910B) && - devicePhyId > (MAX_MODULE_DEVICE_NUM - 1)) { + if ((deviceType == DevType::DEV_TYPE_310P3 || deviceType == DevType::DEV_TYPE_910B || + deviceType == DevType::DEV_TYPE_910_73) && devicePhyId > (MAX_MODULE_DEVICE_NUM - 1)) { // deviceid in 0 ~ 15 HCCL_ERROR("[Get][SingleDevice]errNo[0x%016llx] device_id[%u] more than 15 is invalid", HCOM_ERROR_CODE(HCCL_E_PARA), devicePhyId); return HCCL_E_PARA; - } else if ((deviceType != DevType::DEV_TYPE_310P3 && deviceType != DevType::DEV_TYPE_910B) && - devicePhyId > (HCCL_AISERVER_DEVICE_NUM - 1)) { + } else if ((deviceType != DevType::DEV_TYPE_310P3 && deviceType != DevType::DEV_TYPE_910B && + deviceType != DevType::DEV_TYPE_910_73) && devicePhyId > (HCCL_AISERVER_DEVICE_NUM - 1)) { // deviceid in 0 ~ 7 HCCL_ERROR("[Get][SingleDevice]errNo[0x%016llx] device_id[%u] more than 7 is invalid", HCOM_ERROR_CODE(HCCL_E_PARA), devicePhyId); diff --git a/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableStandard.cc b/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableStandard.cc index 528b0ed7b180d4f08fabbfe9adc0da6a9bb2e517..71747959fe4acfcbf3c40ba4e2afcd5dc2d2ea97 100644 --- a/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableStandard.cc +++ b/src/domain/collective_communication/framework/common/src/topo/topoinfo_ranktableStandard.cc @@ -154,8 +154,8 @@ HcclResult TopoinfoRanktableStandard::GetHcomInfo(hccl::HcclCommParams ¶ms, return HCCL_E_PARA; } - if ((params.deviceType == DevType::DEV_TYPE_910 || params.deviceType == DevType::DEV_TYPE_910B) && - paraPlaneLocation != "device") { + if ((params.deviceType == DevType::DEV_TYPE_910 || params.deviceType == DevType::DEV_TYPE_910B || + params.deviceType == DevType::DEV_TYPE_910_73) && paraPlaneLocation != "device") { HCCL_ERROR("[Get][HcomInfo]errNo[0x%016llx] paraPlaneLocation should be 'device'", HCOM_ERROR_CODE(HCCL_E_PARA)); return HCCL_E_PARA; @@ -569,7 +569,8 @@ HcclResult TopoinfoRanktableStandard::GetDevList(nlohmann::json &instanceList, u u32 devicePhyId = 0; CHK_RET(SalStrToULong(strDevid, HCCL_BASE_DECIMAL, devicePhyId)); if ((params.deviceType != DevType::DEV_TYPE_310P3 && - params.deviceType != DevType::DEV_TYPE_910B) && + params.deviceType != DevType::DEV_TYPE_910B && + params.deviceType != DevType::DEV_TYPE_910_73) && (devicePhyId > (HCCL_AISERVER_DEVICE_NUM - 1))) { HCCL_ERROR("[Get][DevList]errNo[0x%016llx] device_id[%u] more than 7 is invalid", HCOM_ERROR_CODE(HCCL_E_PARA), devicePhyId); diff --git a/src/domain/collective_communication/framework/communicator/hccl_comm.cc b/src/domain/collective_communication/framework/communicator/hccl_comm.cc index 231083e00a094e7d1407dc9d1c2b7b216e496a5b..21bc16caffd9d8165cbb4823083d7995511d4ce1 100644 --- a/src/domain/collective_communication/framework/communicator/hccl_comm.cc +++ b/src/domain/collective_communication/framework/communicator/hccl_comm.cc @@ -112,10 +112,9 @@ HcclResult hcclComm::init(HcclCommParams ¶ms, const RankTable_t &rankTable) communicator_->AtomicInitClear(); return ret; } - - if (params.totalRanks != 1) { - CHK_RET(communicator_->InitCCLbuffer(inCCLbufferSize_, outCCLbufferSize_)); - } + + CHK_RET(communicator_->InitCCLbuffer(inCCLbufferSize_, outCCLbufferSize_)); + HCCL_RUN_INFO("hcclCommInitInfo:commId[%s], rank[%u], totalRanks[%u], serverId[%s], deviceType[%d]," \ "logicDevId[%d], identifier[%s]", params.id.internal, params.rank, params.totalRanks, params.serverId.c_str(), params.deviceType, params.logicDevId, params.identifier.c_str()); @@ -219,7 +218,7 @@ HcclResult hcclComm::DestroyGroup(const std::string &group) const HcclResult hcclComm::GetAlgType(AlgType &algType, HcclCMDType opType) { /* 增加输出日志关键字 */ - HCCL_DEBUG("algType[%d]", algType); + HCCL_DEBUG("algType[%s]", HcclAlg::AlgTypeToStr(algType).c_str()); return communicator_->GetAlgType(algType, opType); } diff --git a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc index 3d132247c9ceef65b2411d03bfc5aa6c28ffea5c..edf580b5777a93740119da35e893a1469a6648c2 100644 --- a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc +++ b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc @@ -82,12 +82,25 @@ HcclCommunicator::~HcclCommunicator() for (auto &res :resMap_) { DestroyAlgResource(res.second); } + + if (opRetryManager_ != nullptr) { + opRetryManager_->UnRegisterOpRetryManager(identifier_); + opRetryManager_ = nullptr; + } resMap_.clear(); tagCommInfo_.clear(); tagWorkSpaceMem_.clear(); + tagStreamInfo_.clear(); + if (opRetryStreamPtr_ != nullptr) { + opRetryStreamPtr_->clear(); + opRetryStreamPtr_ = nullptr; + } (void)UnRegistTaskExceptionHandler(); + kfcControlTransferH2D_ = nullptr; + kfcStatusTransferD2H_ = nullptr; + MrManagerDeInit(); /* 网络资源销毁 */ @@ -185,7 +198,8 @@ HcclResult HcclCommunicator::Init(HcclCommParams ¶ms, const RankTable_t &ran if (GetExternalInputHcclAivMode() && deviceType_ == DevType::DEV_TYPE_910B) { CHK_RET(RegisterKernel(deviceType_)); } - + CHK_RET(InitHDCommunicate()); + CHK_RET(InitOpRetry()); return HCCL_SUCCESS; } @@ -202,7 +216,8 @@ HcclResult HcclCommunicator::Init(HcclCommParams ¶ms, const std::vector(devicePhyId_) == HOST_DEVICE_ID) { HCCL_ERROR("[HcclCommunicator][Init]not support cpu rank"); return HCCL_E_NOT_SUPPORT; } else { - HCCL_DEBUG("[HcclCommunicator][Init]devicePhyId[%u] != HOST_DEVICE_ID", devicePhyId_); + HCCL_DEBUG("HcclCommunicator::Init devicePhyId[%u] != HOST_DEVICE_ID", devicePhyId_); CHK_RET(hrtGetDevice(&deviceLogicId_)); } @@ -405,7 +420,6 @@ HcclResult HcclCommunicator::CheckDataType(const HcclDataType dataType, bool nee return HCCL_E_NOT_SUPPORT; } } - return HCCL_SUCCESS; } @@ -434,7 +448,6 @@ HcclResult HcclCommunicator::InitNotifyManager() queueNotifyManager_.reset(new (std::nothrow) QueueNotifyManager()); CHK_SMART_PTR_NULL(queueNotifyManager_); CHK_RET(queueNotifyManager_->Init()); - queueNotifyManagerRefac_.reset(new (std::nothrow) QueueNotifyManager()); CHK_SMART_PTR_NULL(queueNotifyManagerRefac_); CHK_RET(queueNotifyManagerRefac_->Init()); @@ -445,7 +458,7 @@ HcclResult HcclCommunicator::InitNotifyManager() HcclResult HcclCommunicator::InitDispatcher() { // 根据设备ID创建dispatcher - if ((deviceType_ == DevType::DEV_TYPE_910B) && + if ((deviceType_ == DevType::DEV_TYPE_910B || deviceType_ == DevType::DEV_TYPE_910_73) && GetExternalInputHcclEnableFfts()) { CHK_PRT_CONT(GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE, HCCL_RUN_INFO("Will use ffts mode.")); @@ -582,7 +595,7 @@ HcclResult HcclCommunicator::InitRankInfoSubGroup(const std::vector &r CHK_RET(GetSuperPodNum(rankListNew, superPodNum_)); // 按通信域配置是否使用算子级重执行 - CHK_RET(SetRetryEnable(superPodNum_, serverNum_, deviceNumPerAggregation_, retryEnable_)); + SetRetryEnable(deviceType_, superPodNum_, serverNum_, deviceNumPerAggregation_, retryEnable_); // 获取module相关信息,moduleNum_, isDiffDeviceModule_, multiModuleDiffDeviceNumMode_; CHK_RET(GetModuleInfo(rankListNew)); @@ -720,12 +733,14 @@ HcclResult HcclCommunicator::ClearOpResource(const std::string &tag) CHK_RET(StreamActiveManager::GetInstance(deviceLogicId_).StreamsUnactive(iterStream->second.ringStreams)); } tagStreamInfo_.erase(tag); + if (opRetryStreamPtr_ != nullptr) { + opRetryStreamPtr_->erase(tag); + } if (implAlg_ != nullptr) { CHK_RET(implAlg_->ClearOpResource(tag)); } DestroyWorkspaceResource(tag); - return HCCL_SUCCESS; } @@ -768,13 +783,12 @@ HcclResult HcclCommunicator::GetBandWidthPerNPU(u32 level, float &bandWidth) HcclResult HcclCommunicator::GetDeviceNumPerAggregation(u32 &deviceNumPerAggregation) { deviceNumPerAggregation = deviceNumPerAggregation_; - return HCCL_SUCCESS; } HcclResult HcclCommunicator::CheckReduceDataType(const HcclDataType dataType, const HcclReduceOp op) { - if (deviceType_ == DevType::DEV_TYPE_910B) { + if ((deviceType_ == DevType::DEV_TYPE_910B) || (deviceType_ == DevType::DEV_TYPE_910_73)) { if ((op == HCCL_REDUCE_PROD) && ((dataType == HCCL_DATA_TYPE_INT16) || (dataType == HCCL_DATA_TYPE_BFP16))) { RPT_INPUT_ERR(true, "EI0003", std::vector({"ccl_op", "parameter", "value", "tips"}),\ @@ -824,7 +838,6 @@ HcclResult HcclCommunicator::CheckReduceDataType(const HcclDataType dataType, co return HCCL_E_NOT_SUPPORT; } } - return HCCL_SUCCESS; } @@ -838,7 +851,6 @@ HcclResult HcclCommunicator::GetServerNum(const std::vector &ranks) } } serverNum_ = serverIds.size(); - return HCCL_SUCCESS; } @@ -857,7 +869,6 @@ HcclResult HcclCommunicator::GetServerId(const RankTable_t &rankTable) HCCL_ERROR("[Get][ServerId]GetServerId fail"); return HCCL_E_PARA; } - return HCCL_SUCCESS; } @@ -901,7 +912,6 @@ HcclResult HcclCommunicator::GetInnerServerAverageDevice(const RankTable_t &rank } else { deviceNumPerAggregation_ = deviceNumPerServer_; } - return HCCL_SUCCESS; } @@ -936,7 +946,6 @@ HcclResult HcclCommunicator::GetInnerServerAverageDevice(const std::vector(LinkTypeInServer::HCCS_TYPE)].size() > 0 && devNum > HCCL_DEVICE_NUM_TWO && - (deviceType_ != DevType::DEV_TYPE_910B && !Is310P3Common())) { + (deviceType_ != DevType::DEV_TYPE_910B && deviceType_ != DevType::DEV_TYPE_910_73 && !Is310P3Common())) { CHK_PRT_RET(CheckDevCount(devNum) != HCCL_SUCCESS, HCCL_ERROR("[Check][RankTable]errNo[0x%016llx] devnum is invaild in server.", HCCL_ERROR_CODE(HCCL_E_PARA)), HCCL_E_PARA); } - return HCCL_SUCCESS; } @@ -1086,7 +1092,6 @@ HcclResult HcclCommunicator::CheckDevPhyId(const s32 &devicePhyId) const return HCCL_E_PARA; } - return HCCL_SUCCESS; } @@ -1102,7 +1107,6 @@ HcclResult HcclCommunicator::SortRankInfoList() HCCL_ERROR("[HcclCommunicator][SortRankInfoList]errNo[0x%016llx] index[%u] != rankInfoList.userRank[%u]", HCCL_ERROR_CODE(HCCL_E_PARA), index, rankInfoList_[index].userRank), HCCL_E_PARA); } - return HCCL_SUCCESS; } @@ -1129,8 +1133,8 @@ HcclResult HcclCommunicator::GetRankInfoList(const RankTable_t &rankTable) rankInfo.deviceType = deviceType; CHK_RET(CheckDeviceType(deviceType)); - if (deviceType != DevType::DEV_TYPE_910B) { - // 910B形态不做devicePhyId最大值的判断 + if (deviceType != DevType::DEV_TYPE_910B || deviceType_ != DevType::DEV_TYPE_910_73) { + // 910B、910_73形态不做devicePhyId最大值的判断 CHK_RET(CheckDevPhyId(orgRankInfo.deviceInfo.devicePhyId)); } rankInfo.devicePhyId = orgRankInfo.deviceInfo.devicePhyId; @@ -1158,7 +1162,6 @@ HcclResult HcclCommunicator::GetRankInfoList(const RankTable_t &rankTable) } // 将rank id从小到大的顺序返回 CHK_RET(SortRankInfoList()); - return HCCL_SUCCESS; } @@ -1202,7 +1205,6 @@ HcclResult HcclCommunicator::CalAndSetMeshAggRankSize() size = servRankInfo_.begin()->second.size(); } CHK_RET(SetMeshAggregationRankSize(size)); - return HCCL_SUCCESS; } @@ -1323,7 +1325,6 @@ HcclResult HcclCommunicator::InitPara() workSpaceRes_, cclBufferManager_, dispatcher_, vDispatcher_, notifyPool_, netDevCtxMap_, queueNotifyManager_, algoAttr, topoAttr, false)); - return HCCL_SUCCESS; } @@ -1343,7 +1344,6 @@ HcclResult HcclCommunicator::GetModuleIdx(const RankInfo_t &rankInfo, u32 &modul } CHK_PRT_RET(moduleIdx == INVALID_UINT, HCCL_ERROR("GetModuleIdx failed. moduleIdx:[%d], rankId:[%u]", moduleIdx, rankInfo.rankId), HCCL_E_PARA); - return HCCL_SUCCESS; } @@ -1392,7 +1392,6 @@ HcclResult HcclCommunicator::GetModuleInfo(const std::vector &rankLi rankInfo.deviceInfo.devicePhyId); } } - return HCCL_SUCCESS; } @@ -1448,7 +1447,6 @@ HcclResult HcclCommunicator::CheckSingleServerComm(const std::vector } } } - return HCCL_SUCCESS; } @@ -1468,7 +1466,6 @@ HcclResult HcclCommunicator::TransformRankList( rankInfoTmp.superPodId = rankListIn[index].superPodId; rankListOut.push_back(rankInfoTmp); } - return HCCL_SUCCESS; } @@ -1484,6 +1481,33 @@ bool HcclCommunicator::IsStandardCard() (pairLinkInfo_[static_cast(LinkTypeInServer::SIO_TYPE)].size() == 0)); } +HcclResult HcclCommunicator::InitHDCommunicate() +{ + if ((GetExternalInputHcclAicpuUnfold() == true) || + ((deviceType_ == DevType::DEV_TYPE_910_73) || (deviceType_ == DevType::DEV_TYPE_910B) || Is310P3Common())) { + kfcControlTransferH2D_ = std::make_shared(deviceLogicId_, HCCL_HDC_TYPE_H2D, sizeof(KfcExecControl)); + CHK_SMART_PTR_NULL(kfcControlTransferH2D_); + + CHK_RET(kfcControlTransferH2D_->InitHost()); + kfcStatusTransferD2H_ = std::make_shared(deviceLogicId_, HCCL_HDC_TYPE_D2H, sizeof(KfcExecStatus)); + CHK_SMART_PTR_NULL(kfcStatusTransferD2H_); + CHK_RET(kfcStatusTransferD2H_->InitHost()); + } + return HCCL_SUCCESS; +} + +HcclResult HcclCommunicator::InitOpRetry() +{ + opRetryStreamPtr_ = std::make_shared(); + opRetryManager_.reset(new (std::nothrow) OpRetryManagerPub()); + if (retryEnable_) { + CHK_RET(opRetryManager_->RegisterOpRetryMachine(identifier_, userRank_, commConnections_.isRoot, + commConnections_.agentConnection, commConnections_.serverConnections, kfcControlTransferH2D_, + kfcStatusTransferD2H_, opRetryStreamPtr_, notifyPool_)); + } + return HCCL_SUCCESS; +} + bool HcclCommunicator::CompareWithDevicePhyId(const RankInfo_t &left, const RankInfo_t &right) { return left.deviceInfo.devicePhyId < right.deviceInfo.devicePhyId; @@ -1526,7 +1550,6 @@ HcclResult HcclCommunicator::GetNicInfo(const NICDeployment &nicDeploy, const u3 rankInfo.nicIp.push_back(curRankInfo.deviceInfo.deviceIp[0]); } - return HCCL_SUCCESS; } @@ -1556,7 +1579,6 @@ HcclResult HcclCommunicator::InitPreResource(const RankTable_t &rankTable) } drvInit_ = true; - return HCCL_SUCCESS; } @@ -1581,7 +1603,6 @@ HcclResult HcclCommunicator::InitTcpMode(const RankTable_t &rankTable) const CHK_RET(InitExternalInputHeterog()); RankConsistent::GetInstance().RecordProtocolType(GetExternalInputProtocolType()); - return HCCL_SUCCESS; } @@ -1589,10 +1610,12 @@ bool HcclCommunicator::IsSupportEnableRoce() { // 910B单机两种使能roce场景:1、a+x同时使用两module 2.标卡 bool roceSwitch = false; - HCCL_INFO("[HcclCommunicator]IsSupportEnableRoce"); + HCCL_INFO("[HcclCommunicator]IsSupportEnableRoce log"); if (deviceType_ == DevType::DEV_TYPE_910B) { roceSwitch = (GetExternalInputIntraRoceSwitch() && (!isSingleMeshAggregation_ || isStandardCard_)) || multiModuleDiffDeviceNumMode_; + } else if (deviceType_ == DevType::DEV_TYPE_910_73) { + roceSwitch = GetExternalInputEnableRdmaSdmaConcurrent(); } else { // 其他单机场景为了防止用户误用roce开关 roceSwitch = isStandardCard_ ? GetExternalInputIntraRoceSwitch() : false; } @@ -1607,6 +1630,12 @@ bool HcclCommunicator::IsEnableRoce() interServer_, isSingleMeshAggregation_, roceSwitch); bool isInterServerVnic = false; + // 910_73超节点内节点间走HCCS通信 && Vnic建链, 不需要使能NIC + if (deviceType_ == DevType::DEV_TYPE_910_73 && superPodNum_ == 1 && + GetExternalInputInterHccsDisable() == false && GetExternalInputInterVnicDisable() == false) { + isInterServerVnic = true; + HCCL_INFO("IsEnableRoce isInterServerVnic set %d", isInterServerVnic); + } if ((interServer_ && !isInterServerVnic) || roceSwitch) { return true; } @@ -1674,7 +1703,6 @@ HcclResult HcclCommunicator::InitRaResource() raResourceInit_ = true; // 全局通信域会初始化,子通信域不会初始化,但是析构均会进入此逻辑,需要标记 isSupportRdmaLite_ = IsSupportRDMALite(deviceLogicId_); // 是否支持Rdma Lite - return HCCL_SUCCESS; } @@ -1690,7 +1718,6 @@ HcclResult HcclCommunicator::DisablePreResource() HCCL_ERROR("[Disable][PreResource]Disable all P2P Failed, deviceLogicId[%d], ret[%u]", deviceLogicId_, ret), ret); enableP2PDevices_.clear(); - return HCCL_SUCCESS; } @@ -1722,9 +1749,9 @@ HcclResult HcclCommunicator::GetWorkspaceSubStreamNum(u64 &streamNum, u64 dataSi break; } - if (AlltoAllOperator::NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_)) { + if (CollAlgOperator::NAFullmeshSatisfyHighPerfAlltoallMeshCondition(deviceType_, userRankSize_)) { streamNum = std::max(static_cast(userRankSize_ - 1u), streamNum); - } else if (AlltoAllOperator::FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(deviceType_, + } else if (CollAlgOperator::FullmeshPairwiseSatisfyHighPerfAlltoallMeshCondition(deviceType_, meshAggregationRankSize_)) { streamNum = std::max(static_cast(meshAggregationRankSize_ - 1u), streamNum); } @@ -1741,7 +1768,6 @@ HcclResult HcclCommunicator::GetWorkspaceSubStreamNum(u64 &streamNum, u64 dataSi if (implAlg_ != nullptr && sliceNum >= MIN_PIPLINE_SLICE_NUM) { streamNum++; } - return HCCL_SUCCESS; } @@ -1749,6 +1775,11 @@ void HcclCommunicator::DestroyAlgResource(AlgResourceResponse &res) { for (auto &levelNSubCommTransport : res.opTransportResponse) { for (auto &singleSubCommTransport : levelNSubCommTransport) { + for (u32 i = 0; i < singleSubCommTransport.virtualLinks.size();i++) { + if (singleSubCommTransport.virtualLinks[i] != nullptr) { + singleSubCommTransport.virtualLinks[i]->DeInit(); + } + } for (u32 i = 0; i < singleSubCommTransport.links.size();i++) { if (singleSubCommTransport.transportRequests[i].isValid && singleSubCommTransport.links[i] != nullptr) { @@ -1805,7 +1836,6 @@ HcclResult HcclCommunicator::DestroyNetworkResources() } raResourceInit_ = false; - return HCCL_SUCCESS; } @@ -1850,7 +1880,6 @@ HcclResult HcclCommunicator::AtomicInitSet() "already been initialized", HCCL_ERROR_CODE(HCCL_E_INTERNAL)), HCCL_E_INTERNAL); - return HCCL_SUCCESS; } @@ -1887,7 +1916,6 @@ HcclResult HcclCommunicator::CheckDeviceType(const DevType deviceType) const return HCCL_E_PARA; } - return HCCL_SUCCESS; } @@ -1898,7 +1926,6 @@ HcclResult HcclCommunicator::CheckReductionOp(const HcclReduceOp op) const return HCCL_E_PARA; } - return HCCL_SUCCESS; } @@ -1909,7 +1936,6 @@ HcclResult HcclCommunicator::CheckUserRank(const u32 userRank) const HCCL_ERROR_CODE(HCCL_E_PARA), userRank, userRankSize_); return HCCL_E_PARA; } - return HCCL_SUCCESS; } @@ -1920,7 +1946,6 @@ HcclResult HcclCommunicator::CheckCount(const u64 count) const HCCL_ERROR_CODE(HCCL_E_PARA), count, SYS_MAX_COUNT); return HCCL_E_PARA; } - return HCCL_SUCCESS; } @@ -1957,7 +1982,6 @@ HcclResult HcclCommunicator::GetGroupRanksInfo(const std::vector &groupRank return HCCL_E_PARA; } } - return HCCL_SUCCESS; } @@ -1972,7 +1996,6 @@ HcclResult HcclCommunicator::GetGroupCommonData(WorldGroupInfo &groupCommonData) groupCommonData.worldRankInfoList = rankInfoList_; groupCommonData.ranksPort = ranksPort_; - return HCCL_SUCCESS; } @@ -2010,7 +2033,6 @@ HcclResult HcclCommunicator::InitProfiling() profilingInitiated_ = true; // isExecuteProfilingInit_用于记录本impl是否执行了taskInfoSaver的初始化,用于进行对应的释放 isExecuteProfilingInit_ = true; - return HCCL_SUCCESS; } @@ -2019,7 +2041,6 @@ HcclResult HcclCommunicator::DeinitProfiling() CHK_PRT_RET(!profilingInitiated_, HCCL_DEBUG("Profiling plugin has not been Initiated"), HCCL_SUCCESS); profilingInitiated_ = false; HCCL_INFO("Profiling is deinitiated."); - return HCCL_SUCCESS; } @@ -2027,7 +2048,6 @@ HcclResult HcclCommunicator::RegistTaskExceptionHandler() const { CHK_RET(TaskExceptionHandler::Init()); - return HCCL_SUCCESS; } @@ -2035,7 +2055,6 @@ HcclResult HcclCommunicator::UnRegistTaskExceptionHandler() const { CHK_RET(TaskExceptionHandler::DeInit()); - return HCCL_SUCCESS; } @@ -2059,7 +2078,6 @@ HcclResult HcclCommunicator::ReleaseCommInfos() if (implAlg_ != nullptr) { return implAlg_->ReleaseCommInfos(); } - return HCCL_SUCCESS; } @@ -2073,7 +2091,6 @@ HcclResult HcclCommunicator::InitProfiler() HCCL_INFO("[BASE][InitProfiler]Register CtrlCallBack success."); - return HCCL_SUCCESS; } @@ -2135,7 +2152,6 @@ HcclResult HcclCommunicator::InitNic() return HCCL_E_PARA; } nicInitialized_ = true; - return HCCL_SUCCESS; } @@ -2165,7 +2181,6 @@ HcclResult HcclCommunicator::DeinitNic() return HCCL_E_PARA; } nicInitialized_ = false; - return HCCL_SUCCESS; } @@ -2178,7 +2193,6 @@ HcclResult HcclCommunicator::RegisterToHeartBeat() HcclResult HcclCommunicator::SetGlobalWorkSpace(std::vector &globalWorkSpaceAddr) { CHK_RET(HcclSetGlobalWorkSpace(dispatcher_, globalWorkSpaceAddr)); - return HCCL_SUCCESS; } @@ -2189,14 +2203,12 @@ HcclResult HcclCommunicator::GetandClearOverFlowTasks(std::vector } else { HCCL_WARNING("[impl][GetDumpTask] profilerManager_ not set"); } - return HCCL_SUCCESS; } HcclResult HcclCommunicator::GetDeviceId(s32 &deviceId) const { deviceId = deviceLogicId_; - return HCCL_SUCCESS; } @@ -2222,7 +2234,6 @@ HcclResult HcclCommunicator::GetCqeError(HcclResult &result) { CHK_RET(HeartbeatPub::CheckErrorCqe(deviceLogicId_, identifier_, result)); - return HCCL_SUCCESS; } @@ -2236,7 +2247,6 @@ HcclResult HcclCommunicator::MrManagerInit() CHK_RET(mrManager_->Init()); mrManagerInit_ = true; } - return HCCL_SUCCESS; } @@ -2248,7 +2258,6 @@ HcclResult HcclCommunicator::MrManagerDeInit() mrManager_ = nullptr; mrManagerInit_ = false; } - return HCCL_SUCCESS; } @@ -2256,7 +2265,6 @@ HcclResult HcclCommunicator::SupportDeterministicOptim(bool &isDeterministicOpti { CHK_SMART_PTR_NULL(implAlg_); CHK_RET(implAlg_->SupportDeterministicOptim(isDeterministicOptim)); - return HCCL_SUCCESS; } @@ -2268,7 +2276,6 @@ HcclResult HcclCommunicator::GetHccsLinkNum(u32 &numHccsLink) return HCCL_E_PARA; } numHccsLink = iter->second.size(); - return HCCL_SUCCESS; } @@ -2276,7 +2283,6 @@ HcclResult HcclCommunicator::SetMeshAggregationRankSize(u32 size) { HCCL_INFO("[Set][HcclCommunicator][MeshAggregationRankSize]set MeshAggregationRankSize[%u].", size); meshAggregationRankSize_ = size; - return HCCL_SUCCESS; } @@ -2296,12 +2302,26 @@ HcclResult HcclCommunicator::AllGather(const std::string &tag, void *inputPtr, v CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AllGather(tag, inputPtr, outputPtr, inputCount, dataType, streamObj, opInfo)); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = inputCount * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize * userRankSize_; + opParam.DataDes.count = inputCount; + opParam.DataDes.dataType = dataType; + opParam.reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED; + opParam.stream = streamObj; + opParam.syncMode = SyncMode::DEFAULT_TIMEWAITSYNCMODE; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLGATHER, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2326,14 +2346,13 @@ HcclResult HcclCommunicator::AicpuUnfold(const std::string &tag, void *inputPtr, } } - ret = AicpuKfcTilingDataLaunch(inputPtr, outputPtr, count, dataType, op, stream, cmdType); + ret = AicpuKfcTilingDataLaunch(tag, inputPtr, outputPtr, count, dataType, op, stream, cmdType); if (ret != HCCL_SUCCESS) { HCCL_ERROR("[hcclImpl][TilingData]aicpu unfold tiling data launch failed. return[%d] inputPtr[%p]"\ "outputPtr[%p] count[%llu] dataType[%s] op[%s]", ret, inputPtr, outputPtr, count, GetDataTypeEnumStr(dataType).c_str(), GetReduceOpEnumStr(op).c_str()); return ret; } - return HCCL_SUCCESS; } @@ -2360,13 +2379,27 @@ HcclResult HcclCommunicator::AllGatherOutPlace(const std::string &tag, void *inp CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AllGatherOutPlace(tag, inputPtr, outputPtr, inputCount, dataType, streamObj, - opBaseAtraceInfo_)); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = inputCount * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = inputCount; + opParam.DataDes.dataType = dataType; + opParam.reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED; + opParam.stream = streamObj; + opParam.syncMode = SyncMode::DEFAULT_TIMEWAITSYNCMODE; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLGATHER, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2444,7 +2477,6 @@ HcclResult HcclCommunicator::AllReduce(const std::string &tag, void *inputPtr, v RestorePreSyncMode(preSyncMode, syncMode); - return HCCL_SUCCESS; } @@ -2467,14 +2499,13 @@ HcclResult HcclCommunicator::AllReduceAicpuUnfold(const std::string &tag, void * } } - ret = AicpuKfcTilingDataLaunch(inputPtr, outputPtr, count, dataType, op, stream, HcclCMDType::HCCL_CMD_ALLREDUCE); + ret = AicpuKfcTilingDataLaunch(tag, inputPtr, outputPtr, count, dataType, op, stream, HcclCMDType::HCCL_CMD_ALLREDUCE); if (ret != HCCL_SUCCESS) { HCCL_ERROR("[hcclImpl][TilingData]aicpu unfold tiling data launch failed. return[%d] inputPtr[%p]"\ "outputPtr[%p] count[%llu] dataType[%s] op[%s]", ret, inputPtr, outputPtr, count, GetDataTypeEnumStr(dataType).c_str(), GetReduceOpEnumStr(op).c_str()); return ret; } - return HCCL_SUCCESS; } @@ -2540,11 +2571,9 @@ HcclResult HcclCommunicator::AllReduceOutPlace(const std::string &tag, void *inp RestorePreSyncMode(preSyncMode, syncMode); - return HCCL_SUCCESS; } - HcclResult HcclCommunicator::AlltoAllV(const void *sendBuf, const void *sendCounts, const void *sdispls, HcclDataType sendType, const void *recvBuf, const void *recvCounts, const void *rdispls, HcclDataType recvType, rtStream_t stream, const std::string &tag) @@ -2570,13 +2599,26 @@ HcclResult HcclCommunicator::AlltoAllV(const void *sendBuf, const void *sendCoun CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AlltoAllV(sendBuf, sendCounts, sdispls, sendType, recvBuf, recvCounts, rdispls, recvType, - streamObj, tag)); + CHK_RET(CreateCommCCLbuffer()); + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = const_cast(sendBuf); + opParam.outputPtr = const_cast(recvBuf); + opParam.All2AllDataDes.sendType = sendType; + opParam.All2AllDataDes.recvType = recvType; + opParam.All2AllDataDes.sendCounts = const_cast(sendCounts); + opParam.All2AllDataDes.recvCounts = const_cast(recvCounts); + opParam.All2AllDataDes.sdispls = const_cast(sdispls); + opParam.All2AllDataDes.rdispls = const_cast(rdispls); + opParam.stream = streamObj; + opParam.opType = HcclCMDType::HCCL_CMD_ALLTOALLV; + + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLTOALLV, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2605,13 +2647,25 @@ HcclResult HcclCommunicator::AlltoAllVOutPlace(const void *sendBuf, const void * // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AlltoAllVOutPlace(sendBuf, sendCounts, sdispls, sendType, recvBuf, recvCounts, rdispls, recvType, - streamObj, tag)); + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = const_cast(sendBuf); + opParam.outputPtr = const_cast(recvBuf); + opParam.All2AllDataDes.sendType = sendType; + opParam.All2AllDataDes.recvType = recvType; + opParam.All2AllDataDes.sendCounts = const_cast(sendCounts); + opParam.All2AllDataDes.recvCounts = const_cast(recvCounts); + opParam.All2AllDataDes.sdispls = const_cast(sdispls); + opParam.All2AllDataDes.rdispls = const_cast(rdispls); + opParam.stream = streamObj; + opParam.opType = HcclCMDType::HCCL_CMD_ALLTOALLV; + + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLTOALLV, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2638,11 +2692,21 @@ HcclResult HcclCommunicator::AlltoAllVC(const void *sendBuf, const void *sendCou // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AlltoAllVC(sendBuf, sendCountMatrix, sendType, recvBuf, recvType, streamObj, tag)); + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = const_cast(sendBuf); + opParam.outputPtr = const_cast(recvBuf); + opParam.All2AllDataDes.sendType = sendType; + opParam.All2AllDataDes.recvType = recvType; + opParam.All2AllDataDes.sendCountMatrix = const_cast(sendCountMatrix); + opParam.stream = streamObj; + opParam.opType = HcclCMDType::HCCL_CMD_ALLTOALLVC; + + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLTOALLVC, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2670,14 +2734,30 @@ HcclResult HcclCommunicator::AlltoAllVCOutPlace(const void *sendBuf, const void // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AlltoAllVCOutPlace(sendBuf, sendCountMatrix, sendType, recvBuf, recvType, streamObj, tag)); + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = const_cast(sendBuf); + opParam.outputPtr = const_cast(recvBuf); + opParam.All2AllDataDes.sendType = sendType; + opParam.All2AllDataDes.recvType = recvType; + opParam.All2AllDataDes.sendCountMatrix = const_cast(sendCountMatrix); + opParam.stream = streamObj; + opParam.opType = HcclCMDType::HCCL_CMD_ALLTOALLVC; + + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLTOALLVC, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } +std::vector HcclCommunicator::GenerateSendCountMatrix(u64 count, u32 rankSize) +{ + std::vector sendCountMatrix(rankSize * rankSize, count); + return sendCountMatrix; +} + HcclResult HcclCommunicator::AlltoAll(const void *sendBuf, u64 sendCount, HcclDataType sendType, const void *recvBuf, u64 recvCount, HcclDataType recvType, rtStream_t stream, const std::string &tag) { @@ -2698,14 +2778,26 @@ HcclResult HcclCommunicator::AlltoAll(const void *sendBuf, u64 sendCount, HcclDa Stream streamObj(stream); CHK_RET(callbackTask_->CallbackRegStream(stream)); + // 生成sendCountMatrix矩阵,alltoall的底层实现走alltoallvc + std::vector sendCountMatrix = GenerateSendCountMatrix(sendCount, userRankSize_); + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = const_cast(sendBuf); + opParam.outputPtr = const_cast(recvBuf); + opParam.All2AllDataDes.sendType = sendType; + opParam.All2AllDataDes.recvType = recvType; + opParam.All2AllDataDes.sendCountMatrix = static_cast(sendCountMatrix.data()); + opParam.stream = streamObj; + opParam.opType = HcclCMDType::HCCL_CMD_ALLTOALL; + // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->AlltoAll(sendBuf, sendCount, sendType, recvBuf, recvCount, recvType, streamObj, tag)); + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLTOALL, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2728,11 +2820,29 @@ HcclResult HcclCommunicator::Broadcast(const std::string &tag, void *ptr, u64 co isSetHDCModeInfo_ = true; } implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->Broadcast(tag, ptr, count, dataType, root, streamObj)); + if (isHaveCpuRank_) { + CHK_RET(implAlg_->Broadcast(tag, ptr, count, dataType, root, streamObj)); + } else { + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = ptr; + opParam.outputPtr = ptr; + opParam.inputSize = totalSize; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.root = root; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_BROADCAST, opParam)); + } // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2754,12 +2864,28 @@ HcclResult HcclCommunicator::BroadcastOutPlace(const std::string &tag, void *ptr // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->BroadcastOutPlace(tag, ptr, count, dataType, root, streamObj, - opBaseAtraceInfo_)); + + if (isHaveCpuRank_) { + CHK_RET(implAlg_->Broadcast(tag, ptr, count, dataType, root, streamObj)); + } else { + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = ptr; + opParam.outputPtr = ptr; + opParam.inputSize = totalSize; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.root = root; + opParam.stream = streamObj; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_BROADCAST, opParam)); + } // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2781,11 +2907,25 @@ HcclResult HcclCommunicator::Scatter(const std::string &tag, void *inputPtr, voi // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->Scatter(tag, inputPtr, outputPtr, recvCount, dataType, root, streamObj)); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 outputSize = recvCount * perDataSize; + u64 totalSize = outputSize * userRankSize_; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = recvCount; + opParam.DataDes.dataType = dataType; + opParam.stream = streamObj; + opParam.root = root; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_SCATTER, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2807,12 +2947,26 @@ HcclResult HcclCommunicator::ScatterOutPlace(const std::string &tag, void *input // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->ScatterOutPlace(tag, inputPtr, outputPtr, recvCount, dataType, root, streamObj, - opBaseAtraceInfo_)); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 outputSize = recvCount * perDataSize; + u64 totalSize = outputSize * userRankSize_; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = recvCount; + opParam.DataDes.dataType = dataType; + opParam.stream = streamObj; + opParam.root = root; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_SCATTER, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2833,11 +2987,24 @@ HcclResult HcclCommunicator::Reduce(const std::string &tag, void *inputPtr, void // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); - CHK_RET(implAlg_->Reduce(tag, inputPtr, outputPtr, count, dataType, op, root, streamObj)); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.reduceType = op; + opParam.root = root; + opParam.stream = streamObj; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_REDUCE, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2860,12 +3027,25 @@ HcclResult HcclCommunicator::ReduceOutPlace(const std::string &tag, void *inputP // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->ReduceOutPlace(tag, inputPtr, outputPtr, count, dataType, op, root, streamObj, - opBaseAtraceInfo_)); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.inputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.reduceType = op; + opParam.root = root; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_REDUCE, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2885,11 +3065,23 @@ HcclResult HcclCommunicator::ReduceScatter(const std::string &tag, void *inputPt // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->ReduceScatter(tag, inputPtr, outputPtr, count, dataType, op, streamObj, opInfo)); + + u32 perDataSize = SIZE_TABLE[dataType]; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = userRankSize_ * count * perDataSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = count * perDataSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.reduceType = op; + opParam.stream = streamObj; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2919,12 +3111,24 @@ HcclResult HcclCommunicator::ReduceScatterOutPlace(const std::string &tag, void // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->ReduceScatterOutPlace(tag, inputPtr, outputPtr, count, dataType, op, streamObj, - opBaseAtraceInfo_)); + + u32 perDataSize = SIZE_TABLE[dataType]; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = userRankSize_ * count * perDataSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = count * perDataSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.reduceType = op; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_REDUCE_SCATTER, opParam)); // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2956,7 +3160,6 @@ HcclResult HcclCommunicator::ProcessSendRecvTasks(const std::string &tag, std::v // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -2975,11 +3178,27 @@ HcclResult HcclCommunicator::Send(const std::string &tag, void *inputPtr, u64 co // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); - CHK_RET(implAlg_->Send(tag, inputPtr, count, dataType, destRank, streamObj)); - + if (isHaveCpuRank_) { + CHK_RET(implAlg_->Send(tag, inputPtr, count, dataType, destRank, streamObj)); + } else { + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = inputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + opParam.dstRank = destRank; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_SEND, opParam)); + } // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -3002,11 +3221,29 @@ HcclResult HcclCommunicator::SendOutPlace(const std::string &tag, void *inputPtr // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->SendOutPlace(tag, inputPtr, count, dataType, destRank, streamObj)); + + if (isHaveCpuRank_) { + CHK_RET(implAlg_->Send(tag, inputPtr, count, dataType, destRank, streamObj)); + } else { + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = inputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + opParam.dstRank = destRank; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_SEND, opParam)); + } // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -3025,11 +3262,27 @@ HcclResult HcclCommunicator::Receive(const std::string &tag, void *outputPtr, u6 // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); - CHK_RET(implAlg_->Receive(tag, outputPtr, count, dataType, srcRank, streamObj)); - + if (isHaveCpuRank_) { + CHK_RET(implAlg_->Receive(tag, outputPtr, count, dataType, srcRank, streamObj)); + } else { + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = outputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + opParam.srcRank = srcRank; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_RECEIVE, opParam)); + } // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -3052,11 +3305,29 @@ HcclResult HcclCommunicator::ReceiveOutPlace(const std::string &tag, void *outpu // 头计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, HEAD)); implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, groupRanksPort_, isSetHDCModeInfo_, isUseRankPort_); - CHK_RET(implAlg_->ReceiveOutPlace(tag, outputPtr, count, dataType, srcRank, streamObj)); + + if (isHaveCpuRank_) { + CHK_RET(implAlg_->Receive(tag, outputPtr, count, dataType, srcRank, streamObj)); + } else { + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = count * perDataSize; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = outputPtr; + opParam.inputSize = totalSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = count; + opParam.DataDes.dataType = dataType; + opParam.stream = streamObj; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + opParam.srcRank = srcRank; + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_RECEIVE, opParam)); + } // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); - return HCCL_SUCCESS; } @@ -3084,6 +3355,78 @@ HcclResult HcclCommunicator::Gather(const std::string &tag, void *inputPtr, void // 尾计数任务 CHK_RET(StarsCounter(dispatcher_, streamObj, TAIL)); + return HCCL_SUCCESS; +} + +HcclResult HcclCommunicator::SetInfoToDevice(const OpParam &opParam, + const std::unique_ptr &preMetaInfo, + const HcclWorkflowMode &mode, Stream &stream) +{ + auto inAlltoAllvParaBuffer = cclBufferManager_.GetInAlltoAllvParaBuffer(); + auto outAlltoAllvParaBuffer = cclBufferManager_.GetOutAlltoAllvParaBuffer(); + if ((inAlltoAllvParaBuffer.ptr() == nullptr) || (outAlltoAllvParaBuffer.ptr() == nullptr)) { + CHK_RET( + cclBufferManager_.InitAlltoAllvParaBuffer(preMetaInfo->inputSize, preMetaInfo->outputSize)); + inAlltoAllvParaBuffer = cclBufferManager_.GetInAlltoAllvParaBuffer(); + } + if (mode != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + alltoalltream_ = Stream(StreamType::STREAM_TYPE_ONLINE); + stream = alltoalltream_; + } else { + stream = const_cast(opParam.stream); + } + CHK_RET(SetWorkflowMode(HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE)); + CHK_RET(hcclStreamSynchronize(stream.ptr())); + CHK_RET(hrtMemSyncCopy(inAlltoAllvParaBuffer.ptr(), preMetaInfo->inputSize, preMetaInfo->inputData.data(), + preMetaInfo->inputSize, HcclRtMemcpyKind::HCCL_RT_MEMCPY_KIND_HOST_TO_DEVICE)); + return HCCL_SUCCESS; +} + +HcclResult HcclCommunicator::GetInfoFromDevice(const OpParam &opParam, + const std::unique_ptr &preMetaInfo, + const HcclWorkflowMode &mode, Stream &stream, HostMem& hostCollectBuffer) +{ + CHK_RET(hrtMemSyncCopy(hostCollectBuffer.ptr(), preMetaInfo->outputSize, + cclBufferManager_.GetOutAlltoAllvParaBuffer().ptr(), preMetaInfo->outputSize, + HcclRtMemcpyKind::HCCL_RT_MEMCPY_KIND_DEVICE_TO_HOST)); + + // 非单算子场景,中转内存使用完之后直接释放 + if (mode != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + cclBufferManager_.ReleaseAlltoAllvParaBuffer(); + } + + return HCCL_SUCCESS; +} + +HcclResult HcclCommunicator::RegressCalPreOp(const std::unique_ptr& algOperator, + const OpParam &opParam, std::unique_ptr &preMetaInfo) +{ + Stream preProcessStream; + OpParam preProcessOpParam; + HcclWorkflowMode mode = GetWorkflowMode(); + CHK_PRT_RET(mode == HcclWorkflowMode::HCCL_WORKFLOW_MODE_RESERVED, HCCL_ERROR("Invalid Workflow Mode[%d]", + mode),HCCL_E_INTERNAL); + + // h to d + CHK_RET(SetInfoToDevice(opParam, preMetaInfo, mode, preProcessStream)); + // opParam准备 + CHK_RET(algOperator->PreparePreOpParam(preProcessOpParam, preMetaInfo, preProcessStream)); + + // 回归调用其它算子 + HCCL_INFO("[HcclCommunicator][RegressCalPreOp] Regression calls other operators and opType[%u]", + preMetaInfo->opType); + CHK_RET(ExecOp(preMetaInfo->opType, preProcessOpParam)); + CHK_RET(hcclStreamSynchronize(preProcessStream.ptr())); + HCCL_DEBUG("[HcclCommunicator][RegressCalPreOp] preProcess tag[%s].", preProcessOpParam.tag.c_str()); + SetWorkflowMode(mode); + + // d to h + HostMem hostCollectBuffer = HostMem::alloc(preMetaInfo->outputSize); + CHK_PTR_NULL(hostCollectBuffer.ptr()); + CHK_RET(GetInfoFromDevice(opParam, preMetaInfo, mode, preProcessStream, hostCollectBuffer)); + + algOperator->SetPreProcessResult(std::move(hostCollectBuffer)); + HCCL_INFO("[HcclCommunicator][RegressCalPreOp] run success!"); return HCCL_SUCCESS; } @@ -3095,12 +3438,22 @@ HcclResult HcclCommunicator::ExecOp(HcclCMDType opType, const OpParam &opParam) // 算法选择 std::string algName; std::string newTag; + + std::unique_ptr preMetaInfo = std::make_unique(); + CHK_SMART_PTR_NULL(preMetaInfo); + + bool preProcessFlag = algOperator->JudgeIfNeedPreProcessAndGetParam(opParam, preMetaInfo); + if (preProcessFlag) { + CHK_RET(RegressCalPreOp(algOperator, opParam, preMetaInfo)); + } + CHK_RET(algOperator->SelectAlg(opParam.tag, opParam, algName, newTag)); + // 资源创建 if (resMap_.find(newTag) == resMap_.end()) { AlgResourceRequest resRequest; CHK_RET(algOperator->CalcResRequest(algName, opParam, resRequest)); - CHK_RET(AllocAlgResource(newTag, opParam, resRequest, resMap_[newTag])); + CHK_RET(AllocAlgResource(newTag, opType, opParam, resRequest, resMap_[newTag])); if (!isHaveCpuRank_) { if (isUseRankPort_) { HeartbeatPub::SetRankPortInfo(deviceLogicId_, isUseRankPort_, groupRanksPort_); @@ -3111,14 +3464,40 @@ HcclResult HcclCommunicator::ExecOp(HcclCMDType opType, const OpParam &opParam) CHK_RET(RegisterToHeartBeat()); } } + } else { + bool needRecreateAlltoallComm = algOperator->CheckNeedRecreateComm(algName, resMap_[newTag].scratchMem.size()); + HCCL_INFO("resMap_ find this newTag[%s], and need to judge whether recreate comm [%d]", newTag.c_str(), + needRecreateAlltoallComm); + if (needRecreateAlltoallComm) { + AlgResourceRequest resRequest; + CHK_RET(algOperator->CalcResRequest(algName, opParam, resRequest)); + CHK_RET(AllocAlgResource(newTag, opType, opParam, resRequest, resMap_[newTag])); + if (!isHaveCpuRank_) { + if (isUseRankPort_) { + HeartbeatPub::SetRankPortInfo(deviceLogicId_, isUseRankPort_, groupRanksPort_); + } + if (opType != HcclCMDType::HCCL_CMD_SEND && + opType != HcclCMDType::HCCL_CMD_RECEIVE && + opType != HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) { + CHK_RET(RegisterToHeartBeat()); + } + } + } else { + if (opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALLV || opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALLVC + || opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALL) { + bool isAlltoAllZCopyMode = false; + DeviceMem tinySendRecvMem; + CHK_RET(implAlg_->GetAlltoAllStatus(tinySendRecvMem, isAlltoAllZCopyMode)); + CHK_RET(CalcTinySendRecvMem(opParam, resMap_[newTag], tinySendRecvMem)); + } + } } // 算法执行 CHK_RET(algOperator->Orchestrate(algName, opParam, resMap_[newTag])); - return HCCL_SUCCESS; } -// batchsendrecv需要根据任务来确定与哪些卡建链,因此复用tag,并在相应基础上实现增量建链 +// batchsendrecv需要根据任务来确定与哪些卡建链,因此复用tag,并在相应resmap里面实现增量建链 HcclResult HcclCommunicator::ExecOpExt(HcclCMDType opType, const OpParam &opParam) { std::unique_ptr algOperator = implAlg_->GetAlgOperator(opType); @@ -3130,8 +3509,8 @@ HcclResult HcclCommunicator::ExecOpExt(HcclCMDType opType, const OpParam &opPara if (resMap_.find(newTag) == resMap_.end()) { AlgResourceRequest resRequest; CHK_RET(algOperator->CalcResRequest(algName, opParam, resRequest)); - CHK_RET(AllocAlgResource(newTag, opParam, resRequest, resMap_[newTag])); - } else if (algOperator->NeedIncrCreateLink(algName, opParam)) { + CHK_RET(AllocAlgResource(newTag, opType, opParam, resRequest, resMap_[newTag])); + } else { // 增量建链 AlgResourceRequest resRequest; CHK_RET(algOperator->CalcIncreLinkRequest(algName, opParam, resRequest)); @@ -3139,17 +3518,66 @@ HcclResult HcclCommunicator::ExecOpExt(HcclCMDType opType, const OpParam &opPara } // 算法执行 CHK_RET(algOperator->Orchestrate(algName, opParam, resMap_[newTag])); + return HCCL_SUCCESS; +} + +HcclResult HcclCommunicator::CalcTinySendRecvMem(const OpParam &opParam, AlgResourceResponse &algResResponse, + DeviceMem &tinySendRecvMem) +{ + u64 sendCount = 0; + u64 recvCount = 0; + if (opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALLV) { + for (u32 i = 0; i < userRankSize_; i++) { + u64 curSendCount = *(static_cast(opParam.All2AllDataDes.sendCounts) + i) + + *(static_cast(opParam.All2AllDataDes.sdispls) + i); + sendCount = std::max(sendCount, curSendCount); + u64 curRecvCount = *(static_cast(opParam.All2AllDataDes.recvCounts) + i) + + *(static_cast(opParam.All2AllDataDes.rdispls) + i); + recvCount = std::max(recvCount, curRecvCount); + } + } else { + for (u32 i = 0; i < userRankSize_; i++) { + sendCount += *(static_cast(opParam.All2AllDataDes.sendCountMatrix) + + userRank_ * userRankSize_ + i); + recvCount += *(static_cast(opParam.All2AllDataDes.sendCountMatrix) + + userRank_ + userRankSize_ * i); + } + } + + u32 sendTypeSize = 0, recvTypeSize = 0; + CHK_RET(SalGetDataTypeSize(opParam.All2AllDataDes.sendType, sendTypeSize)); + CHK_RET(SalGetDataTypeSize(opParam.All2AllDataDes.recvType, recvTypeSize)); + + // 在sendCount/recvCount全0时, 使用tinySendRecvMem, 避免使用空deviceMem + algResResponse.paramInputMem = sendCount == 0 ? + DeviceMem::create(tinySendRecvMem.ptr(), tinySendRecvMem.size()) : + DeviceMem::create(opParam.inputPtr, sendCount * sendTypeSize); + algResResponse.paramOutputMem = recvCount == 0 ? + DeviceMem::create(tinySendRecvMem.ptr(), tinySendRecvMem.size()) : + DeviceMem::create(opParam.outputPtr, recvCount * recvTypeSize); + + HCCL_INFO("[HcclCommunicator][CalcTinySendRecvMem] senMem addr[%p], sendSize[%llu]," \ + "RecvMem addr[%p], RecvSize[%llu],", algResResponse.paramInputMem.ptr(), + algResResponse.paramInputMem.size(), algResResponse.paramOutputMem.ptr(), + algResResponse.paramOutputMem.size()); return HCCL_SUCCESS; } -HcclResult HcclCommunicator::AllocAlgResource(const std::string &newTag, const OpParam &opParam, + +HcclResult HcclCommunicator::AllocAlgResource(const std::string &newTag, HcclCMDType opType, const OpParam &opParam, AlgResourceRequest &resRequest, AlgResourceResponse &algResResponse) { HcclResult ret = HCCL_SUCCESS; if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { if (resRequest.scratchMemSize > 0) { algResResponse.scratchMem = GetWorkspaceScracthMem(opParam.tag, resRequest.scratchMemSize); + if (opType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER) { + // cce reduce地址32字节对齐,截取32字节对齐后的内存地址 + u32 addOffset = (reinterpret_cast(algResResponse.scratchMem.ptr())) % CCE_REDUCE_ALIGN_SIZE; + u64 totalSize = userRankSize_ * opParam.DataDes.count * SIZE_TABLE[opParam.DataDes.dataType]; + algResResponse.scratchMem = algResResponse.scratchMem.range(addOffset, totalSize); + } } if (resRequest.streamNum > 0) { algResResponse.streams = GetWorkspaceSubStreams(opParam.tag, resRequest.streamNum); @@ -3157,6 +3585,12 @@ HcclResult HcclCommunicator::AllocAlgResource(const std::string &newTag, const O } else if (GetWorkflowMode() == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { if (resRequest.scratchMemSize > 0) { algResResponse.scratchMem = DeviceMem::alloc(resRequest.scratchMemSize); + if (opType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER) { + // cce reduce地址32字节对齐,截取32字节对齐后的内存地址 + u32 addOffset = (reinterpret_cast(algResResponse.scratchMem.ptr())) % CCE_REDUCE_ALIGN_SIZE; + algResResponse.scratchMem = algResResponse.scratchMem.range(addOffset, + cclBufferManager_.GetInCCLbufferSize()); + } } if (resRequest.streamNum > 0) { CHK_RET(opStreamManager_->RegisterMaster(opParam.stream)); @@ -3172,8 +3606,16 @@ HcclResult HcclCommunicator::AllocAlgResource(const std::string &newTag, const O algResResponse.cclInputMem = cclBufferManager_.GetInCCLbuffer(); algResResponse.cclOutputMem = cclBufferManager_.GetOutCCLbuffer(); - algResResponse.paramInputMem = DeviceMem::create(opParam.inputPtr, opParam.inputSize); - algResResponse.paramOutputMem = DeviceMem::create(opParam.outputPtr, opParam.outputSize); + if (opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALLV || opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALLVC + || opParam.opType == HcclCMDType::HCCL_CMD_ALLTOALL) { + bool isAlltoAllZCopyMode = false; + DeviceMem tinySendRecvMem; + CHK_RET(implAlg_->GetAlltoAllStatus(tinySendRecvMem, isAlltoAllZCopyMode)); + CHK_RET(CalcTinySendRecvMem(opParam, algResResponse, tinySendRecvMem)); + } else { + algResResponse.paramInputMem = DeviceMem::create(opParam.inputPtr, opParam.inputSize); + algResResponse.paramOutputMem = DeviceMem::create(opParam.outputPtr, opParam.outputSize); + } if (resRequest.needAivBuffer) { ret = cclBufferManager_.CreateCommAIVbuffer(); @@ -3186,11 +3628,16 @@ HcclResult HcclCommunicator::AllocAlgResource(const std::string &newTag, const O TransportIOMem transMem{algResResponse.cclInputMem, algResResponse.cclOutputMem, algResResponse.paramInputMem, algResResponse.paramOutputMem, algResResponse.scratchMem, algResResponse.aivInputMem, algResResponse.aivOutputMem}; + HCCL_DEBUG("algResResponse.cclInputMem[%p], size[%llu]; algResResponse.cclOutputMem[%p], " \ + "size[%llu]; algResResponse.paramInputMem[%p], size[%llu]; algResResponse.paramOutputMem[%p], size[%llu]", + algResResponse.cclInputMem.ptr(), algResResponse.cclInputMem.size(), + algResResponse.cclOutputMem.ptr(), algResResponse.cclOutputMem.size(), + algResResponse.paramInputMem.ptr(), algResResponse.paramInputMem.size(), + algResResponse.paramOutputMem.ptr(), algResResponse.paramOutputMem.size()); algResResponse.opTransportResponse = resRequest.opTransport; ret = transportManager_->Alloc(opParam.tag, transMem, algResResponse.opTransportResponse); CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Alloc][AlgResource]Alloc transports failed, tag[%s]", newTag.c_str()), ret); - return HCCL_SUCCESS; } @@ -3199,7 +3646,7 @@ HcclResult HcclCommunicator::IncreAllocLink(const std::string &newTag, const OpP { algResResponse.cclInputMem = cclBufferManager_.GetInCCLbuffer(); algResResponse.cclOutputMem = cclBufferManager_.GetOutCCLbuffer(); - + TransportIOMem transMem{algResResponse.cclInputMem, algResResponse.cclOutputMem, algResResponse.paramInputMem, algResResponse.paramOutputMem, algResResponse.scratchMem, algResResponse.aivInputMem, algResResponse.aivOutputMem}; @@ -3208,7 +3655,6 @@ HcclResult HcclCommunicator::IncreAllocLink(const std::string &newTag, const OpP algResResponse.opTransportResponse); CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[IncreAlloc][Link]IncreAlloc transports failed, tag[%s]", newTag.c_str()), ret); - return HCCL_SUCCESS; } @@ -3229,7 +3675,6 @@ HcclResult HcclCommunicator::InitRecvMsgAndRequestBuffer() HCCL_INFO("InitRequestBuffer Success!"); } - return HCCL_SUCCESS; } @@ -3258,14 +3703,12 @@ HcclResult HcclCommunicator::InitMemBlocksAndRecvWrMem() HCCL_INFO("InitMemBlocksAndRecvWrMem Success!"); } - return HCCL_SUCCESS; } HcclResult HcclCommunicator::SetDevicePid(s32 devicePid) { devicePid_ = devicePid; - return HCCL_SUCCESS; } @@ -3288,7 +3731,6 @@ HcclResult HcclCommunicator::CreateWorkSpace(u64 size, DeviceMem &buffer) const CHK_PRT_RET(size && !buffer, HCCL_ERROR("[Create][WorkSpace]Create work space size[%llu] fail,"\ "please check workspace size.", size), HCCL_E_PTR); CHK_RET(hrtMemSet(buffer.ptr(), size, size)); - return HCCL_SUCCESS; } @@ -3296,7 +3738,6 @@ HcclResult HcclCommunicator::GetWorkSpace(u64 *workSpaceSize, u64 *workSpace) co { *workSpaceSize = workSpaceSize_; *workSpace = reinterpret_cast(workSpace_.ptr()); - return HCCL_SUCCESS; } @@ -3307,7 +3748,6 @@ HcclResult HcclCommunicator::InitWorkSpace() CHK_RET(CreateWorkSpace(workSpaceSize_, workSpace_)); } - return HCCL_SUCCESS; } @@ -3321,7 +3761,6 @@ HcclResult HcclCommunicator::CreateCommResource(const std::string &tag, rtStream CHK_RET(CreateCommAndStreamRes(tag, stream)); CHK_RET(Mc2CreateAndLaunchContext(aiCpuStream, isOpbaseMode, commContext)); - return HCCL_SUCCESS; } @@ -3346,10 +3785,15 @@ HcclResult HcclCommunicator::Mc2CreateAndLaunchContext(rtStream_t aiCpuStream, b combinOpara_.config.deterministic = GetDeterministicConfig(); // retryEnable 写入aicpu_ctx combinOpara_.config.retryEnable = static_cast(retryEnable_); + combinOpara_.config.retryHoldTime = GetExternalInputRetryHoldTime(); + combinOpara_.config.retryIntervalTime = GetExternalInputRetryIntervalTime(); combinOpara_.config.notifyWaitTime = (GetExternalInputHcclExecTimeoutSet() != HcclExecTimeoutSet::HCCL_EXEC_TIMEOUT_NOT_SET) ? GetExternalInputHcclExecTimeOut() : NOTIFY_DEFAULT_WAIT_TIME; + combinOpara_.kfcControlTransferH2DParams = kfcControlTransferH2D_->GetCommunicateParams(); + combinOpara_.kfcStatusTransferD2HParams = kfcStatusTransferD2H_->GetCommunicateParams(); + void *overflowAddr = nullptr; if (Is310P3Common()) { CHK_RET(hrtCtxGetOverflowAddr(&overflowAddr)); @@ -3374,7 +3818,6 @@ HcclResult HcclCommunicator::Mc2CreateAndLaunchContext(rtStream_t aiCpuStream, b } *commContext = commContext_.ptr(); - return HCCL_SUCCESS; } @@ -3390,7 +3833,6 @@ HcclResult HcclCommunicator::GetAiCpuNotifyData(const std::shared_ptrGetNotifyData(notifyInfo)); HCCL_INFO("[HcclCommunicator][GetAiCpuNotifyData]esId[%lld], addr[%lld], devId[%u], tsId[%u].", notifyInfo.resId, notifyInfo.addr, notifyInfo.devId, notifyInfo.tsId); - return HCCL_SUCCESS; } @@ -3410,7 +3852,6 @@ HcclResult HcclCommunicator::CreateAndGetAiCpuNotify(std::shared_ptrptr(); - return HCCL_SUCCESS; } HcclResult HcclCommunicator::SetAicpuNotifyInvaild() { combinOpara_.signalInfo.aicpuNotify.resId = INVALID_U64; - return HCCL_SUCCESS; } @@ -3620,7 +4056,6 @@ HcclResult HcclCommunicator::ReplaceCommInfoByTag(const std::string &tag, std::u std::unique_lock replLock(commLock_); tagCommInfo_.erase(tag); tagCommInfo_.insert(std::pair(tag, std::move(*commInfo))); - return HCCL_SUCCESS; } @@ -3651,6 +4086,7 @@ HcclResult HcclCommunicator::CreateMutiStreamResFor310P(const std::string &tag, for (auto &signal : streamInfo.ringDeviceSignalAux) { signal = nullptr; } + u32 notifyNum = resNum * 2; // 2:Signal + SignalAux std::vector> notifys(notifyNum, nullptr); CHK_RET(queueNotifyManager_->Alloc(tag, notifyNum, notifys, NotifyLoadType::DEVICE_NOTIFY)); @@ -3668,7 +4104,6 @@ HcclResult HcclCommunicator::CreateMutiStreamResFor310P(const std::string &tag, } } - return HCCL_SUCCESS; } @@ -3715,7 +4150,8 @@ HcclResult HcclCommunicator::CreateCommAndStreamRes(const std::string &tag, Stre if (!(IsExistMutiStreamRes(tag))) { innerStreamInfo_t streamInfo; std::unique_lock mutiStreamLock(tagStreamInfoLock_); - if (GetRankSize() == 2) { + // 2p场景下,mc2当前algType为518,streamInfo.ringNum走默认流程值为1导致资源申请不足,910_73 mc2固定在节点内默认用mesh + if (GetRankSize() == 2 || deviceType_ == DevType::DEV_TYPE_910_73) { algTypeTmp = AlgType::ALG_NP_MESH_PLUS_RING; } HcclResult ret = HCCL_SUCCESS; @@ -3731,6 +4167,7 @@ HcclResult HcclCommunicator::CreateCommAndStreamRes(const std::string &tag, Stre tag.c_str()), ret); tagStreamInfo_.insert(std::pair(tag, std::move(streamInfo))); + opRetryStreamPtr_->insert(std::make_pair(tag, tagStreamInfo_[tag].ringDeviceStreams)); mutiStreamLock.unlock(); } @@ -3746,7 +4183,6 @@ HcclResult HcclCommunicator::CreateCommAndStreamRes(const std::string &tag, Stre } CHK_RET(SetCommResource(commInputSize, commInputPtr, commOutputPtr, comm, tagStreamInfo_[tag], stream)); - return HCCL_SUCCESS; } @@ -3758,7 +4194,6 @@ HcclResult HcclCommunicator::GetComm(const std::string &tag, CommBase **comm) *comm = tagCommInfo_[tag].commOuter[0].get(); } - return HCCL_SUCCESS; } @@ -3872,7 +4307,6 @@ HcclResult HcclCommunicator::SetCommResource(u64 commBufferSize, void *commInPtr static_cast(stream.id()), rankSize); CHK_RET(ProfilingManagerPub::CallMsprofReportMc2CommInfo(hrtMsprofSysCycleTime(), &hcclMc2Info_, sizeof(hcclMc2Info_))); - return HCCL_SUCCESS; } @@ -3896,7 +4330,6 @@ HcclResult HcclCommunicator::CreateDeviceCommContext(u64 size, DeviceMem &buffer CHK_PRT_RET(size && !buffer, HCCL_ERROR("[Create][DeviceCommContext]Create device commContext size[%llu] fail,"\ "please check deviceCommContext size.", size), HCCL_E_PTR); } - return HCCL_SUCCESS; } @@ -3911,7 +4344,6 @@ void HcclCommunicator::Break() if (implAlg_ != nullptr) { implAlg_->Break(); } - return; } @@ -3977,7 +4409,6 @@ HcclResult HcclCommunicator::SetWorldGroupInfo( ranksPort_.push_back(rank); HCCL_DEBUG("ranksPort port[%u]", rank); } - return HCCL_SUCCESS; } @@ -3988,7 +4419,10 @@ HcclResult HcclCommunicator::GetTopoDesc(HcclTopoDescs *topoDescs, uint32_t topo return HCCL_E_PARA; } - if (deviceType_ == DevType::DEV_TYPE_910B) { + if (deviceType_ == DevType::DEV_TYPE_910_73) { + topoDescs[static_cast(HcclTopoLevel::HCCL_TOPO_L0)].algSets = HCCL_ALG_SWITCH | HCCL_ALG_RING; + topoDescs[static_cast(HcclTopoLevel::HCCL_TOPO_L1)].algSets = HCCL_ALG_RING; + } else if (deviceType_ == DevType::DEV_TYPE_910B) { topoDescs[static_cast(HcclTopoLevel::HCCL_TOPO_L0)].algSets = HCCL_ALG_MESH; topoDescs[static_cast(HcclTopoLevel::HCCL_TOPO_L1)].algSets = 0; } else if (deviceType_ == DevType::DEV_TYPE_310P3) { @@ -3998,12 +4432,12 @@ HcclResult HcclCommunicator::GetTopoDesc(HcclTopoDescs *topoDescs, uint32_t topo topoDescs[static_cast(HcclTopoLevel::HCCL_TOPO_L0)].rankSize = userRankSize_; topoDescs[static_cast(HcclTopoLevel::HCCL_TOPO_L1)].rankSize = 0; - return HCCL_SUCCESS; } HcclResult HcclCommunicator::CheckSuperDeviceId(const RankTable_t &rankTable) { + // 非910_73/910_73非超节点形态,Sdid为无效值0xFFFFFFFF,无需校验SDID合法性 if (!IsUseSdidForDeviceId(superDeviceId_)) { return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h index c289132adcd634af96deddc79c05b6928c4c10c6..5b4f43ae2869881de3f833dca8da351d57be0e81 100644 --- a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h +++ b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h @@ -38,7 +38,7 @@ #include "device_capacity.h" #include "transport_manager.h" #include "coll_alg_operator.h" -#include "alltoall_operator.h" +#include "opretry_manage_pub.h" namespace hccl { using ServRankInfo_t = std::map >; @@ -148,6 +148,8 @@ public: virtual HcclResult ReleaseCommInfos(); + virtual std::vector GenerateSendCountMatrix(u64 count, u32 rankSize); + virtual HcclResult GetAlltoAllStagedWorkSpaceMemSize(u64 *sendCounts, u64 *sdispls, HcclDataType sendType, u64 *recvCounts, u64 *rdispls, HcclDataType recvType, u64 &memSize); @@ -318,7 +320,7 @@ private: CommBase *comm, innerStreamInfo_t &streamInfo, Stream &stream); HcclResult GetAicpuOpStreamAndNotify(HcclRtStream *opStream, void** aicpuNotify); HcclResult SetAicpuNotifyInvaild(); - HcclResult AicpuKfcTilingDataLaunch(void *inputPtr, void *outputPtr, u64 count, + HcclResult AicpuKfcTilingDataLaunch(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, HcclDataType dataType, HcclReduceOp op, HcclRtStream stream, HcclCMDType opType); HcclResult AllReduceAicpuUnfold(const std::string &tag, void *inputPtr, void *outputPtr, u64 count, @@ -418,11 +420,19 @@ private: HcclResult GetModuleInfo(const std::vector &rankList); HcclResult SetInterModeInSuperPod(); HcclResult CheckSingleServerComm(const std::vector &rankList) const; + HcclResult SetInfoToDevice(const OpParam &opParam, const std::unique_ptr &preMetaInfo, + const HcclWorkflowMode &mode, Stream &stream); + HcclResult GetInfoFromDevice(const OpParam &opParam, const std::unique_ptr &preMetaInfo, + const HcclWorkflowMode &mode, Stream &stream, HostMem& hostCollectBuffer); + HcclResult RegressCalPreOp(const std::unique_ptr& algOperator, + const OpParam &opParam, std::unique_ptr &preMetaInfo); HcclResult ExecOp(HcclCMDType opType, const OpParam &opParam); // batchsendrecv专用,增量建链 HcclResult ExecOpExt(HcclCMDType opType, const OpParam &opParam); - HcclResult AllocAlgResource(const std::string &tag, const OpParam &opParam, + HcclResult CalcTinySendRecvMem(const OpParam &opParam, AlgResourceResponse &algResResponse, + DeviceMem &tinySendRecvMem); + HcclResult AllocAlgResource(const std::string &tag, HcclCMDType opType, const OpParam &opParam, AlgResourceRequest &resRequest, AlgResourceResponse &algResResponse); HcclResult IncreAllocLink(const std::string &newTag, const OpParam &opParam, AlgResourceRequest &resRequest, AlgResourceResponse &algResResponse); @@ -453,6 +463,8 @@ private: HcclDataType dataType, HcclReduceOp op, HcclRtStream stream, HcclCMDType cmdType); u32 GetLocalNicPort(); std::string GetSupportDataType(bool needReduce); + HcclResult InitHDCommunicate(); + HcclResult InitOpRetry(); HcclIpAddress loopBackIp_; bool profilingInitiated_; @@ -504,6 +516,14 @@ private: std::unique_ptr transportManager_ = { nullptr }; std::unordered_map resMap_; // tag : AlgResourceResponse bool retryEnable_ = false; + Stream alltoalltream_; + + std::unique_ptr opRetryManager_ = { nullptr }; + + std::shared_ptr kfcControlTransferH2D_; + std::shared_ptr kfcStatusTransferD2H_; + HcclCommConnections commConnections_; + std::shared_ptr opRetryStreamPtr_; }; } // end namespace hccl diff --git a/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.cc b/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.cc index 7e5bfa17f34cfcfa5478deb9d4c613154914d0df..82f991cf4c3ff339cbbd910f10ced240c85f6c20 100644 --- a/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.cc +++ b/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.cc @@ -86,6 +86,29 @@ HcclResult TransportManager::ExceptionHandle(const std::string &tag, OpCommTrans return HCCL_SUCCESS; } +HcclResult TransportManager::CreateVirturalTransport(SingleSubCommTransport& singleSubCommTransport) +{ + MachinePara machinePara; + std::chrono::milliseconds kdefaultTimeout = std::chrono::seconds( + GetExternalInputHcclLinkTimeOut()); + + singleSubCommTransport.virtualLinks.clear(); + singleSubCommTransport.virtualLinks.resize(singleSubCommTransport.transportRequests.size()); + + for (u32 i = 0; i < singleSubCommTransport.transportRequests.size(); i++) { + TransportPara para {}; + para.virtualFlag = true; + para.timeout = kdefaultTimeout; + para.index = i; + singleSubCommTransport.virtualLinks[i].reset(new (std::nothrow) Transport(TransportType::TRANS_TYPE_RESERVED, + para, dispatcher_, notifyPool_, machinePara)); + CHK_PRT_RET(!singleSubCommTransport.virtualLinks[i], HCCL_ERROR("[CreateVirturalTransport]In create link," \ + "new link failed"), HCCL_E_PTR); + } + + return HCCL_SUCCESS; +} + HcclResult TransportManager::Alloc(const std::string &tag, const TransportIOMem &transMem, OpCommTransport &opTransportResponse) { @@ -99,12 +122,19 @@ HcclResult TransportManager::Alloc(const std::string &tag, const TransportIOMem singleSubCommTransport.links.clear(); singleSubCommTransport.links.reserve(singleSubCommTransport.transportRequests.size()); + if (singleSubCommTransport.needVirtualLink) { + // task多线程并行下发,根据当前transport创建vtransport信息 + CHK_RET(CreateVirturalTransport(singleSubCommTransport)); + } + u32 linkIdx = 0; for (auto &transportRequest : singleSubCommTransport.transportRequests) { singleSubCommTransport.links.emplace_back(std::make_shared(nullptr)); if (transportRequest.isValid) { DeviceMem inputMem; DeviceMem outputMem; + HCCL_DEBUG("transportRequest.inputMemType[%d] transportRequest.outputMemType[%d]", + transportRequest.inputMemType, transportRequest.outputMemType); GetIOMem(transMem, transportRequest.inputMemType, transportRequest.outputMemType, inputMem, outputMem); @@ -164,6 +194,7 @@ HcclResult TransportManager::Alloc(const std::string &tag, const TransportIOMem } } CHK_RET(notifyPool_->UnregisterOp(tag)); + return HCCL_SUCCESS; } @@ -181,11 +212,11 @@ HcclResult TransportManager::IncreAlloc(const std::string &tag, const TransportI for (u32 rankIndex = 0; rankIndex < reqSingleSubComm.transportRequests.size(); rankIndex++) { TransportRequest &transportRequest = reqSingleSubComm.transportRequests[rankIndex]; CHK_PRT_RET(rankIndex >= respSingleSubComm.links.size(), - HCCL_ERROR("[IncreAlloc] The remote rank_id[%u] is larger than the exist respSingleSubComm map "\ + HCCL_ERROR("[IncreAlloc] The remote rank_id[%u] is larger than the existent respSingleSubComm map "\ "size[%u]", rankIndex, respSingleSubComm.links.size()), HCCL_E_PARA); if (respSingleSubComm.links[rankIndex] != nullptr && respSingleSubComm.links[rankIndex]->GetLinkType() != hccl::LinkType::LINK_RESERVED) { - HCCL_INFO("[IncreAlloc] The link to remote userRank[%u] is exist", transportRequest.remoteUserRank); + HCCL_INFO("[IncreAlloc] The link to remote userRank[%u] has existed", transportRequest.remoteUserRank); continue; } if (transportRequest.isValid) { @@ -197,7 +228,7 @@ HcclResult TransportManager::IncreAlloc(const std::string &tag, const TransportI std::vector > connectSockets; bool isInterRdma; HcclResult ret = CreateDestSockets(tag, transportRequest.remoteUserRank, reqSingleSubComm.taskNum, - connectSockets, isInterRdma, reqSingleSubComm.isUsedRdma); + connectSockets, isInterRdma); CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[IncreAlloc]Create dest sockets failed"), ret); MachineType machineType = transportRequest.localUserRank < transportRequest.remoteUserRank? @@ -215,7 +246,7 @@ HcclResult TransportManager::IncreAlloc(const std::string &tag, const TransportI } } for (u32 index = 0; index < linkThreads.size(); index++) { - if (linkThreads[index] != nullptr) { + if (linkThreads[index] != nullptr && linkThreads[index]->joinable()) { linkThreads[index]->join(); CHK_RET(hrtResetDevice(deviceLogicId_)); // 防止线程里面异常退出,在进程中reset } @@ -240,7 +271,7 @@ HcclResult TransportManager::ConstructTransTag(const std::string& tag, std::stri return HCCL_SUCCESS; } -void TransportManager::GetIOMem(const TransportIOMem &transMem, +HcclResult TransportManager::GetIOMem(const TransportIOMem &transMem, const TransportMemType inputMemType, const TransportMemType outputMemType, DeviceMem &inputMem, DeviceMem &outputMem) { @@ -252,6 +283,11 @@ void TransportManager::GetIOMem(const TransportIOMem &transMem, inputMem = transMem.paramInputMem; } else if (inputMemType == AIV_INPUT) { inputMem = transMem.aivInputMem; + } else if (inputMemType == CCL_OUTPUT) { + inputMem = transMem.cclOutputMem; + } else { + HCCL_ERROR("inputMemType is Invalid, inputMem not set"); + return HCCL_E_INTERNAL; } if (outputMemType == CCL_OUTPUT) { @@ -262,7 +298,15 @@ void TransportManager::GetIOMem(const TransportIOMem &transMem, outputMem = transMem.paramOutputMem; } else if (outputMemType == AIV_OUTPUT) { outputMem = transMem.aivOutputMem; + } else if (outputMemType == CCL_INPUT) { + outputMem = transMem.cclInputMem; + } else if (outputMemType == PARAM_INPUT) { + outputMem = transMem.paramInputMem; + } else { + HCCL_ERROR("outputMemType is Invalid, inputMem not set"); + return HCCL_E_INTERNAL; } + return HCCL_SUCCESS; } u32 TransportManager::GetRemoteNicPort(u32 remoteRank) @@ -367,6 +411,8 @@ HcclResult TransportManager::CreateLink(const std::string &tag, const ErrContext MachinePara machinePara; CHK_RET(SetMachinePara(tag, machineType, serverId, remoteRank, supportDataReceivedAck, linkMode, sockets, inputMem, outputMem, machinePara)); + HCCL_DEBUG("inputMem[%p],outputMem[%p], inputMem size[%llu], outputMem size[%llu]", inputMem.ptr(), outputMem.ptr(), + inputMem.size(), outputMem.size()); HCCL_INFO("[creakLink para]rank[%u]-localUserrank[%u]-localIpAddr[%s], linkMode[%d] " "dst_rank[%u]-remoteUserrank[%u]-remote_ip_addr[%s], machineType[%d], serverId[%s], nicDeploy[%d] ", userRank_, rankInfoList_[userRank_].worldRank, rankInfoList_[userRank_].serverId.c_str(), machinePara.linkMode, @@ -454,8 +500,14 @@ TransportType TransportManager::GetTransportType(const u32 dstRank, bool isUsedR if (isHaveCpuRank_) { transportType = TransportType::TRANS_TYPE_HETEROG_P2P; } else { + LinkTypeInServer linkType = LinkTypeInServer::RESERVED_LINK_TYPE; + hrtGetPairDeviceLinkType(rankInfoList_[userRank_].devicePhyId, rankInfoList_[dstRank].devicePhyId, + linkType); + if (linkType == LinkTypeInServer::SIO_TYPE && GetExternalInputEnableRdmaSdmaConcurrent() && isUsedRdma + && rankInfoList_[userRank_].deviceType == DevType::DEV_TYPE_910_73) { + transportType = TransportType::TRANS_TYPE_P2P; // Server内判断是否使用rdma - if (isUsedRdma) { + } else if (isUsedRdma) { transportType = TransportType::TRANS_TYPE_IBV_EXP; } else { transportType = TransportType::TRANS_TYPE_P2P; @@ -538,7 +590,16 @@ HcclResult TransportManager::TransportInit(const u32 dstRank, MachinePara &machi bool TransportManager::IsSupportInterHccs(const u32 dstRank) { - return false; + // 仅判断超节点内, 兼容打平通信域同时有server内和server间, 因此不判断server_id + bool isInterHccs = GetExternalInputInterHccsDisable() == false && + rankInfoList_[userRank_].deviceType == DevType::DEV_TYPE_910_73 && + rankInfoList_[userRank_].superPodId.empty() == false && + rankInfoList_[userRank_].superPodId == rankInfoList_[dstRank].superPodId; + + HCCL_INFO("[IsSupportInterHccs]rank[%u], superPodId[%s], dstRank[%u], dstSuperPodId[%s], isInterHccs[%d]", + userRank_, rankInfoList_[userRank_].superPodId.c_str(), dstRank, + rankInfoList_[dstRank].superPodId.c_str(), isInterHccs); + return isInterHccs; } void TransportManager::UpdateIsInterRdma(const u32 remoteRank, bool &isInterRdma, bool forceRdma) // 待确认判断是否完善 @@ -552,9 +613,17 @@ void TransportManager::UpdateIsInterRdma(const u32 remoteRank, bool &isInterRdma } LinkTypeInServer linkType; hrtGetPairDeviceLinkType(rankInfoList_[userRank_].devicePhyId, rankInfoList_[remoteRank].devicePhyId, linkType); - isInterRdma = rankInfoList_[userRank_].serverId != rankInfoList_[remoteRank].serverId || - (isUsedRdmaOuter_ && linkType == LinkTypeInServer::PXI_TYPE) || - (rankInfoList_[userRank_].serverId == rankInfoList_[remoteRank].serverId && forceRdma); + if (isConcurrent && forceRdma && rankInfoList_[userRank_].deviceType == DevType::DEV_TYPE_910_73) { + if (linkType == LinkTypeInServer::SIO_TYPE) { + isInterRdma = false; + } else { + isInterRdma = true; + } + } else { + isInterRdma = rankInfoList_[userRank_].serverId != rankInfoList_[remoteRank].serverId || + (isUsedRdmaOuter_ && linkType == LinkTypeInServer::PXI_TYPE) || + (rankInfoList_[userRank_].serverId == rankInfoList_[remoteRank].serverId && forceRdma); + } } u32 TransportManager::GetInterRemotePort(s32 devicePhyId, u32 dstUserRank) @@ -603,4 +672,4 @@ HcclResult TransportManager::MakeRemoteLinkInfo(const u32 remoteRank, bool isInt return HCCL_SUCCESS; } -} // namespace hccl \ No newline at end of file +} // namespace hccl diff --git a/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.h b/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.h index 02f1037ef1b4597e9bbe87ff9c4ccac555db026e..244d9308b313bc1a024b52f5a1770063bb7bd898 100644 --- a/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.h +++ b/src/domain/collective_communication/framework/communicator/impl/resource_manager/transport_manager.h @@ -140,6 +140,7 @@ public: ~TransportManager(); + HcclResult CreateVirturalTransport(SingleSubCommTransport& singleSubCommTransport); HcclResult Alloc(const std::string &tag, const TransportIOMem &transMem, OpCommTransport &opTransportResponse); HcclResult IncreAlloc(const std::string &tag, const TransportIOMem &transMem, OpCommTransport &opTransportReq, OpCommTransport &opTransportResponse); @@ -149,7 +150,7 @@ public: TransportManager& operator=(TransportManager &&) = delete; // Move assign private: - void GetIOMem(const TransportIOMem &transMem, + HcclResult GetIOMem(const TransportIOMem &transMem, const TransportMemType inputMemType, const TransportMemType outputMemType, DeviceMem &inputMem, DeviceMem &outputMem); u32 GetRemoteNicPort(u32 remoteRank); diff --git a/src/domain/collective_communication/framework/hcom/hcom_common.cc b/src/domain/collective_communication/framework/hcom/hcom_common.cc index 169accf5fc1e21913a82860da107dfed9ce9e1de..0c4242d0f7e1cec93ff964c7ccbf31ab732c0ae0 100644 --- a/src/domain/collective_communication/framework/hcom/hcom_common.cc +++ b/src/domain/collective_communication/framework/hcom/hcom_common.cc @@ -206,7 +206,8 @@ void HcomTopoInfoFuncInstall(HcclResult (*p1)(const char *, uint32_t), void (*p2 HcclResult HcomRegRemoteAccessMem(const MemRegisterAddr* addrList, u32 count) { HcomInfo &hcomInfo = HcomGetCtxHomInfo(); - if (hcomInfo.params.deviceType == DevType::DEV_TYPE_910B) { + if (hcomInfo.params.deviceType == DevType::DEV_TYPE_910B || + hcomInfo.params.deviceType == DevType::DEV_TYPE_910_73) { // 910_73场景临时使用SDMA模拟RDMA return HCCL_SUCCESS; } @@ -302,7 +303,8 @@ HcclResult GetRankList(u32 rankNum, const u32 *rankIds, HcclGroupParams ¶ms) bool isStandardCard = false; CHK_RET(hcomInfo.pComm->IsStandardCard(isStandardCard)); - if (!isStandardCard && hcomInfo.params.deviceType != DevType::DEV_TYPE_910B) { + if (!isStandardCard && hcomInfo.params.deviceType != DevType::DEV_TYPE_910B && + hcomInfo.params.deviceType != DevType::DEV_TYPE_910_73) { CHK_RET(CheckRankTableConfigInfo(rankList, rankNum, serverNum)); } return HCCL_SUCCESS; diff --git a/src/domain/collective_communication/framework/inc/comm.h b/src/domain/collective_communication/framework/inc/comm.h index 66149a228c30c4deaf4a3e0611affd4f2ad189ee..341dfdbbc649a336d77eed1c0c3c12eca3364f8a 100644 --- a/src/domain/collective_communication/framework/inc/comm.h +++ b/src/domain/collective_communication/framework/inc/comm.h @@ -13,6 +13,7 @@ #include "hccl_common.h" #include "common.h" +#include "hccl_socket.h" // profiling状态 enum class HcomProfilingMode { @@ -32,6 +33,12 @@ enum class HcclTopoLevel { }; namespace hccl { +using HcclCommConnections = struct HcclCommConnectionsDef { + bool isRoot{false}; + std::shared_ptr agentConnection{nullptr}; + std::map> serverConnections; +}; + using HcclCommParams = struct TagHCCLCollectiveParams { /** 通信域的基本构建信息,通信域标识、节点数及本节点的编号 @@ -57,6 +64,7 @@ using HcclCommParams = struct TagHCCLCollectiveParams { CommAttr attr; WorkMode commWorkMode = WorkMode::HCCL_MODE_NORMAL; std::string identifier; + HcclCommConnections commConnections; TagHCCLCollectiveParams() : id{0}, rank(INVALID_VALUE_RANKID), userRank(INVALID_VALUE_RANKID), totalRanks(0xFFFFFFFF), logicDevId(-1), deviceType(DevType::DEV_TYPE_COUNT), profilingMode(HcomProfilingMode::PROFILING_CLOSE), diff --git a/src/domain/collective_communication/framework/op_base/src/op_base.cc b/src/domain/collective_communication/framework/op_base/src/op_base.cc index 2c34afee80e2a53fc0c926277138930f4a4733f3..3f807bd33bf0f39541c346a01ccf19b7e6913c2a 100644 --- a/src/domain/collective_communication/framework/op_base/src/op_base.cc +++ b/src/domain/collective_communication/framework/op_base/src/op_base.cc @@ -40,6 +40,8 @@ const std::string HCCL_ALLTOALL = "ALLTOALL"; const std::string HCCL_ALLTOALLV = "ALLTOALLV"; const std::string HCCL_ALLTOALLVC = "ALLTOALLVC"; +map> g_topoDetectServerPtrMap; + HcclResult CallMsprofReportHostApi(hccl::hcclComm* hcclComm, HcclCMDType cmdType, uint64_t beginTime, u64 count, HcclDataType dataType, std::string tag) { @@ -322,8 +324,9 @@ HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo) CHK_RET(InitExternalInput()); HcclRootHandle rootHandle; - TopoInfoDetect topoDetectServer; - CHK_RET(topoDetectServer.SetupServer(rootHandle)); + std::shared_ptr topoDetectServer = std::make_shared(); + CHK_SMART_PTR_NULL(topoDetectServer); + CHK_RET(topoDetectServer->SetupServer(rootHandle)); if (sizeof(HcclRootHandle) > HCCL_ROOT_INFO_BYTES) { HCCL_ERROR("[Get][RootInfo]hccl root info overflow. max length: %u, actual:%zu, identifier[%s]", @@ -336,6 +339,8 @@ HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo) sizeof(HcclRootHandle)), HCCL_E_MEMORY); } + HcclOpInfoCtx& opBaseInfo = GetHcclOpInfoCtx(); + opBaseInfo.hcclCommTopoInfoDetectServer.insert({rootHandle.identifier, topoDetectServer}); /* 首节点诊断信息记录 */ HCCL_RUN_INFO("[HCCL_TRACE]HcclGetRootInfo success, take time [%lld]us, identifier[%s]", DURATION_US(TIME_NOW() - startut), rootHandle.identifier); @@ -383,6 +388,35 @@ HcclResult HcclGetCommHandle(const char *commName, std::shared_ptrsecond->GetServerConnections(commConnections.serverConnections)); + } + + auto iterAgent = opBaseInfo.hcclCommTopoInfoDetectAgent.find(rootHandle.identifier); + if (iterAgent == opBaseInfo.hcclCommTopoInfoDetectAgent.end()) { + HCCL_ERROR("hccl get agent connections failed, rootHandle.identifier=%s", rootHandle.identifier); + return HCCL_E_PARA; + } else { + CHK_RET(iterAgent->second->GetAgentConnection(commConnections.agentConnection)); + } + return HCCL_SUCCESS; +} + +void HcclCloseCommConnections(const std::string &identifier) +{ + HcclOpInfoCtx& opBaseInfo = GetHcclOpInfoCtx(); + opBaseInfo.hcclCommTopoInfoDetectServer.erase(identifier); + opBaseInfo.hcclCommTopoInfoDetectAgent.erase(identifier); + return; +} + HcclResult InitCommRootInfo(const u32 nRanks, const u32 rank, const HcclRootHandle &rootHandle, const CommConfig &commConfig, HcclComm *comm) { @@ -397,36 +431,48 @@ HcclResult InitCommRootInfo(const u32 nRanks, const u32 rank, const HcclRootHand rootHandle.identifier)); CHK_SMART_PTR_NULL(pComm); - TopoInfoDetect topoDetectAgent; - ret = topoDetectAgent.SetupAgent(nRanks, rank, rootHandle); - CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] setup topo detect error", HCCL_ERROR_CODE(ret)), - errorFlag = true); + std::shared_ptr topoDetectAgent = std::make_shared(); + CHK_SMART_PTR_NULL(topoDetectAgent); + ret = topoDetectAgent->SetupAgent(nRanks, rank, rootHandle); + CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfo]errNo[0x%016llx] setup topo detect error", + HCCL_ERROR_CODE(ret)), errorFlag = true); RankTable_t rankTable; - ret = topoDetectAgent.GetCluterInfo(rankTable); - CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] GetCluterInfo error", HCCL_ERROR_CODE(ret)), - errorFlag = true); + ret = topoDetectAgent->GetCluterInfo(rankTable); + CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfo]errNo[0x%016llx] GetCluterInfo error", + HCCL_ERROR_CODE(ret)), errorFlag = true); /* 初始化hccl comm */ HcclBasicRankInfo localRankInfo; - ret = topoDetectAgent.GetLocalRankInfo(localRankInfo); - CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] GetLocalRankInfo error.", HCCL_ERROR_CODE(ret)), - errorFlag = true); + ret = topoDetectAgent->GetLocalRankInfo(localRankInfo); + CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfo]errNo[0x%016llx] GetLocalRankInfo error.", + HCCL_ERROR_CODE(ret)), errorFlag = true); ret = GetSelfClusterInfo(localRankInfo, params); CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] GetRankInfo error.", HCCL_ERROR_CODE(ret)), errorFlag = true); - ret = topoDetectAgent.Teardown(rootHandle); + ret = topoDetectAgent->WaitComplete(rootHandle); CHK_PRT_BREAK(ret != HCCL_SUCCESS, - HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] teardown topo detect error", HCCL_ERROR_CODE(ret)), + HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] wait complete topo detect error", HCCL_ERROR_CODE(ret)), errorFlag = true); CHK_RET(DisplayRanktableInfo(rankTable)); + bool retryEnable = GetExternalInputIntraServerRetryEnable() || GetExternalInputInterServerRetryEnable() || + GetExternalInputInterSuperPodRetryEnable(); + if (retryEnable) { + opBaseHcom.hcclCommTopoInfoDetectAgent.insert({ rootHandle.identifier, topoDetectAgent }); + ret = HcclGetCommConnections(rootHandle, params.commConnections); + CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][RootInfo]HcclGetCommConnections failed."), + errorFlag = true); + } else { + ret = topoDetectAgent->Teardown(); + CHK_PRT_BREAK(ret != HCCL_SUCCESS, + HCCL_ERROR("[Init][RootInfo]errNo[0x%016llx] Teardown topo detect error", HCCL_ERROR_CODE(ret)), + errorFlag = true); + } + ret = InitWorkflowMode(HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); CHK_PRT_BREAK(ret != HCCL_SUCCESS, HCCL_ERROR("[InitCommRootInfo]errNo[0x%016llx] init work flow mode error", HCCL_ERROR_CODE(ret)), @@ -476,36 +522,38 @@ HcclResult InitCommRootInfo(const u32 nRanks, const u32 rank, const HcclRootHand return HCCL_SUCCESS; } -HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm) +HcclResult HcclCommInitRootInfoInner(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, + HcclComm *comm, string &identifier) { HcclResult ret = HCCL_SUCCESS; HcclUs startut = TIME_NOW(); - s32 deviceLogicId = 0; - CHK_RET(hrtGetDeviceRefresh(&deviceLogicId)); - - CHK_PRT_RET((nRanks == 0), HCCL_ERROR("[Init][CommRootInfo]errNo[0x%016llx] nRanks[%u] should be greater than 0.", - HCCL_ERROR_CODE(HCCL_E_PARA), nRanks), HCCL_E_PARA); - - CHK_PRT_RET((rank >= nRanks), HCCL_ERROR("[Init][CommRootInfo]errNo[0x%016llx] rank[%u] should be less than "\ - "nRanks[%u].", HCCL_ERROR_CODE(HCCL_E_PARA), rank, nRanks), HCCL_E_PARA); - CHK_SMART_PTR_NULL(comm); CHK_SMART_PTR_NULL(rootInfo); - HcclRootHandle rootHandle; s32 sRet = memcpy_s(&rootHandle, sizeof(HcclRootHandle), rootInfo->internal, sizeof(HcclRootHandle)); - CHK_PRT_RET(sRet != EOK, HCCL_ERROR("[Init][RootInfo]memcpy root info fail. errorno[%d] "\ + CHK_PRT_RET(sRet != EOK, HCCL_ERROR("[Init][RootInfoInner]memcpy root info fail. errorno[%d] "\ "params:destMaxSize[%u], count[%u]", sRet, sizeof(HcclRootHandle), sizeof(HcclRootHandle)), HCCL_E_MEMORY); rootHandle.identifier[ROOTINFO_INDENTIFIER_MAX_LENGTH - 1] = '\0'; + identifier = rootHandle.identifier; + + s32 deviceLogicId = 0; + CHK_RET(hrtGetDeviceRefresh(&deviceLogicId)); + + CHK_PRT_RET((nRanks == 0), HCCL_ERROR("[Init][CommRootInfoInner]errNo[0x%016llx] nRanks[%u] should "\ + "be greater than 0.", HCCL_ERROR_CODE(HCCL_E_PARA), nRanks), HCCL_E_PARA); + + CHK_PRT_RET((rank >= nRanks), HCCL_ERROR("[Init][CommRootInfoInner]errNo[0x%016llx] rank[%u] should "\ + "be less than nRanks[%u].", HCCL_ERROR_CODE(HCCL_E_PARA), rank, nRanks), HCCL_E_PARA); + CHK_SMART_PTR_NULL(comm); ret = InitExternalInput(); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfo]errNo[0x%016llx] init external input error", - HCCL_ERROR_CODE(ret)), HCCL_E_PARA); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfoInner]errNo[0x%016llx] init "\ + "external input error", HCCL_ERROR_CODE(ret)), HCCL_E_PARA); /* 接口交互信息日志 */ - HCCL_RUN_INFO("Entry-HcclCommInitRootInfo:ranks[%u], rank[%u], rootinfo: host ip[%s] port[%u] nicDeploy[%d]" \ - " identifier[%s], deviceLogicId[%d]", nRanks, rank, rootHandle.ip, rootHandle.port, rootHandle.nicDeploy, - rootHandle.identifier, deviceLogicId); + HCCL_RUN_INFO("Entry-HcclCommInitRootInfoInner:ranks[%u], rank[%u], rootinfo: host ip[%s] port[%u] "\ + "nicDeploy[%d] identifier[%s], deviceLogicId[%d]", nRanks, rank, rootHandle.ip, rootHandle.port, + rootHandle.nicDeploy, rootHandle.identifier, deviceLogicId); CommConfig commConfig; @@ -517,74 +565,98 @@ HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, u ret); /* 关键状态记录 */ - HCCL_RUN_INFO("[HCCL_TRACE]HcclCommInitRootInfo success, take time [%lld]us, rankNum[%u], rank[%u]", + HCCL_RUN_INFO("[HCCL_TRACE]HcclCommInitRootInfoInner success, take time [%lld]us, rankNum[%u], rank[%u]", DURATION_US(TIME_NOW() - startut), nRanks, rank); return HCCL_SUCCESS; } -HcclResult HcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, - const HcclCommConfig *config, HcclComm *comm) +HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm) +{ + HcclResult ret = HCCL_SUCCESS; + string identifier; + ret = HcclCommInitRootInfoInner(nRanks, rootInfo, rank, comm, identifier); + if (g_topoDetectServerPtrMap.find(identifier) != g_topoDetectServerPtrMap.end()) { + g_topoDetectServerPtrMap[identifier] = nullptr; + } + return ret; +} + +HcclResult HcclCommInitRootInfoConfigInner(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, + const HcclCommConfig *config, HcclComm *comm, string &identifier) { HcclResult ret = HCCL_SUCCESS; HcclUs startut = TIME_NOW(); + CHK_SMART_PTR_NULL(rootInfo); + + HcclRootHandle rootHandle; + s32 sRet = memcpy_s(&rootHandle, sizeof(HcclRootHandle), rootInfo->internal, sizeof(HcclRootHandle)); + CHK_PRT_RET(sRet != EOK, HCCL_ERROR("[Init][RootInfo]memcpy root info fail. errorno[%d] "\ + "params:destMaxSize[%u], count[%u]", sRet, sizeof(HcclRootHandle), + sizeof(HcclRootHandle)), HCCL_E_MEMORY); + rootHandle.identifier[ROOTINFO_INDENTIFIER_MAX_LENGTH - 1] = '\0'; + identifier = rootHandle.identifier; + // 检查配置参数是否为空 RPT_INPUT_ERR(config == nullptr, "EI0003", std::vector({"ccl_op", "parameter", "value", "tips"}),\ - std::vector({"HcclCommInitRootInfoConfig", "config", "nullptr", "please check comm"})); + std::vector({"HcclCommInitRootInfoConfigInner", "config", "nullptr", "please check comm"})); CHK_SMART_PTR_NULL(config); s32 deviceLogicId = 0; CHK_RET(hrtGetDeviceRefresh(&deviceLogicId)); CHK_PRT_RET((nRanks == 0), - HCCL_ERROR("[Init][CommRootInfoConfig]errNo[0x%016llx] nRanks[%u] should be greater than 0.", + HCCL_ERROR("[Init][CommRootInfoConfigInner]errNo[0x%016llx] nRanks[%u] should be greater than 0.", HCCL_ERROR_CODE(HCCL_E_PARA), nRanks), HCCL_E_PARA); CHK_PRT_RET((rank >= nRanks), - HCCL_ERROR("[Init][CommRootInfoConfig]errNo[0x%016llx] rank[%u] should be less than " - "nRanks[%u].", + HCCL_ERROR("[Init][CommRootInfoConfigInner]errNo[0x%016llx] rank[%u] should be less than nRanks[%u].", HCCL_ERROR_CODE(HCCL_E_PARA), rank, nRanks), HCCL_E_PARA); CHK_SMART_PTR_NULL(comm); - CHK_SMART_PTR_NULL(rootInfo); - - HcclRootHandle rootHandle; - s32 sRet = memcpy_s(&rootHandle, sizeof(HcclRootHandle), rootInfo->internal, sizeof(HcclRootHandle)); - CHK_PRT_RET(sRet != EOK, HCCL_ERROR("[Init][RootInfo]memcpy root info fail. errorno[%d] "\ - "params:destMaxSize[%u], count[%u]", sRet, sizeof(HcclRootHandle), - sizeof(HcclRootHandle)), HCCL_E_MEMORY); - rootHandle.identifier[ROOTINFO_INDENTIFIER_MAX_LENGTH - 1] = '\0'; ret = InitExternalInput(); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfoConfig]errNo[0x%016llx] init external input error", - HCCL_ERROR_CODE(ret)), HCCL_E_PARA); + CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Init][CommRootInfoConfigInner]errNo[0x%016llx] init "\ + "external input error", HCCL_ERROR_CODE(ret)), HCCL_E_PARA); /* 读取用户配置 */ CommConfig commConfig; ret = commConfig.Load(config, rootHandle.identifier); CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Init][CommRootInfoConfig]errNo[0x%016llx] load comm config failed.", HCCL_ERROR_CODE(ret)), - HCCL_E_PARA); + HCCL_ERROR("[Init][CommRootInfoConfigInner]errNo[0x%016llx] load comm config failed.", + HCCL_ERROR_CODE(ret)), HCCL_E_PARA); /* 接口交互信息日志 */ - HCCL_RUN_INFO("Entry-HcclCommInitRootInfoConfig:ranks[%u], rank[%u], rootinfo: host ip[%s] port[%u] nicDeploy[%d]" \ - " identifier[%s], deviceLogicId[%d]", nRanks, rank, rootHandle.ip, rootHandle.port, rootHandle.nicDeploy, - rootHandle.identifier, deviceLogicId); + HCCL_RUN_INFO("Entry-HcclCommInitRootInfoConfigInner:ranks[%u], rank[%u], rootinfo: host ip[%s] "\ + "port[%u] nicDeploy[%d] identifier[%s], deviceLogicId[%d]", nRanks, rank, rootHandle.ip, + rootHandle.port, rootHandle.nicDeploy, rootHandle.identifier, deviceLogicId); /* --------------初始化------------------------- */ ret = InitCommRootInfo(nRanks, rank, rootHandle, commConfig, comm); CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("[Init][CommRootInfoConfig]errNo[0x%016llx]HcclCommInitRootInfoConfig failed.", + HCCL_ERROR("[Init][CommRootInfoConfigInner]errNo[0x%016llx]HcclCommInitRootInfoConfigInner failed.", HCCL_ERROR_CODE(ret)), ret); - HCCL_RUN_INFO("[HCCL_TRACE]HcclCommInitRootInfoConfig success, take time [%lld]us, rankNum[%u], rank[%u]", - DURATION_US(TIME_NOW() - startut), nRanks, rank); + HCCL_RUN_INFO("[HCCL_TRACE]HcclCommInitRootInfoConfigInner success, take time [%lld]us, "\ + "rankNum[%u], rank[%u]", DURATION_US(TIME_NOW() - startut), nRanks, rank); return HCCL_SUCCESS; } +HcclResult HcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, + const HcclCommConfig *config, HcclComm *comm) +{ + HcclResult ret = HCCL_SUCCESS; + string identifier; + ret = HcclCommInitRootInfoConfigInner(nRanks, rootInfo, rank, config, comm, identifier); + if (g_topoDetectServerPtrMap.find(identifier) != g_topoDetectServerPtrMap.end()) { + g_topoDetectServerPtrMap[identifier] = nullptr; + } + return ret; +} + HcclResult HcclSetConfig(HcclConfig config, HcclConfigValue configValue) { if (config == HCCL_DETERMINISTIC) { @@ -1110,11 +1182,7 @@ HcclResult HcclSend(void* sendBuf, uint64_t count, HcclDataType dataType, uint32 CHK_RET_AND_PRINT_IDE(SetDefaultQosConfig(hcclComm), tag.c_str()); - if (!GetExternalInputHcclEnableFfts()) { - CHK_RET_AND_PRINT_IDE(SendLoop(tag, sendBuf, count, dataType, destRank, hcclComm, stream), tag.c_str()); - } else { - CHK_RET_AND_PRINT_IDE(hcclComm->SendOutPlace(tag, sendBuf, count, dataType, destRank, stream), tag.c_str()); - } + CHK_RET_AND_PRINT_IDE(hcclComm->SendOutPlace(tag, sendBuf, count, dataType, destRank, stream), tag.c_str()); HCCL_PROFILER_DEL_STREAM(stream); HCCL_PROFILER_DEL_TAG(tag); @@ -1187,11 +1255,7 @@ HcclResult HcclRecv(void* recvBuf, uint64_t count, HcclDataType dataType, uint32 CHK_RET_AND_PRINT_IDE(SetDefaultQosConfig(hcclComm), tag.c_str()); - if (!GetExternalInputHcclEnableFfts()) { - CHK_RET_AND_PRINT_IDE(ReceiveLoop(tag, recvBuf, count, dataType, srcRank, hcclComm, stream), tag.c_str()); - } else { - CHK_RET_AND_PRINT_IDE(hcclComm->ReceiveOutPlace(tag, recvBuf, count, dataType, srcRank, stream), tag.c_str()); - } + CHK_RET_AND_PRINT_IDE(hcclComm->ReceiveOutPlace(tag, recvBuf, count, dataType, srcRank, stream), tag.c_str()); HCCL_PROFILER_DEL_STREAM(stream); HCCL_PROFILER_DEL_TAG(tag); @@ -1231,6 +1295,7 @@ HcclResult HcclCommDestroy(HcclComm comm) if (comm == opBaseHcom.pComm.get()) { group = opBaseHcom.pComm->GetIdentifier(); opBaseHcom.pComm = nullptr; + HcclCloseCommConnections(group); } else { HCCL_RUN_INFO("com is not global com"); group = hcclComm->GetIdentifier(); @@ -1242,6 +1307,7 @@ HcclResult HcclCommDestroy(HcclComm comm) auto iter = opBaseHcom.opGroup2CommMap.find(group); if (iter != opBaseHcom.opGroup2CommMap.end()) { opBaseHcom.opGroup2CommMap.erase(group); + HcclCloseCommConnections(group); } else { HCCL_ERROR("[HcclCommDestroy] comm is not exist, comm=%p, group=%s, deviceLogicId=%d", comm, group.c_str(), deviceLogicId); @@ -1369,118 +1435,6 @@ HcclResult ReduceScatterLoop(const std::string &tag, void *inputPtr, void *outpu return HCCL_SUCCESS; } -HcclResult SendLoop(const std::string &tag, void *inputPtr, const u64 &count, - HcclDataType dataType, int destRank, hccl::hcclComm *hcclComm, rtStream_t stream) -{ - void *commInputPtr = nullptr; - u64 commInputSize; - HcclResult ret; - - ret = hcclComm->GetInCCLbuffer(commInputPtr, commInputSize); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] Op_base get comm inbuffer error", HCCL_ERROR_CODE(ret)), ret); - - u32 qosCfg = INVALID_QOSCFG; // qos不使能的情况下为全F - CHK_RET(hcclComm->GetQosCfg(qosCfg)); - - u32 unitSize; - ret = SalGetDataTypeSize(dataType, unitSize); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] get datatype size error", HCCL_ERROR_CODE(ret)), ret); - - char *curInputPtr = static_cast(inputPtr); - u64 inputOffset = 0; - u64 countLeft = count; - while (countLeft > 0) { - curInputPtr += inputOffset; - HCCL_DEBUG("-OP_BASE-SendLoop:inputOffset[%llu]", inputOffset); - u64 curCount = ((countLeft * unitSize) > commInputSize) ? (commInputSize / unitSize) : countLeft; - u64 curSize = curCount * unitSize; // 单位 byte - - HCCL_DEBUG("-OP_BASE-SendLoop:curInputPtr[%p], curCount[%llu], curSize[%llu]", - curInputPtr, curCount, curSize); - ret = hrtMemAsyncCopyByQos(commInputPtr, curSize, curInputPtr, curSize, - HcclRtMemcpyKind::HCCL_RT_MEMCPY_KIND_DEVICE_TO_DEVICE, stream, qosCfg); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("In OP_BASE inputbuffer transit, memcopy failed."), HCCL_E_MEMORY); - - /* 记录指令信息用于一致性校验 */ - ret = RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_SEND, tag, - curCount, dataType, 0, commInputSize, HCCL_WORLD_GROUP); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] record CMD with parameter error", HCCL_ERROR_CODE(ret)), ret); - /* 入参的正确性由HCCL确保 */ - ret = hcclComm->send(const_cast(tag.c_str()), commInputPtr, curCount, dataType, destRank, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] op_base hcclComm send error, tag[%s], input_ptr[%p], "\ - "count[%llu], data_type[%s]", HCCL_ERROR_CODE(ret), tag.c_str(), - commInputPtr, curCount, GetDataTypeEnumStr(dataType).c_str()), ret); - ret = RankConsistent::GetInstance().DelOpPara(tag); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("errNo[0x%016llx] delete CMD with parameters error. tag[%s]", - HCCL_ERROR_CODE(ret), tag.c_str()), ret); - CHK_PRT_RET((curCount == 0), HCCL_ERROR("In OP_BASE curCount is zero"), HCCL_E_PARA); - countLeft -= curCount; - inputOffset = curSize; - } - - return HCCL_SUCCESS; -} - -HcclResult ReceiveLoop(const std::string &tag, void *outputPtr, const u64 &count, - HcclDataType dataType, int srcRank, hccl::hcclComm *hcclComm, rtStream_t stream) -{ - void *commOutputPtr = nullptr; - u64 commOutputSize; - HcclResult ret; - - ret = hcclComm->GetOutCCLbuffer(commOutputPtr, commOutputSize); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] Op_base get comm outbuffer error", HCCL_ERROR_CODE(ret)), ret); - u32 qosCfg = INVALID_QOSCFG; // qos不使能的情况下为全F - CHK_RET(hcclComm->GetQosCfg(qosCfg)); - - u32 unitSize; - ret = SalGetDataTypeSize(dataType, unitSize); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] get datatype size error", HCCL_ERROR_CODE(ret)), ret); - - char *curOutputPtr = static_cast(outputPtr); - u64 outputOffset = 0; - u64 countLeft = count; - while (countLeft > 0) { - curOutputPtr += outputOffset; - HCCL_DEBUG("-OP_BASE-ReceiveLoop:outputOffset[%llu]", outputOffset); - u64 curCount = ((countLeft * unitSize) > commOutputSize) ? (commOutputSize / unitSize) : countLeft; - u64 curSize = curCount * unitSize; // 单位 byte - - HCCL_DEBUG("-OP_BASE-ReceiveLoop:curOutputPtr[%p], curCount[%llu], curSize[%llu]", - curOutputPtr, curCount, curSize); - /* 记录指令信息用于一致性校验 */ - ret = RankConsistent::GetInstance().RecordOpPara(HcclCMDType::HCCL_CMD_RECEIVE, tag, - curCount, dataType, 0, commOutputSize, HCCL_WORLD_GROUP); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] record CMD with parameter error", HCCL_ERROR_CODE(ret)), ret); - /* 入参的正确性由HCCL确保 */ - ret = hcclComm->receive(const_cast(tag.c_str()), commOutputPtr, - curCount, dataType, srcRank, stream); - CHK_PRT_RET(ret != HCCL_SUCCESS, - HCCL_ERROR("errNo[0x%016llx] op_base hcclComm receive error, tag[%s], "\ - "output_ptr[%p], count[%llu], data_type[%s]", HCCL_ERROR_CODE(ret), tag.c_str(), - commOutputPtr, curCount, GetDataTypeEnumStr(dataType).c_str()), ret); - ret = RankConsistent::GetInstance().DelOpPara(tag); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("errNo[0x%016llx] delete CMD with parameters error. tag[%s]", - HCCL_ERROR_CODE(ret), tag.c_str()), ret); - ret = hrtMemAsyncCopyByQos(curOutputPtr, curSize, commOutputPtr, curSize, - HcclRtMemcpyKind::HCCL_RT_MEMCPY_KIND_DEVICE_TO_DEVICE, stream, qosCfg); - CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("In OP_BASE outputbuffer transit, memcopy failed."), HCCL_E_MEMORY); - CHK_PRT_RET((curCount == 0), HCCL_ERROR("In OP_BASE curCount is zero"), HCCL_E_PARA); - countLeft -= curCount; - outputOffset = curSize; - } - - return HCCL_SUCCESS; -} - // 获取算子所需workspace memory大小[byte] HcclResult HcclGetOpBasedMemSize(const HcclCMDType &opType, u64 &size, const HcomCollOpInfo &opInfo) @@ -1972,11 +1926,6 @@ HcclResult ReduceLoop(const std::string &tag, void *inputPtr, void *outputPtr, c return HCCL_SUCCESS; } -/* - * ********************************************************************** - * 单算子GatherAllToAllV的函数接口,目前不对外开放,仅图模式动态shape使用 - * ********************************************************************** - */ HcclResult HcclGatherAlltoAllV(HcomGatherAllToAllVParams params, HcclComm comm, aclrtStream stream) { HcclUs startut = TIME_NOW(); diff --git a/src/domain/collective_communication/framework/op_base/src/op_base.h b/src/domain/collective_communication/framework/op_base/src/op_base.h index d3214e5ad7390b4ba8b113c5617baded1951ce8d..77bb26a8d41a4ebad4b3f0b92dbaaefc19f54d0c 100644 --- a/src/domain/collective_communication/framework/op_base/src/op_base.h +++ b/src/domain/collective_communication/framework/op_base/src/op_base.h @@ -18,6 +18,7 @@ #include "op_base_pub.h" #include "hccl_comm_pub.h" #include "config.h" +#include "../common/src/topo/topoinfo_detect.h" using HcclOpInfoCtx = struct HcclInfoTag { HcclCommPtr pComm; @@ -27,6 +28,8 @@ using HcclOpInfoCtx = struct HcclInfoTag { bool isUsed; std::mutex opGroupMapMutex; std::map> opGroup2CommMap; + std::map> hcclCommTopoInfoDetectServer; + std::map> hcclCommTopoInfoDetectAgent; HcclInfoTag() :isUsed(false) {} }; @@ -40,12 +43,6 @@ HcclResult CallMsprofReportHostApi(hccl::hcclComm* hcclComm, HcclCMDType cmdType HcclResult ReduceScatterLoop(const std::string &tag, void *inputPtr, void *outputPtr, const u64 &count, HcclDataType dataType, HcclReduceOp op, hccl::hcclComm *hcclComm, rtStream_t stream); -HcclResult ReceiveLoop(const std::string &tag, void *outputPtr, const u64 &count, - HcclDataType dataType, int srcRank, hccl::hcclComm *hcclComm, rtStream_t stream); - -HcclResult SendLoop(const std::string &tag, void *inputPtr, const u64 &count, - HcclDataType dataType, int destRank, hccl::hcclComm *hcclComm, rtStream_t stream); - HcclResult HcclGetOpBasedMemSize(const HcclCMDType &opType, u64 &size, const HcomCollOpInfo &opInfo); diff --git a/version.info b/version.info index 017198b7ba509e7d464203064baa39f80d9f3b50..4923b3d5b77fc65c885addd76e471672ee09eb00 100644 --- a/version.info +++ b/version.info @@ -1 +1 @@ -Version=7.3.0.1 +Version=7.3.T10.0