From 730d6cec02751fac50b31ab166a2c38eec7f010f Mon Sep 17 00:00:00 2001 From: ZhangLynn9 Date: Wed, 12 Nov 2025 16:37:13 +0800 Subject: [PATCH] A3 3 levels all_gather_fake demo --- inc/hccl/hccl_types.h | 1 + .../algorithm/CMakeLists.txt | 1 + .../base/alg_template/CMakeLists.txt | 1 + .../base/alg_template/alg_template_base.cc | 15 + .../base/alg_template/alg_template_base_pub.h | 5 + .../temp_all_gather_fake/CMakeLists.txt | 9 + .../all_gather_fake_mesh_aicpu.cc | 187 +++++++++ .../all_gather_fake_mesh_aicpu.h | 19 + .../all_gather_fake_ring_aicpu.cc | 148 +++++++ .../all_gather_fake_ring_aicpu.h | 19 + .../all_gather_fake_ring_aicpu_async.cc | 33 ++ .../all_gather_fake_ring_aicpu_async.h | 19 + .../base/inc/all_gather_fake_mesh_aicpu_pub.h | 62 +++ .../all_gather_fake_ring_aicpu_async_pub.h | 28 ++ .../base/inc/all_gather_fake_ring_aicpu_pub.h | 41 ++ .../impl/coll_executor/CMakeLists.txt | 1 + .../coll_all_gather_fake/CMakeLists.txt | 8 + .../coll_all_gather_fake_executor.cc | 390 ++++++++++++++++++ .../coll_all_gather_fake_executor.h | 48 +++ .../coll_all_gather_fake_mesh_executor.cc | 198 +++++++++ .../coll_all_gather_fake_mesh_executor.h | 39 ++ .../algorithm/impl/operator/CMakeLists.txt | 1 + .../impl/operator/all_gather_fake_operator.cc | 83 ++++ .../impl/operator/all_gather_fake_operator.h | 30 ++ .../impl/operator/coll_alg_operator.cc | 7 +- .../algorithm/pub_inc/alg_cmd_type.h | 1 + .../algorithm/pub_inc/ffts_common_pub.h | 37 ++ .../debug/profiling/inc/task_profiling_pub.h | 1 + .../framework/CMakeLists.txt | 9 + .../framework/communicator/hccl_comm.cc | 18 + .../communicator/impl/hccl_communicator.cc | 3 + .../communicator/impl/hccl_communicator.h | 2 + .../impl/hccl_communicator_device.cc | 6 + .../impl/hccl_communicator_host.cc | 48 +++ .../framework/inc/hccl_comm_pub.h | 2 + .../framework/op_base/src/op_base.cc | 80 ++++ .../framework/op_base/src/op_base.h | 2 + 37 files changed, 1601 insertions(+), 1 deletion(-) create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/CMakeLists.txt create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.cc create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.h create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.cc create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.h create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.cc create mode 100644 src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.h create mode 100644 src/domain/collective_communication/algorithm/base/inc/all_gather_fake_mesh_aicpu_pub.h create mode 100644 src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_async_pub.h create mode 100644 src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_pub.h create mode 100644 src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/CMakeLists.txt create mode 100644 src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.cc create mode 100644 src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.h create mode 100644 src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.cc create mode 100644 src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.h create mode 100644 src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.cc create mode 100644 src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.h diff --git a/inc/hccl/hccl_types.h b/inc/hccl/hccl_types.h index 691e72b..0f9750c 100644 --- a/inc/hccl/hccl_types.h +++ b/inc/hccl/hccl_types.h @@ -174,6 +174,7 @@ typedef enum { HCCL_CMD_ALLGATHER_V, HCCL_CMD_REDUCE_SCATTER_V, HCCL_CMD_BATCH_WRITE, + HCCL_CMD_ALLGATHER_FAKE, HCCL_CMD_ALL, HCCL_CMD_FINALIZE = 100, HCCL_CMD_INTER_GROUP_SYNC, diff --git a/src/domain/collective_communication/algorithm/CMakeLists.txt b/src/domain/collective_communication/algorithm/CMakeLists.txt index b0e42ec..84d5fb4 100644 --- a/src/domain/collective_communication/algorithm/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/CMakeLists.txt @@ -48,6 +48,7 @@ target_include_directories(hccl_alg PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_reduce/310P ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_to_all ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather + ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather_fake ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather/310P ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather_v ${CMAKE_CURRENT_SOURCE_DIR}/impl/coll_executor/coll_all_gather_v/310P diff --git a/src/domain/collective_communication/algorithm/base/alg_template/CMakeLists.txt b/src/domain/collective_communication/algorithm/base/alg_template/CMakeLists.txt index d8274aa..8ace5d7 100644 --- a/src/domain/collective_communication/algorithm/base/alg_template/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/base/alg_template/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources(hccl_alg PRIVATE add_subdirectory(component) add_subdirectory(temp_alltoall) add_subdirectory(temp_all_gather) +add_subdirectory(temp_all_gather_fake) add_subdirectory(temp_all_reduce) add_subdirectory(temp_alltoallv) add_subdirectory(temp_broadcast) diff --git a/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base.cc b/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base.cc index 0dfaf00..b245ed4 100644 --- a/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base.cc +++ b/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base.cc @@ -894,6 +894,21 @@ void ExecutorBase::CalcRecursiveHdLinkRelationForSecondScene(u32 rank, } } +HcclResult ExecutorBase::ExecuteAicpuBarrier(const std::shared_ptr &preLink, + const std::shared_ptr &aftLink) +{ + // 同步与preLink保证数据收发已结束 + CHK_RET(preLink->PostReady(stream_)); + + CHK_RET(aftLink->WaitReady(stream_)); + + // 同步与aftLink保证数据收发已结束 + CHK_RET(aftLink->PostFin(stream_)); + + CHK_RET(preLink->WaitFin(stream_)); + return HCCL_SUCCESS; +} + HcclResult ExecutorBase::ExecuteBarrier(const std::shared_ptr &preLink, const std::shared_ptr &aftLink) { diff --git a/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base_pub.h b/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base_pub.h index 0cf0d08..b173362 100644 --- a/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base_pub.h +++ b/src/domain/collective_communication/algorithm/base/alg_template/alg_template_base_pub.h @@ -164,6 +164,10 @@ enum TemplateType { TEMPLATE_ALL_REDUCE_DOUBLING_LOCAL_REDUCE = 99, // AllReduceDoublingLocalReduce AR 910A单机小数据量tbe reduce优化 + TEMPLATE_ALL_GATHER_FAKE_MESH_AICPU = 100, + TEMPLATE_ALL_GATHER_FAKE_RING_LEVEL1 = 101, + TEMPLATE_ALL_GATHER_FAKE_RING_LEVEL2 = 102, + TEMPLATE_NATIVE_MAX_NUM, // 内置template最大值 TEMPLATE_CUSTOM_BEGIN = 1000, // 用户自定义template起始值 @@ -586,6 +590,7 @@ public: const std::vector &links, AdjInfo& nslbAdjInfo); protected: + HcclResult ExecuteAicpuBarrier(const std::shared_ptr &preLink, const std::shared_ptr &aftLink); HcclResult ExecuteBarrier(const std::shared_ptr &preLink, const std::shared_ptr &aftLink); HcclResult ExecuteBarrier(const std::shared_ptr &preLink, const std::shared_ptr &aftLink, Stream &stream); diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/CMakeLists.txt b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/CMakeLists.txt new file mode 100644 index 0000000..5062d02 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/CMakeLists.txt @@ -0,0 +1,9 @@ +set(src_list_pub + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_fake_mesh_aicpu.cc + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_fake_ring_aicpu.cc + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_fake_ring_aicpu_async.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list_pub} +) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.cc b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.cc new file mode 100644 index 0000000..b8ee2e2 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "all_gather_fake_mesh_aicpu.h" +#include "alg_template_register.h" + +namespace hccl { +AllGatherFakeMeshAicpu::AllGatherFakeMeshAicpu(const HcclDispatcher dispatcher) + : AlgTemplateBase(dispatcher) +{} + +AllGatherFakeMeshAicpu::~AllGatherFakeMeshAicpu() {} + +HcclResult AllGatherFakeMeshAicpu::Prepare(std::vector &meshStreams, + std::vector> &meshSignal, std::vector> &meshSignalAux, + u32 userRank, HcomCollOpInfo *opInfo, u32 interRank, u32 interRankSize) +{ + meshStreams_ = meshStreams; + meshSignal_ = &meshSignal; + meshSignalAux_ = &meshSignalAux; + interRank_ = interRank; + interRankSize_ = interRankSize; + userRank_ = userRank; + return HCCL_SUCCESS; +} + +HcclResult AllGatherFakeMeshAicpu::RunAllGather(const std::vector &links, const std::vector &outputSlices, + const std::vector &inputSlices) +{ + HCCL_DEBUG("AllGatherFakeMeshAicpu::RunAllGather starts"); + for (u32 round = 1; round < interRankSize_; round++) { + u32 dstRank = BackwardRank(interRank_, interRankSize_, round); + Stream& subStream = (round == interRankSize_ - 1) ? stream_ : meshStreams_[round - 1]; + CHK_RET(links[dstRank]->PostReady(subStream)); + CHK_RET(links[dstRank]->WaitReady(subStream)); + } + + for (u32 round = 1; round < interRankSize_; round++) { + u32 dstRank = BackwardRank(interRank_, interRankSize_, round); + Stream& subStream = (round == interRankSize_ - 1) ? stream_ : meshStreams_[round - 1]; + profilerInput_.streamID = subStream.id(); + profilerInput_.planeID = round - 1; + profilerInput_.step = HCCL_EXEC_STEP_NOT_SET; + + if (round == interRankSize_ - 1) { + for (u32 signalIndex = 0; signalIndex < interRankSize_ - 2; signalIndex++) { // rankSize-2: stream num + CHK_RET(LocalNotify::Wait(subStream, dispatcher_, (*meshSignal_)[signalIndex], profilerInput_.stage)); + } + for (u32 signalIndex = 0; signalIndex < interRankSize_ - 2; signalIndex++) { // rankSize-2: stream num + CHK_RET(LocalNotify::Post(subStream, dispatcher_, (*meshSignalAux_)[signalIndex], + profilerInput_.stage)); + } + } else { + u32 signalIndex = round - 1; + CHK_RET(LocalNotify::Post(subStream, dispatcher_, (*meshSignal_)[signalIndex], + profilerInput_.stage)); + CHK_RET(LocalNotify::Wait(subStream, dispatcher_, (*meshSignalAux_)[signalIndex], profilerInput_.stage)); + } + // 本rank要收数据 + void *srcMemPtr = nullptr; + // 从对端的input内存拿数据,input==output也没有关系 + CHK_RET(links[dstRank]->GetRemoteMem(UserMemType::OUTPUT_MEM, &srcMemPtr)); + struct hccl::Transport::Buffer remoteBuf; + remoteBuf.addr = static_cast(srcMemPtr) + baseOffset_ + inputSlices[dstRank].offset; + remoteBuf.size = inputSlices[dstRank].size; + struct hccl::Transport::Buffer localBuf; + localBuf.addr = static_cast(outputMem_.ptr()) + outputSlices[dstRank].offset; + localBuf.size = outputSlices[dstRank].size; + CHK_RET(links[dstRank]->ReadSync(localBuf, remoteBuf, subStream)); + CHK_RET(links[dstRank]->PostFin(subStream)); + CHK_RET(links[dstRank]->WaitFin(subStream)); + } + return HCCL_SUCCESS; +} + +// allgather的入口函数 +HcclResult AllGatherFakeMeshAicpu::RunAsync(const u32 rank, const u32 rankSize, const std::vector &links) +{ + HcclResult ret = HCCL_SUCCESS; + CHK_SMART_PTR_NULL(dispatcher_); + CHK_PTR_NULL(stream_.ptr()); + HCCL_INFO("AllGatherFakeMeshAicpu run: rank[%u] totalrank[%u] inputMem[%p] outputMem[%p] count[%llu]", + rank, rankSize, inputMem_.ptr(), outputMem_.ptr(), count_); + + interRank_ = rank; + interRankSize_ = rankSize; + + if (interRankSize_ == 1) { + if (inputMem_ != outputMem_) { + HCCL_DEBUG("rank[%u] mem copy async from input to output", rank); + ret = HcclD2DMemcpyAsync(dispatcher_, outputMem_, inputMem_, stream_); + } + return ret; + } + + if (links.size() < rankSize) { + HCCL_ERROR("[AllGatherFakeMeshAicpu][RunAsync]rank[%u] linksize error", rank); + return HCCL_E_INTERNAL; + } + u32 subStreamSize = interRankSize_ - 2; // 子流大小等于ranksize-2 + if (meshStreams_.size() < subStreamSize || (*meshSignal_).size() < subStreamSize || + (*meshSignalAux_).size() < subStreamSize) { + HCCL_ERROR("[AllGatherFakeMeshAicpu][RunAsync]AllGatherFakeMeshAicpu stream size error: " + "rank[%u] totalrank:%u substreamsize[%llu] signalsize[%llu], signal_aux size[%llu]", + rank, rankSize, meshStreams_.size(), (*meshSignal_).size(), (*meshSignalAux_).size()); + return HCCL_E_PARA; + } + + u32 unitSize = DataUnitSize(dataType_); + if (unitSize == 0) { + HCCL_ERROR("[AllGatherFakeMeshAicpu][RunAsync]rank[%u] Unit Data Size is zero", rank); + return HCCL_E_INTERNAL; + } + + std::vector inputSlices(slices_); + if (slices_.size() == 0) { + slices_.resize(interRankSize_); + inputSlices.resize(rankSize); + + // 生成std::vector slices_ + u64 sliceSize = count_ * unitSize; + + for (u32 i = 0; i < interRankSize_; i++) { + slices_[i].size = sliceSize; + slices_[i].offset = (i * sliceSize); + + inputSlices[i].size = sliceSize; + inputSlices[i].offset = (inputMem_.size() < outputMem_.size()) ? 0 : (sliceSize * i); + HCCL_DEBUG("rank[%u], slices[%u].offset=%llu, slices[%u].size=%llu", rank, i, slices_[i].offset, i, + slices_[i].size); + } + } + + for (u32 i = 0; i < interRankSize_; i++) { + HCCL_DEBUG("[AllGatherFakeMeshAicpu][Outputslice]: size[%llu] offset[%llu] inputslice: size[%llu] offset[%llu]", + slices_[i].size, slices_[i].offset, inputSlices[i].size, inputSlices[i].offset); + } + + if (inputMem_ != outputMem_) { + DeviceMem dst = outputMem_.range(slices_[rank].offset, slices_[rank].size); + DeviceMem src = inputMem_.range(inputSlices[rank].offset, inputSlices[rank].size); + + HCCL_INFO("inputMem != outputMem: rank[%u] copy src[%p] offset[%llu] size[%llu] to dst[%p] offset[%llu]" + "size[%llu]", + rank, src.ptr(), inputSlices[rank].offset, inputSlices[rank].size, dst.ptr(), slices_[rank].offset, + slices_[rank].size); + + // 拷贝到自身rank的output_mem + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, stream_)); + } + + for (u32 streamIndex = 0; streamIndex < rankSize - 2; streamIndex++) { // rankSize-2: stream num + HCCL_DEBUG("rank[%u] streamindex[%u] wait signalaux[%p]", + rank, streamIndex, (*meshSignalAux_)[streamIndex]->ptr()); + CHK_RET(LocalNotify::Wait(meshStreams_[streamIndex], dispatcher_, (*meshSignalAux_)[streamIndex], + profilerInput_.stage)); + + HCCL_DEBUG("rank[%u] siganl_aux index[%u] signal record signalaux[%p]", rank, streamIndex, + (*meshSignalAux_)[streamIndex]->ptr()); + CHK_RET(LocalNotify::Post(stream_, dispatcher_, (*meshSignalAux_)[streamIndex], + profilerInput_.stage)); + } + + HCCL_DEBUG("RunAllGather starts: rank[%u]", rank); + CHK_RET(RunAllGather(links, slices_, inputSlices)); + + for (u32 streamIndex = 0; streamIndex < rankSize - 2; streamIndex++) { // rankSize - 2 stream num + HCCL_DEBUG("rank[%u] streamindex[%u] wait signal[%p] ", + rank, streamIndex, (*meshSignal_)[streamIndex]->ptr()); + CHK_RET(LocalNotify::Wait(stream_, dispatcher_, (*meshSignal_)[streamIndex], profilerInput_.stage)); + + HCCL_DEBUG("rank[%u] streamindex[%u] record signal[%p]", rank, streamIndex, meshStreams_[streamIndex].ptr()); + CHK_RET(LocalNotify::Post(meshStreams_[streamIndex], dispatcher_, (*meshSignal_)[streamIndex], + profilerInput_.stage)); + } + HCCL_DEBUG("AllGatherFakeMeshAicpu finished: rank[%u]", rank); + return HCCL_SUCCESS; +} +REGISTER_TEMPLATE(TemplateType::TEMPLATE_ALL_GATHER_FAKE_MESH_AICPU, AllGatherFakeMeshAicpu); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.h b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.h new file mode 100644 index 0000000..65cbaab --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_MESH_AICPU_H +#define ALL_GATHER_FAKE_MESH_AICPU_H + +#include "all_gather_fake_mesh_aicpu_pub.h" + +namespace hccl { +} // namespace hccl + +#endif /* * ALL_GATHER_FAKE_MESH_AICPU_H */ diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.cc b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.cc new file mode 100644 index 0000000..beedd16 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "all_gather_fake_ring_aicpu.h" +#include "alg_template_register.h" + +namespace hccl { +AllGatherFakeRingAicpu::AllGatherFakeRingAicpu(const HcclDispatcher dispatcher) : ExecutorBase(dispatcher) +{ +} + +AllGatherFakeRingAicpu::~AllGatherFakeRingAicpu() +{ +} + +// 服务器间allgather的入口函数 +HcclResult AllGatherFakeRingAicpu::RunAsync(const u32 rank, const u32 rankSize, const std::vector &links) +{ + CHK_SMART_PTR_NULL(dispatcher_); + CHK_PTR_NULL(stream_.ptr()); + HCCL_DEBUG("AllGatherFakeRingAicpu run_async rank[%u] ranksize[%u] inputMem[%p] outputMem[%p] count[%llu]", \ + rank, rankSize, inputMem_.ptr(), outputMem_.ptr(), count_); + + if (rankSize == 1) { + if (inputMem_ != outputMem_) { + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, outputMem_, inputMem_, stream_)); + } + return HCCL_SUCCESS; + } + + // 获取ring algorithm所需的通信连接 + u32 ringPrevRank = (rank + rankSize - 1) % rankSize; + u32 ringNextRank = (rank + 1) % rankSize; + HCCL_DEBUG("[AllGatherFakeRingAicpu][RunAsync]rank[%u] linkSize[%zu] rankSize[%u] ringPrevRank[%u] ringNextRank[%u]", + rank, links.size(), rankSize, ringPrevRank, ringNextRank); + if (links.size() < rankSize) { + HCCL_ERROR("[AllGatherFakeRingAicpu][RunAsync]rank[%u] linkSize[%zu] is less than rankSize[%u]", rank, links.size(), rankSize); + return HCCL_E_INTERNAL; + } + + linkLeft_ = links[ringPrevRank]; + CHK_SMART_PTR_NULL(linkLeft_); + + linkRight_ = links[ringNextRank]; + CHK_SMART_PTR_NULL(linkRight_); + + u32 unitSize = DataUnitSize(dataType_); + if (unitSize == 0) { + HCCL_ERROR("[AllGatherFakeRingAicpu][RunAsync]unitSize is zero"); + return HCCL_E_INTERNAL; + } + + std::vector inputSlices(slices_); + if (slices_.size() == 0) { + slices_.resize(rankSize); + inputSlices.resize(rankSize); + + u64 sliceSize = count_ * unitSize; + for (u32 i = 0; i < rankSize; i++) { + slices_[i].size = sliceSize; + slices_[i].offset = sliceSize * i; + inputSlices[i].size = sliceSize; + inputSlices[i].offset = (inputMem_.size() < outputMem_.size()) ? 0 : (sliceSize * i); + HCCL_DEBUG("rank[%u], slices[%u].offset=%llu, slices[%u].size=%llu", \ + rank, i, slices_[i].offset, i, slices_[i].size); + } + } + + // 双buffer下, 先将input拷贝到output的合适位置 + if (inputMem_ != outputMem_) { + DeviceMem dst = outputMem_.range(slices_[rank].offset, slices_[rank].size); + DeviceMem src = inputMem_.range(inputSlices[rank].offset, inputSlices[rank].size); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dst, src, stream_)); + } + + // 运行all-gather, ring算法 + HCCL_DEBUG("RunAllGather starts. rank [%u], rankSize [%u], nicRankList_.size() [%zu]", + rank, rankSize, nicRankList_.size()); + CHK_RET(RunAllGather(rank, rankSize, slices_)); + HCCL_DEBUG("[RunAsync] barrierSwitchOn_ [%d] ", barrierSwitchOn_); + if (barrierSwitchOn_) { + // 执行barrier,保证数据收发完成 + CHK_RET(ExecuteAicpuBarrier(linkLeft_, linkRight_)); + } + + HCCL_INFO("AllGatherFakeRingAicpu finished: rank[%u] end", rank); + return HCCL_SUCCESS; +} + +HcclResult AllGatherFakeRingAicpu::WriteSync( + struct Transport::Buffer &remoteBuf, struct Transport::Buffer &localBuf, Stream &stream) +{ + // A3场景level1是SDMA + HCCL_DEBUG("AllGatherFakeRingAicpu::Level1Sync apply the WriteSync"); + CHK_RET(linkRight_->WriteSync(remoteBuf, localBuf, stream)); + return HCCL_SUCCESS; +} + +HcclResult AllGatherFakeRingAicpu::RunAllGather(u32 rank, u32 rankSize, const std::vector &outputSlices) +{ + HCCL_DEBUG("AllGatherFakeRingAicpu::RunAllGather starts: rank[%u] end", rank); + if (outputSlices.size() < rankSize || outputSlices.size() > rankSize) { + HCCL_ERROR("[AllGather]rank[%u] outputSlices.size() [%zu] is not equal to rankSize [%u]", + rank, outputSlices.size(), rankSize); + return HCCL_E_INTERNAL; + } + + // 首次传输,将本rank的数据发送到下游 + u32 txSliceIndex = rank; + for (u32 i = 0; i < rankSize - 1; i++) { + HCCL_DEBUG("rank[%u] i[%u] rankSize[%u] txSliceIndex[%u] outputSlices.size()[%zu]", + rank, i, rankSize, txSliceIndex, outputSlices.size()); + + CHK_RET(linkLeft_->PostReady(stream_)); + CHK_RET(linkRight_->WaitReady(stream_)); + + Slice txSlice = outputSlices[txSliceIndex]; + DeviceMem srcMem = outputMem_.range(txSlice.offset, txSlice.size); + HCCL_DEBUG("tx srcMem[%p] offset[%llu] size[%llu] baseOffset_[%llu]", + srcMem.ptr(), txSlice.offset, txSlice.size, baseOffset_); + + struct hccl::Transport::Buffer localBuf; + localBuf.addr = static_cast(outputMem_.ptr()) + txSlice.offset; + localBuf.size = txSlice.size; + + void *dstMemPtr = nullptr; + CHK_RET(linkRight_->GetRemoteMem(UserMemType::OUTPUT_MEM, &dstMemPtr)); + struct hccl::Transport::Buffer remoteBuf; + remoteBuf.addr = static_cast(dstMemPtr) + txSlice.offset + baseOffset_; + remoteBuf.size = txSlice.size; + // A3场景level1的链路类型是P2P + CHK_RET(WriteSync(remoteBuf, localBuf, stream_)); + CHK_RET(linkRight_->PostFin(stream_)); + CHK_RET(linkLeft_->WaitFin(stream_)); + // 更新下一轮的txSliceIndex + txSliceIndex = ForwordRank(txSliceIndex, rankSize, 1); + } + return HCCL_SUCCESS; +} +REGISTER_TEMPLATE(TemplateType::TEMPLATE_ALL_GATHER_FAKE_RING_LEVEL1, AllGatherFakeRingAicpu); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.h b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.h new file mode 100644 index 0000000..16e8742 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_RING_AICPU_H +#define ALL_GATHER_FAKE_RING_AICPU_H + +#include "all_gather_fake_ring_aicpu_pub.h" + +namespace hccl { +} // namespace hccl + +#endif /* * ALL_GATHER_FAKE_RING_AICPU_H */ diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.cc b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.cc new file mode 100644 index 0000000..f90063c --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.cc @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "all_gather_fake_ring_aicpu_async.h" +#include "alg_template_register.h" + +namespace hccl { +AllGatherFakeRingAsync::AllGatherFakeRingAsync(const HcclDispatcher dispatcher) : AllGatherFakeRingAicpu(dispatcher) +{ +} + +AllGatherFakeRingAsync::~AllGatherFakeRingAsync() +{ +} + +HcclResult AllGatherFakeRingAsync::WriteSync( + struct Transport::Buffer &remoteBuf, struct Transport::Buffer &localBuf, Stream &stream) +{ + // A3场景level2是RDMA + HCCL_DEBUG("AllGatherFakeRingAsync::Level2Sync apply the WriteAsync"); + CHK_RET(linkRight_->WriteAsync(remoteBuf, localBuf, stream)); + return HCCL_SUCCESS; +} + +REGISTER_TEMPLATE(TemplateType::TEMPLATE_ALL_GATHER_FAKE_RING_LEVEL2, AllGatherFakeRingAsync); +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.h b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.h new file mode 100644 index 0000000..1700764 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_RING_AICPU_ASYNC_H +#define ALL_GATHER_FAKE_RING_AICPU_ASYNC_H + +#include "all_gather_fake_ring_aicpu_async_pub.h" + +namespace hccl { +} // namespace hccl + +#endif /* * ALL_GATHER_FAKE_RING_AICPU_ASYNC_H */ diff --git a/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_mesh_aicpu_pub.h b/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_mesh_aicpu_pub.h new file mode 100644 index 0000000..ff006ce --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_mesh_aicpu_pub.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_MESH_AICPU_PUB_H +#define ALL_GATHER_FAKE_MESH_AICPU_PUB_H + +#include "alg_template_base_pub.h" + +namespace hccl { +class AllGatherFakeMeshAicpu : public AlgTemplateBase { +public: + explicit AllGatherFakeMeshAicpu(const HcclDispatcher dispatcher); // 所有大环的rank个数,commcombine提供接口 + + ~AllGatherFakeMeshAicpu() override; + + // should be called soon after template AllGatherMesh instance created + HcclResult Prepare(std::vector &meshStreams, std::vector> &meshSignal, + std::vector> &meshSignalAux, u32 userRank = INVALID_VALUE_RANKID, + HcomCollOpInfo *opInfo = nullptr, u32 interRank = INVALID_VALUE_RANKID, u32 interRankSize = 0) override; + + HcclResult RunAsync(const u32 rank, const u32 rankSize, const std::vector &links) override; + +protected: + // 获取向该rank往前的第i个rank + inline u32 BackwardRank(u32 rank, u32 rankSize, u32 step) const + { + if (rankSize == 0) { + return 0; + } + return (rank + rankSize - step) % rankSize; + } + + inline u32 ForwardRank(u32 rank, u32 rankSize, u32 step) const + { + if (rankSize == 0) { + return 0; + } + return (rank + step) % rankSize; + } + virtual HcclResult RunAllGather(const std::vector &links, + const std::vector &outputSlices, + const std::vector &inputSlices); + std::vector meshStreams_; /** 多steam**/ + + std::vector> *meshSignal_{nullptr}; /* 每个ring创建一个signal */ + std::vector> *meshSignalAux_{nullptr}; /* 从stream wait,主steam record */ + u32 interRank_; // 在所有rank环上的rankid + u32 interRankSize_; + u32 userRank_; +private: + +}; +} // namespace hccl + +#endif /* ALL_GATHER_FAKE_MESH_AICPU_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_async_pub.h b/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_async_pub.h new file mode 100644 index 0000000..71763c4 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_async_pub.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_RING_AICPU_ASYNC_PUB_H +#define ALL_GATHER_FAKE_RING_AICPU_ASYNC_PUB_H + +#include "all_gather_fake_ring_aicpu_pub.h" + +namespace hccl { +class AllGatherFakeRingAsync : public AllGatherFakeRingAicpu { +public: + explicit AllGatherFakeRingAsync(const HcclDispatcher dispatcher); + + ~AllGatherFakeRingAsync() override; +protected: + HcclResult WriteSync( + struct Transport::Buffer &remoteBuf, struct Transport::Buffer &localBuf, Stream &stream) override; +}; +} // namespace hccl + +#endif /* ALL_GATHER_FAKE_RING_AICPU_ASYNC_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_pub.h b/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_pub.h new file mode 100644 index 0000000..f81e463 --- /dev/null +++ b/src/domain/collective_communication/algorithm/base/inc/all_gather_fake_ring_aicpu_pub.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_RING_AICPU_PUB_H +#define ALL_GATHER_FAKE_RING_AICPU_PUB_H + +#include "alg_template_base_pub.h" + +namespace hccl { +class AllGatherFakeRingAicpu : public AlgTemplateBase { +public: + explicit AllGatherFakeRingAicpu(const HcclDispatcher dispatcher); + + ~AllGatherFakeRingAicpu() override; + + HcclResult RunAsync(const u32 rank, const u32 rankSize, const std::vector &links) override; + +protected: + // 获取向该rank往前的第i个rank + inline u32 ForwordRank(u32 rank, u32 rankSize, u32 preNum) const + { + return (rank + rankSize - preNum) % rankSize; + } + virtual HcclResult RunAllGather(u32 rank, u32 rankSize, const std::vector &outputSlices); + virtual HcclResult WriteSync( + struct Transport::Buffer &remoteBuf, struct Transport::Buffer &localBuf, Stream &stream); + + // 迭代6新增加 + std::shared_ptr linkLeft_; + std::shared_ptr linkRight_; +}; +} // namespace hccl + +#endif /* ALL_GATHER_FAKE_RING_AICPU_PUB_H */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt index e383ab3..6e36d88 100644 --- a/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/CMakeLists.txt @@ -10,6 +10,7 @@ target_sources(hccl_alg PRIVATE ) add_subdirectory(coll_all_gather) +add_subdirectory(coll_all_gather_fake) add_subdirectory(coll_all_gather_v) add_subdirectory(coll_all_reduce) add_subdirectory(coll_reduce_scatter) diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/CMakeLists.txt new file mode 100644 index 0000000..e3cf5d5 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/CMakeLists.txt @@ -0,0 +1,8 @@ +set(src_list + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_fake_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/coll_all_gather_fake_mesh_executor.cc +) + +target_sources(hccl_alg PRIVATE + ${src_list} +) \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.cc new file mode 100644 index 0000000..0fe3d7b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.cc @@ -0,0 +1,390 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "coll_all_gather_fake_executor.h" + +namespace hccl { +CollAllGatherFakeExecutor::CollAllGatherFakeExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollCommExecutor(dispatcher, topoMatcher) +{ +} + +HcclResult CollAllGatherFakeExecutor::Orchestrate(OpParam& param, AlgResourceResponse& algRes) +{ + HcclUs startut = TIME_NOW(); + tag_ = param.tag; + algResResp_ = &algRes; + + HCCL_PROFILER_ADD_TAG(param.tag, algoAttr_.identifier, workflowMode_); + HCCL_PROFILER_ADD_STREAM_BY_STREAMID(param.stream.id(), param.tag, 0, algType_); + CHK_RET(AddSubStreamToProfiling()); + if (workflowMode_ == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_ADD_OPDATA_OP(param.tag, param.DataDes.count, param.inputPtr, param.outputPtr, + param.DataDes.dataType, INVALID_VALUE_RANKID, algoAttr_.identifier, HcclReduceOp::HCCL_REDUCE_RESERVED); + HCCL_PROFILER_ADD_GROUPRANK(algoAttr_.identifier, topoAttr_.userRankSize, topoAttr_.userRank); + } + + HcclResult ret = HCCL_SUCCESS; + // 图模式和单卡场景下不需要Loop + if (workflowMode_ != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + u64 totalSize = param.DataDes.count * SIZE_TABLE[param.DataDes.dataType]; + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputMem = DeviceMem::create(algRes.paramInputMem.ptr(), totalSize); + execMem.outputMem = DeviceMem::create(algRes.paramOutputMem.ptr(), totalSize * topoAttr_.userRankSize); + execMem.scratchMem = algRes.scratchMem; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + HCCL_DEBUG("[CollAllGatherFakeExecutor][Orchestrate]offload inputMem[%p][%llu], outputMem[%p][%llu]," \ + "scratchMem[%p][%llu], inputPtr[%p] outputPtr[%p], count[%llu]", + execMem.inputMem.ptr(), execMem.inputMem.size(), execMem.outputMem.ptr(), execMem.outputMem.size(), + execMem.scratchMem.ptr(), execMem.scratchMem.size(), execMem.inputPtr, execMem.outputPtr, execMem.count); + ret = KernelRun(param, execMem); + } else if (topoAttr_.userRankSize == 1) { + ExecMem execMem; + execMem.count = param.DataDes.count; + execMem.inputMem = algRes.cclInputMem; + execMem.outputMem = algRes.cclOutputMem; + execMem.scratchMem = algRes.scratchMem; + execMem.inputPtr = param.inputPtr; + execMem.outputPtr = param.outputPtr; + ret = KernelRun(param, execMem); + } else { + ret = RunLoop(param, algRes); + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherFakeExecutor][Orchestrate]errNo[0x%016llx]all gather excutor kernel run failed", + HCCL_ERROR_CODE(ret)), ret); + + if (workflowMode_ == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE && !is310P3Common_) { + HCCL_PROFILER_DEL_STREAM_BY_STREAMID(param.stream.id()); + HCCL_PROFILER_DEL_TAG(param.tag); + HCCL_PROFILER_DEL_OPDATA(param.tag); + HCCL_PROFILER_DEL_GROUPRANK(algoAttr_.identifier); + } + HCCL_INFO("tag[%s], Allgather executor orchestrate success, take time [%lld]us.", + param.tag.c_str(), DURATION_US(TIME_NOW() - startut)); + return HCCL_SUCCESS; +} + + +u64 CollAllGatherFakeExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = cclBuffSize / topoAttr_.userRankSize / HCCL_MIN_SLICE_ALIGN + * HCCL_MIN_SLICE_ALIGN / unitSize; + HCCL_DEBUG("[CollAllGatherFakeExecutor][CalcLoopMaxCount] " + "using default userRank [%u] maxCountPerLoop[%llu]", topoAttr_.userRank, maxCountPerLoop); + return maxCountPerLoop; +} + +bool CollAllGatherFakeExecutor::IsHugeData(const u64 curSize) +{ + bool hugeData = curSize * topoAttr_.userRankSize / HCCL_INTERNODE_MAX_DATA_RATE > RDMA_SEND_MAX_SIZE || + curSize > SDMA_SEND_MAX_SIZE; + return hugeData; +} + +bool CollAllGatherFakeExecutor::IsSmallData(const u64 size) +{ + HCCL_INFO("[CollAllGatherFakeExecutor][IsSmallData]opMeta is using the default option: not small data."); + return false; +} + +bool CollAllGatherFakeExecutor::IsDataSplitForRdmaSdmaConcurrent(const u64 curSize) +{ + HCCL_INFO("[CollAllGatherFakeExecutor]opMeta is using the default option: not data split."); + return false; +} + +// 基于性能考量,合并RunLoop和RunLoopInner +HcclResult CollAllGatherFakeExecutor::RunLoop(OpParam ¶m, AlgResourceResponse &algRes) +{ + u32 unitSize = SIZE_TABLE[param.DataDes.dataType]; + + u8 *curInputPtr = static_cast(param.inputPtr); + u8 *curOutputPtr = static_cast(param.outputPtr); + void *commInputPtr = algRes.cclInputMem.ptr(); + u8 *commOutputPtr = static_cast(algRes.cclOutputMem.ptr()); + CHK_PTR_NULL(curInputPtr); + CHK_PTR_NULL(curOutputPtr); + CHK_PTR_NULL(commInputPtr); + CHK_PTR_NULL(commOutputPtr); + + u64 maxCountPerLoop = CalcLoopMaxCount(algRes.cclInputMem.size(), unitSize); // override + CHK_PRT_RET(maxCountPerLoop == 0, + HCCL_ERROR("[CollAllGatherFakeExecutor][RunLoop]tag[%s], userRankSize is [%u], maxCountPerLoop is [%llu].", + param.tag.c_str(), topoAttr_.userRankSize, maxCountPerLoop), + HCCL_E_PARA); + + bool smallData = IsSmallData(param.DataDes.count * unitSize); + for (u64 countLeft = param.DataDes.count, curCount = 0, inputOffset = 0, outputOffset = 0; + countLeft > 0; countLeft -= curCount) { + curInputPtr += inputOffset; + curOutputPtr += outputOffset; + // 判断剩余数据量对应的output size是否大于中转output size + curCount = (countLeft > maxCountPerLoop) ? maxCountPerLoop : countLeft; + u64 curSize = curCount * unitSize; // 单位:字节 + + HCCL_DEBUG("[CollAllGatherFakeExecutor][RunLoop]tag[%s], inputOffset[%llu], outputOffset[%llu], " \ + "sendBuf[%p], recvBuf[%p], sendCount[%llu], dataType[%d].", + param.tag.c_str(), inputOffset, outputOffset, curInputPtr, curOutputPtr, curCount, param.DataDes.dataType); + + if (!is310P3Common_) { + /* 设置子图复用标志 */ + auto autoSelectedAlgTypeLevel1 = static_cast(algType_.algoLevel1); + bool hugeData = IsHugeData(curSize); // override + bool dataSplit = IsDataSplitForRdmaSdmaConcurrent(curSize); + auto opMeta = HcclOpMetaInfo::GetOneForAllGatherFake(autoSelectedAlgTypeLevel1, hugeData, smallData, + CopyPattern::BCOPY, dataSplit); + CHK_RET(InitTask(dispatcher_, param.stream, opMeta.isEnableCache, opMeta.GetCacheKeyFake())); + } + + // 执行 + if (!DMAReduceFlag_) { + // 如果使用in CCL buffer,需要将user buffer in中的结果拷贝到CCL buffer in + DeviceMem srcMem = DeviceMem::create(curInputPtr, curSize); + DeviceMem dstMem = DeviceMem::create(commInputPtr, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, param.stream)); + HCCL_DEBUG("[CollAllGatherFakeExecutor][RunLoop]copy from user in to ccl in."); + } + + // 使用当前Loop偏移到的地址作为当前的inputPtr和outputPtr + ExecMem execMem; + execMem.count = curCount; + execMem.inputMem = DeviceMem::create(commInputPtr, curSize); + execMem.outputMem = DeviceMem::create(commOutputPtr, curSize * topoAttr_.userRankSize); + execMem.scratchMem = algRes.scratchMem; + execMem.inputPtr = curInputPtr; + execMem.outputPtr = curOutputPtr; + HcclResult ret = KernelRun(param, execMem); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherFakeExecutor][RunLoop]errNo[0x%016llx]kernel run error, tag[%s], " \ + "inputMem ptr[%p], outputMem ptr[%p], count[%llu], dataType[%d]", + HCCL_ERROR_CODE(ret), param.tag.c_str(), commInputPtr, commOutputPtr, + curCount, param.DataDes.dataType), + ret); + + if (!DMAReduceFlag_) { + // 如果使用CCL buffer,需要将CCL buffer out中的结果拷贝到user buffer out + for (u32 i = 0; i < topoAttr_.userRankSize; i++) { + // 拷贝中转output上每个slice的数据到output内存,目的端中每个slice的size固定为output的size + DeviceMem dstMem = DeviceMem::create(curOutputPtr + param.DataDes.count * unitSize * i, curSize); + DeviceMem srcMem = DeviceMem::create(commOutputPtr + curSize * i, curSize); + CHK_RET(HcclD2DMemcpyAsync(dispatcher_, dstMem, srcMem, param.stream)); + } + } + + if (!is310P3Common_) { + CHK_RET(LaunchTaskExtend(dispatcher_, param.stream, algResResp_->slaveStreams)); + } + + inputOffset = curSize; + outputOffset = curSize; + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeExecutor::PrepareAllgatherSlice(u32 sliceNum, u64 inputMemSize, + std::vector &dataSegsSlice) const +{ + Slice sliceTemp; + for (u32 i = 0; i < sliceNum; i++) { // 根据数据量计算每个环上数据的偏移和大小 + sliceTemp.size = inputMemSize; + sliceTemp.offset = inputMemSize * i; + dataSegsSlice.push_back(sliceTemp); + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeExecutor::CalculateLevel1AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, + std::vector> multRingsSliceZero, std::vector> &multRingsSlice) const +{ + for (u32 ringIndex = 0; ringIndex < multRingsSliceZero.size(); ringIndex++) { + std::vector level1DataSlice; + for (u32 level0Idx = 0; level0Idx < level0RankSize; level0Idx++) { + CHK_PRT_RET(multRingsSliceZero[ringIndex].size() < level0RankSize, + HCCL_ERROR("[CalculateLevel1AllgatherSlice]multRingsSliceZero[ringIndex]" \ + "size is smaller than level0RankSize."), HCCL_E_INTERNAL); + for (u32 level1Idx = 0; level1Idx < level1RankSize; level1Idx++) { + Slice tmpSlice; + tmpSlice.size = multRingsSliceZero[ringIndex][level0Idx].size; + tmpSlice.offset = + multRingsSliceZero[ringIndex][level0Idx].offset + level1Idx * level0RankSize * inputMemSize; + level1DataSlice.push_back(tmpSlice); + } + } + multRingsSlice.push_back(level1DataSlice); + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeExecutor::CalculateLevel2AllgatherSlice(u64 inputMemSize, u32 level0RankSize, + u32 level1RankSize, u32 level2RankSize, std::vector> multRingsSliceZero, + std::vector &level2DataSlice, u32 ringIndex) const +{ + for (u32 level0Idx = 0; level0Idx < level0RankSize; level0Idx++) { + for (u32 level2Idx = 0; level2Idx < level2RankSize; level2Idx++) { + for (u32 level1Idx = 0; level1Idx < level1RankSize; level1Idx++) { + Slice tmpSlice; + tmpSlice.size = multRingsSliceZero[ringIndex][level0Idx].size; + tmpSlice.offset = multRingsSliceZero[ringIndex][level0Idx].offset + + (level1Idx * level0RankSize + level2Idx * level0RankSize * level1RankSize) *inputMemSize; + level2DataSlice.push_back(tmpSlice); + } + } + } + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeExecutor::AllGatherLevel2(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + u64 count, HcclDataType dataType, Stream &stream, HcomCollOpInfo *opInfo) +{ + u32 perDataSize = 0; + CHK_RET(SalGetDataTypeSize(dataType, perDataSize)); + + SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 commIndex = level0CommInfo.localRank; + SubCommInfo level1CommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + CHK_RET(CheckCommSize(COMM_LEVEL2, COMM_INDEX_0)); + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0); + + u64 inputMemSize = inputMem.size(); + u32 level0RankSize = level0CommInfo.localRankSize; + u32 level1RankSize = level1CommInfo.localRankSize; + u32 level2RankSize = level2CommInfo.localRankSize; + u32 level0ServerIndex = level0CommInfo.localRank; + u32 level1ServerIndex = level1CommInfo.localRank; + + std::unique_ptr level2AGExecutor; + level2AGExecutor = AlgTemplateRegistry::Instance().GetAlgTemplate( + TemplateType::TEMPLATE_ALL_GATHER_RING, dispatcher_); + HCCL_INFO("allgather ring: using ring algo inter-server."); + CHK_SMART_PTR_NULL(level2AGExecutor); + + // 计算slice, 不同超节点相同slice + std::vector level2DataSegsSlice; + Slice sliceTemp; + for (u32 i = 0; i < level2RankSize; i++) { + sliceTemp.size = inputMemSize; + sliceTemp.offset = i * level1RankSize * level0RankSize * inputMemSize; + level2DataSegsSlice.push_back(sliceTemp); + } + // outputMem传整块,通过baseOffset偏移 + u64 level2BaseOffset = (level0ServerIndex + level1ServerIndex * level1RankSize) * inputMemSize; + CHK_RET(level2AGExecutor->Prepare(outputMem, outputMem, inputMem, count, dataType, stream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, level2DataSegsSlice, level2BaseOffset)); + + CHK_RET(level2AGExecutor->RegisterProfiler(( + level2RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, stream)); + + CHK_RET(RunTemplate(level2AGExecutor, level2CommInfo)); + HCCL_INFO("allgather double ring [superpod] level2 allgather run success"); + + // 第二步,各个AI Server 间 all gather (ring/NHR) + HCCL_INFO("commIdx:%u Tag[%s].commLevel1.size():%u", commIndex, tag.c_str(), + level1RankSize); + + if (level1RankSize > 1) { + std::unique_ptr level1AGExecutor; + if (algType_.algoLevel1 == AlgTypeLevel1::ALG_LEVEL1_RING) { + level1AGExecutor = AlgTemplateRegistry::Instance().GetAlgTemplate( + TemplateType::TEMPLATE_ALL_GATHER_RING, dispatcher_); + HCCL_INFO("allgather ring: using ring algo inter-server."); + } else if (algType_.algoLevel1 == AlgTypeLevel1::ALG_LEVEL1_NB) { + level1AGExecutor = AlgTemplateRegistry::Instance().GetAlgTemplate( + TemplateType::TEMPLATE_ALL_GATHER_NB, dispatcher_); + HCCL_INFO("allgather ring: using nonuniform-bruck algo inter-server."); + } else if (algType_.algoLevel1 == AlgTypeLevel1::ALG_LEVEL1_NHR) { + level1AGExecutor = AlgTemplateRegistry::Instance().GetAlgTemplate( + TemplateType::TEMPLATE_ALL_GATHER_NHR, dispatcher_); + HCCL_INFO("allgather ring: using nonuniform-hierarchical-ring algo inter-server."); + } else { + HCCL_ERROR("allgather ring: unsupported algtype [%s].", AlgTypeToStr(algType_).c_str()); + return HCCL_E_NOT_SUPPORT; + } + CHK_SMART_PTR_NULL(level1AGExecutor); + + // 计算slice, 不同超节点相同slice + std::vector level1DataSegsSlice; + for (u32 j = 0; j < level2RankSize; j++) { + for (u32 i = 0; i < level1RankSize; i++) { + sliceTemp.size = inputMemSize; + sliceTemp.offset = + (i * level0RankSize + j * level1RankSize * level0RankSize + level0ServerIndex) *inputMemSize; + level1DataSegsSlice.push_back(sliceTemp); + } + } + + CHK_RET(level1AGExecutor->Prepare(outputMem, outputMem, inputMem, count, dataType, stream, + HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, level1DataSegsSlice, 0)); + + CHK_RET(level1AGExecutor->RegisterProfiler(( + level1RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + level2CommInfo.localRank, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, stream)); + + CHK_RET(RunTemplate(level1AGExecutor, level1CommInfo)); + HCCL_INFO("allgather double ring [superpod] level1 allgather run success"); + } + + // 节点内做all gather double ring + std::vector dataSegsSlice; + std::vector> multRingsSliceZero; // 数据基于该rank上环0的偏移 + CHK_RET(PrepareAllgatherSlice(level0RankSize, inputMemSize, dataSegsSlice)); + + // 多环数据切分 + if (topoType_ == TopoType::TOPO_TYPE_NP_DOUBLE_RING) { + multRingsSliceZero = PrepareMultiRingSlice(dataSegsSlice, tag, false, topoAttr_.nicList); + } else { + multRingsSliceZero.push_back(dataSegsSlice); + } + std::vector> multRingsSlice; + for (u32 ringIndex = 0; ringIndex < multRingsSliceZero.size(); ringIndex++) { + std::vector level2DataSlice; + CHK_RET(CalculateLevel2AllgatherSlice(inputMemSize, level0RankSize, level1RankSize, level2RankSize, + multRingsSliceZero, level2DataSlice, ringIndex)); + multRingsSlice.push_back(level2DataSlice); + } + + std::vector> multRingsUserMemSlice; + if (!DMAReduceFlag_) { + multRingsUserMemSlice = multRingsSlice; + } else { + for (u32 ringIndex = 0; ringIndex < multRingsSlice.size(); ringIndex++) { + std::vector level2UserMemSlice; + for (auto &cclSlice : multRingsSlice[ringIndex]) { + Slice tmpSlice; + tmpSlice.size = cclSlice.size; + tmpSlice.offset = + (cclSlice.offset / inputMemSize) * count * perDataSize + + multRingsSliceZero[ringIndex][0].offset; + level2UserMemSlice.push_back(tmpSlice); + HCCL_DEBUG("rank[%u], ringIndex[%u], tmpSlice.offset=[%llu], size=[%llu]", + topoAttr_.userRank, ringIndex, tmpSlice.offset, tmpSlice.size); + } + multRingsUserMemSlice.push_back(level2UserMemSlice); + } + } + + CHK_RET(ActiveSlaveStreams(stream)); + if (DMAReduceFlag_ && level1RankSize > 1) { + // allgather输入放在CCL buffer上,通过设置nullptr指示要从CCL buffer获取输入 + opInfo->inputAddr = nullptr; + } + CHK_RET(MultiRingAllGather(tag, inputMem, outputMem, count, + dataType, multRingsSlice, stream, PROF_STAGE_2, 0, opInfo, multRingsUserMemSlice)); + + HCCL_INFO("allgather double ring [superpod] level2 allgather run success"); + return HCCL_SUCCESS; +} + +} // namespace hccl \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.h new file mode 100644 index 0000000..8d4f65b --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_FAKE_EXECUTOR_H +#define COLL_ALLGATHER_FAKE_EXECUTOR_H +#include "coll_comm_executor.h" +namespace hccl { +class CollAllGatherFakeExecutor : public CollCommExecutor { + +public: + explicit CollAllGatherFakeExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherFakeExecutor() = default; + + HcclResult Orchestrate(OpParam& param, AlgResourceResponse& algRes) override; +protected: + // AllGather Loop Executor公共接口 + virtual u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize); + virtual bool IsHugeData(const u64 curSize); + virtual bool IsDataSplitForRdmaSdmaConcurrent(const u64 curSize); + HcclResult RunLoop(OpParam ¶m, AlgResourceResponse &algRes); + virtual bool IsSmallData(const u64 size); + + // 工具类 + HcclResult PrepareAllgatherSlice(u32 sliceNum, u64 inputMemSize, std::vector &dataSegsSlice) const; + + HcclResult CalculateLevel1AllgatherSlice(u64 inputMemSize, u32 level0RankSize, u32 level1RankSize, + std::vector> multRingsSliceZero, std::vector> &multRingsSlice) const; + + HcclResult CalculateLevel2AllgatherSlice(u64 inputMemSize, u32 level0RankSize, + u32 level1RankSize, u32 level2RankSize, std::vector> multRingsSliceZero, + std::vector &level2DataSlice, u32 ringIndex) const; + + HcclResult AllGatherLevel2(const std::string &tag, DeviceMem &inputMem, DeviceMem &outputMem, + u64 count, HcclDataType dataType, Stream &stream, HcomCollOpInfo *opInfo = nullptr); + + bool DMAReduceFlag_{false}; // 是否DMA消减的标志 +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.cc b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.cc new file mode 100644 index 0000000..c0d4c72 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.cc @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "coll_all_gather_fake_mesh_executor.h" + +namespace hccl { +CollAllGatherFakeMeshExecutor::CollAllGatherFakeMeshExecutor(const HcclDispatcher dispatcher, + std::unique_ptr &topoMatcher) + : CollAllGatherFakeExecutor(dispatcher, topoMatcher) +{ + DMAReduceFlag_ = false; +} + +HcclResult CollAllGatherFakeMeshExecutor::CalcStreamNum(u32& streamNum) +{ + u32 totalStreamNum = topoAttr_.deviceNumPerAggregation > 1U ? topoAttr_.deviceNumPerAggregation - 1U : 1U; + streamNum = totalStreamNum - 1U; + HCCL_INFO("[CollAllGatherFakeMeshExecutor][CalcStreamNum] tag[%s] streamNum[%u]", + tag_.c_str(), streamNum); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeMeshExecutor::CalcCommInfo(std::vector& opTransport) +{ + TransportMemType inputType = TransportMemType::RESERVED; + TransportMemType outputType = TransportMemType::RESERVED; + CHK_RET(CalcTransportMemType(inputType, outputType)); + CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport)); + CHK_RET(CalcLevel2CommInfo(inputType, outputType, opTransport)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeMeshExecutor::CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType) +{ + if (workflowMode_ == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE) { + inputType = TransportMemType::CCL_INPUT; + outputType = TransportMemType::CCL_OUTPUT; + } else { + inputType = TransportMemType::PARAM_INPUT; + outputType = TransportMemType::PARAM_OUTPUT; + } + HCCL_INFO("[CollAllGatherFakeMeshExecutor][CalcTransportMemType] tag[%s] inputType[%d], outputType[%d]", + tag_.c_str(), inputType, outputType); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeMeshExecutor::CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel0(COMM_LEVEL0, CommType::COMM_TAG_MESH); + commParaLevel0.meshSinglePlane = (topoAttr_.deviceType == DevType::DEV_TYPE_910B) && + topoMatcher_->GetExternalInputHcclDeterministic() == DETERMINISTIC_DISABLE && + (workflowMode_ != HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel0, opTransport[COMM_LEVEL0], inputType, outputType)); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeMeshExecutor::CalcLevel1CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel1(COMM_LEVEL1, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel1, opTransport[commParaLevel1.commPlane], inputType, outputType)); + HCCL_INFO("[CollAllGatherFakeMeshExecutor][CalcLevel1CommInfo] Calc RingCommInfo Finish"); + return HCCL_SUCCESS; +} + +HcclResult CollAllGatherFakeMeshExecutor::CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) +{ + CommParaInfo commParaLevel2(COMM_LEVEL2, CommType::COMM_TAG_RING_INNER); + CHK_RET(CalcCommPlaneInfo(tag_, commParaLevel2, opTransport[COMM_LEVEL2], inputType, outputType)); + HCCL_INFO("[CollAllGatherFakeMeshExecutor][CalcLevel2CommInfo] Calc RingCommInfo Finish"); + return HCCL_SUCCESS; +} + +u64 CollAllGatherFakeMeshExecutor::CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) +{ + // 中转内存单次最多能够接受的output count + u64 maxCountPerLoop = cclBuffSize / topoAttr_.userRankSize / HCCL_MIN_SLICE_ALIGN + * HCCL_MIN_SLICE_ALIGN / unitSize; + HCCL_DEBUG("[CollAllGatherFakeMeshExecutor][CalcLoopMaxCount] " + "using default userRank [%u] maxCountPerLoop[%llu]", topoAttr_.userRank, maxCountPerLoop); + return maxCountPerLoop; +} + +HcclResult CollAllGatherFakeMeshExecutor::KernelRun(const OpParam ¶m, ExecMem &execMem) +{ + HCCL_CONFIG_INFO(HCCL_ALG, "[CollAllGatherFakeMeshExecutor][KernelRun]allgather mesh start"); + u32 perDataSize = SIZE_TABLE[param.DataDes.dataType]; + + // 获取子通信域信息 + CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1)); + SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0); + u32 level0RankSize = level0CommInfo.localRankSize; + u32 commIndex = level0CommInfo.localRank; + CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1)); + SubCommInfo level1CommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex); + u32 serverIndex = level1CommInfo.localRank; + u32 level1RankSize = level1CommInfo.localRankSize; + SubCommInfo level2CommInfo = GetSubCommInfo(COMM_LEVEL2, COMM_INDEX_0); + u32 superPodIndex = level2CommInfo.localRank; + u32 level2RankSize = level2CommInfo.localRankSize; + + u64 inputMemSize = execMem.inputMem.size(); + u64 baseOffset = serverIndex * inputMemSize * level0RankSize + superPodIndex * inputMemSize * level0RankSize * level1RankSize; + u64 level0Offset = commIndex * inputMemSize; + DeviceMem dstMem = execMem.outputMem.range(baseOffset + level0Offset, inputMemSize); + CHK_SMART_PTR_NULL(dstMem); + // 第一步,将数据从input内存拷贝到output内存的对应位置 + HcclResult ret = HcclD2DMemcpyAsync(dispatcher_, dstMem, execMem.inputMem, const_cast(param.stream)); + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[CollAllGatherFakeMeshExecutor][KernelRun]all gather 4PmeshHD memcpy Failed, Offset[%llu], Size[%llu].", + baseOffset + level0Offset, inputMemSize), ret); + + // 第二步,各个AI Server 内 multi stream mesh all gather + std::vector dataSegsSlice; // 数据分成ranksize份,每份的起始偏移和大小 + std::vector> multiStreamSlice; // 每个stream使用的数据基于用户buffer的偏移 + u32 sliceNum = level0RankSize; + CHK_RET(PrepareAllgatherSlice(sliceNum, inputMemSize, dataSegsSlice)); + CHK_RET(ActiveSlaveStreams(param.stream)); + if (sliceNum > 1) { + // mesh算法stream数量为server内rank数减1 + CHK_RET(AlgTemplateBase::PrepareSliceMeshStreams(dataSegsSlice, sliceNum - 1, multiStreamSlice)); + + // 抽取当前用于多环all gather 的output内存数据 + DeviceMem currentOutputMem = execMem.outputMem.range(baseOffset, inputMemSize * level0RankSize); + CHK_SMART_PTR_NULL(currentOutputMem); + + std::unique_ptr level0TempAlg; + level0TempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(TemplateType::TEMPLATE_ALL_GATHER_FAKE_MESH_AICPU, + dispatcher_); + CHK_SMART_PTR_NULL(level0TempAlg); + HCCL_DEBUG("[CollAllGatherFakeMeshExecutor][KernelRun] using mesh algo inter-server."); + CHK_RET(level0TempAlg->Prepare(algResResp_->slaveStreams, algResResp_->notifiesMain, algResResp_->notifiesAux, + topoAttr_.userRank, nullptr, commIndex, level0RankSize)); + CHK_RET(level0TempAlg->Prepare(currentOutputMem, currentOutputMem, execMem.inputMem, + execMem.count * level0RankSize, param.DataDes.dataType, param.stream, HCCL_REDUCE_RESERVED, + LEVEL0_BRIDGE_RANK_ID, dataSegsSlice, baseOffset)); + u32 rankSize = level0RankSize; + CHK_RET(level0TempAlg->RegisterProfiler((rankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + commIndex, + PROF_STAGE_1, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level0TempAlg, level0CommInfo)); + HCCL_DEBUG("[CollAllGatherFakeMeshExecutor][KernelRun] all gather mesh outer run success"); + } + + // 第三步, AI server 间 recursive halving doubling all gather + u64 hdSize = inputMemSize * level0RankSize; + u64 hdCount = hdSize / perDataSize; + + std::unique_ptr level1TempAlg; + level1TempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate( + TemplateType::TEMPLATE_ALL_GATHER_FAKE_RING_LEVEL1, dispatcher_); + HCCL_DEBUG("[CollAllGatherFakeMeshExecutor][KernelRun] using ring algo inter-server."); + CHK_SMART_PTR_NULL(level1TempAlg); + + baseOffset = superPodIndex * inputMemSize * level0RankSize * level1RankSize; + DeviceMem currentlevel1OutputMem = execMem.outputMem.range(baseOffset, inputMemSize * level0RankSize * level1RankSize); + CHK_SMART_PTR_NULL(currentlevel1OutputMem); + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(level1TempAlg->Prepare(currentlevel1OutputMem, currentlevel1OutputMem, execMem.inputMem, hdCount, + param.DataDes.dataType, param.stream, HcclReduceOp::HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, + std::vector(COMM_INDEX_0), baseOffset)); + + CHK_RET(level1TempAlg->RegisterProfiler((level1RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + serverIndex, + PROF_STAGE_2, HCCL_EXEC_STEP_NOT_SET, param.stream)); + + CHK_RET(RunTemplate(level1TempAlg, level1CommInfo)); + + HCCL_DEBUG("[CollAllGatherFakeMeshExecutor][KernelRun] ring inner run success"); + + u64 level2hdSize = inputMemSize * level0RankSize * level1RankSize; + u64 level2hdCount = level2hdSize / perDataSize; + + std::unique_ptr level2TempAlg; + level2TempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate( + TemplateType::TEMPLATE_ALL_GATHER_FAKE_RING_LEVEL2, dispatcher_); + HCCL_DEBUG("[CollAllGatherFakeMeshExecutor][KernelRun] using ring algo superpod-inter-server."); + CHK_SMART_PTR_NULL(level2TempAlg); + // 此处虽然带入inputMem作为scratch mem, 但inputMem 不能被使用 + CHK_RET(level2TempAlg->Prepare(execMem.outputMem, execMem.outputMem, execMem.inputMem, level2hdCount, + param.DataDes.dataType, param.stream, HcclReduceOp::HCCL_REDUCE_RESERVED, INVALID_VALUE_RANKID, + std::vector(COMM_INDEX_0), 0)); + + CHK_RET(level2TempAlg->RegisterProfiler((level2RankSize << PROF_RANKSIZE_OFFSET_OF_PLANEID) + superPodIndex, + PROF_STAGE_0, HCCL_EXEC_STEP_NOT_SET, param.stream)); + CHK_RET(RunTemplate(level2TempAlg, level2CommInfo)); + return HCCL_SUCCESS; +} + +REGISTER_EXEC("AllGatherFakeMeshExecutor", AllGatherFakeMesh, CollAllGatherFakeMeshExecutor); +} // namespace hccl diff --git a/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.h b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.h new file mode 100644 index 0000000..ecbfd4c --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef COLL_ALLGATHER_FAKE_MESH_EXECUTOR_H +#define COLL_ALLGATHER_FAKE_MESH_EXECUTOR_H +#include "coll_all_gather_fake_executor.h" +namespace hccl { +class CollAllGatherFakeMeshExecutor : public CollAllGatherFakeExecutor { +public: + explicit CollAllGatherFakeMeshExecutor(const HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~CollAllGatherFakeMeshExecutor() = default; + +private: + /* *************** 资源计算 *************** */ + HcclResult CalcStreamNum(u32& streamNum) override; + HcclResult CalcCommInfo(std::vector& opTransport) override; + HcclResult CalcLevel0CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel1CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcLevel2CommInfo(TransportMemType inputType, TransportMemType outputType, + std::vector& opTransport) override; + HcclResult CalcTransportMemType(TransportMemType &inputType, TransportMemType &outputType); + + /* *************** 算法编排 *************** */ + u64 CalcLoopMaxCount(const u64 cclBuffSize, const u32 unitSize) override; + HcclResult KernelRun(const OpParam ¶m, ExecMem &execMem) override; +}; + +} // namespace hccl + +#endif \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt b/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt index ccf812f..bb505d5 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt +++ b/src/domain/collective_communication/algorithm/impl/operator/CMakeLists.txt @@ -6,6 +6,7 @@ set(src_list ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter_v_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_operator.cc + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_fake_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_v_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/broadcast_operator.cc ${CMAKE_CURRENT_SOURCE_DIR}/alltoall_operator.cc diff --git a/src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.cc new file mode 100644 index 0000000..4187634 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "all_gather_fake_operator.h" +#include "device_capacity.h" +#include "rank_consistentcy_checker.h" +#include "executor_impl.h" +#include "coll_alg_op_registry.h" +#include "hccl_aiv.h" + +namespace hccl { +AllGatherFakeOperator::AllGatherFakeOperator(AlgConfigurator* algConfigurator, CCLBufferManager &cclBufferManager, + HcclDispatcher dispatcher, std::unique_ptr &topoMatcher) + : CollAlgOperator(algConfigurator, cclBufferManager, dispatcher, topoMatcher, HcclCMDType::HCCL_CMD_ALLGATHER_FAKE) +{ +} + +AllGatherFakeOperator::~AllGatherFakeOperator() +{ +} + +HcclResult AllGatherFakeOperator::SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, + std::string& newTag) +{ + HcclResult ret; + + if (deviceType_ == DevType::DEV_TYPE_910_93) { + ret = SelectAlgfor91093(param, algName); + } else { + HCCL_ERROR("[SelectAlg] device type[%d] is out of range for selector.", deviceType_); + return HCCL_E_NOT_SUPPORT; + } + CHK_PRT_RET(ret != HCCL_SUCCESS, + HCCL_ERROR("[AllGatherFakeOperator][SelectAlg]tag[%s], all_gather_fake failed, return[%d]", tag.c_str(), ret), ret); + if (workflowMode_ == HcclWorkflowMode::HCCL_WORKFLOW_MODE_OPS_KERNEL_INFO_LIB) { + newTag = tag; + } else if (deviceType_ == DevType::DEV_TYPE_310P3) { + newTag = tag + algName; + } else { + AlgTypeLevel1 algType1 = algType_.algoLevel1; + auto level1Iter = HCCL_ALGO_LEVEL1_NAME_MAP.find(algType1); + CHK_PRT_RET(level1Iter == HCCL_ALGO_LEVEL1_NAME_MAP.end(), HCCL_ERROR("level1: algType1[%u] is invalid.", + algType1), HCCL_E_INTERNAL); + newTag = tag + level1Iter->second + algName; + } + newTag += (param.aicpuUnfoldMode ? "_device" : "_host"); + HCCL_INFO("[SelectAlg] all_gather_fake newTag is [%s]", newTag.c_str()); + + if (UNLIKELY(GetDebugConfig() & HCCL_ALG)) { + HCCL_CONFIG_INFO(HCCL_ALG, + "[AllGatherFakeOperator][SelectAlg]userRank_[%u], algName[%s] actual level1 algo[%d], level2 algo[%d]", + userRank_, algName.c_str(), algType_.algoLevel1, algType_.algoLevel2); + } + + return ret; +} + +HcclResult AllGatherFakeOperator::SelectAlgfor910B(const OpParam& param, std::string& algName) +{ + (void) param; + algName = "AllGatherFakeMeshExecutor"; + HCCL_INFO("[SelectAlgfor910B] all_gather_fake SelectAlgfor910B is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +HcclResult AllGatherFakeOperator::SelectAlgfor91093(const OpParam& param, std::string& algName) +{ + (void) param; + algName = "AllGatherFakeMeshExecutor"; + HCCL_INFO("[SelectAlgfor91093] all_gather_fake SelectAlgfor91093 is algName [%s]", algName.c_str()); + return HCCL_SUCCESS; +} + +REGISTER_OP(HcclCMDType::HCCL_CMD_ALLGATHER_FAKE, AllGatherFake, AllGatherFakeOperator); + +} \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.h b/src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.h new file mode 100644 index 0000000..2d947d2 --- /dev/null +++ b/src/domain/collective_communication/algorithm/impl/operator/all_gather_fake_operator.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ALL_GATHER_FAKE_OPERATOR_H +#define ALL_GATHER_FAKE_OPERATOR_H + +#include "coll_alg_operator.h" + +namespace hccl { +class AllGatherFakeOperator : public CollAlgOperator { +public: + AllGatherFakeOperator(AlgConfigurator* algConfigurator, CCLBufferManager &cclBufferManager, + HcclDispatcher dispatcher, std::unique_ptr &topoMatcher); + ~AllGatherFakeOperator(); + HcclResult SelectAlg(const std::string& tag, const OpParam& param, std::string& algName, std::string& newTag) override; +private: + HcclResult SelectAlgfor910B(const OpParam& param, std::string& algName); + HcclResult SelectAlgfor91093(const OpParam& param, std::string& algName); +}; + +} + +#endif /** __ALL_GATHER_FAKE_OPERATOR_H__ */ \ No newline at end of file diff --git a/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc b/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc index b79b6f1..a2008f7 100644 --- a/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc +++ b/src/domain/collective_communication/algorithm/impl/operator/coll_alg_operator.cc @@ -353,7 +353,12 @@ HcclResult CollAlgOperator::GetDefaultAlgoLevel1V2(HcclCMDType hcclCMDType, u64 auto originalAlgTypeLevel0 = algType_.algoLevel0; bool disdeterniminsticWithInlineReduce = isInlineReduce && isRdmaReduce && topoMatcher_->GetDeterministicConfig() == DETERMINISTIC_DISABLE; - + if (hcclCMDType == HcclCMDType::HCCL_CMD_ALLGATHER_FAKE) { + algType = AlgTypeLevel1::ALG_LEVEL1_RING; + HCCL_DEBUG("[CollAlgOperator][GetDefaultAlgoLevel1V2] The algTypeLevel1 is ALG_LEVEL1_RING " + "since hcclCMDType[%d].", hcclCMDType); + return HCCL_SUCCESS; + } // 对于不支持Rdma Lite的场景,下发性能较差,RS和AG需要一个很大的数据量(AR的一半)才能掩盖下发时间 u64 pipelineMinSize = (isSupportRdmaLite_) ? (PIPELINE_MIN_SIZE) : (PIPELINE_MIN_SIZE_NO_LITE); if (((hcclCMDType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER && disdeterniminsticWithInlineReduce) || diff --git a/src/domain/collective_communication/algorithm/pub_inc/alg_cmd_type.h b/src/domain/collective_communication/algorithm/pub_inc/alg_cmd_type.h index 0f69a81..765291f 100644 --- a/src/domain/collective_communication/algorithm/pub_inc/alg_cmd_type.h +++ b/src/domain/collective_communication/algorithm/pub_inc/alg_cmd_type.h @@ -33,6 +33,7 @@ const std::map HCOM_CMD_TYPE_STR_MAP{ {HcclCMDType::HCCL_CMD_BATCH_WRITE, "batch_write"}, {HcclCMDType::HCCL_CMD_BATCH_GET, "batch_get"}, {HcclCMDType::HCCL_CMD_BATCH_PUT, "batch_put"}, + {HcclCMDType::HCCL_CMD_ALLGATHER_FAKE, "allgather_fake"}, {HcclCMDType::HCCL_CMD_ALL, "all"}, {HcclCMDType::HCCL_CMD_MAX, "max"} }; diff --git a/src/domain/collective_communication/algorithm/pub_inc/ffts_common_pub.h b/src/domain/collective_communication/algorithm/pub_inc/ffts_common_pub.h index cf96925..90c34c4 100644 --- a/src/domain/collective_communication/algorithm/pub_inc/ffts_common_pub.h +++ b/src/domain/collective_communication/algorithm/pub_inc/ffts_common_pub.h @@ -281,6 +281,43 @@ using HcclOpMetaInfo = struct HcclOpMetaInfoDef { std::to_string(static_cast(deterministic)); #else return ""; +#endif + } + + static HcclOpMetaInfoDef GetOneForAllGatherFake(u32 algolevel1Type = 0, bool hugeData = false, bool smallCount = false, + CopyPattern copyPattern = CopyPattern::BCOPY, bool dataSplit = false, bool isAivMode = false) + { + HcclOpMetaInfoDef meta; + meta.opType = HcclCMDType::HCCL_CMD_ALLGATHER_FAKE; // 新增算子类型 + meta.isEnableCache = true; // 是否做子图复用,如不需要,则置成false + // 后续参数是用来保证不同算法key的唯一性的,如果需要,可以提供给GetCacheKeyFake使用 + meta.copyPattern = copyPattern; + meta.algolevel1Type = algolevel1Type; + meta.hugeData = hugeData; + meta.isSmallCount = smallCount; + meta.dataSplit = dataSplit; + meta.isAivMode = isAivMode; + return meta; + } + + std::string GetCacheKeyFake() const + { +#ifndef CCL_KERNEL_AICPU + std::string base = "MT_sustom_"; // 该前缀避免跟华为自研的冲突 + std::string type = std::to_string(static_cast(opType)); // 后续需要保持key的唯一性,可拓展不同的属性加以区分 + // 后续参数是用来保证不同算法key的唯一性的,如果需要,可以提供给GetCacheKeyFake使用 + std::string isRootRankStr = isRootRank ? "1" : "0"; + std::string isSmallCountStr = isSmallCount ? "1" : "0"; + std::string isDefaultPathStr = isDefaultPath ? "1" : "0"; + std::string dataSplitStr = dataSplit ? "1" : "0"; + std::string isAivModeStr = isAivMode ? "1" : "0"; + return base + type + isRootRankStr + std::to_string(static_cast(reduceType)) + + std::to_string(rootRank) + std::to_string(sliceNum) + std::to_string(static_cast(dataType)) + + isSmallCountStr + isDefaultPathStr + std::to_string(piplineSliceNum) + std::to_string(algolevel1Type) + + std::to_string(static_cast(copyPattern)) + dataSplitStr + isAivModeStr + + std::to_string(static_cast(deterministic));; +#else + return ""; #endif } }; diff --git a/src/domain/collective_communication/common/debug/profiling/inc/task_profiling_pub.h b/src/domain/collective_communication/common/debug/profiling/inc/task_profiling_pub.h index 4b650d4..d587ab5 100644 --- a/src/domain/collective_communication/common/debug/profiling/inc/task_profiling_pub.h +++ b/src/domain/collective_communication/common/debug/profiling/inc/task_profiling_pub.h @@ -246,6 +246,7 @@ const std::map PROF_OP_NAME = {{HcclCMDType::HCCL_CMD_ {HcclCMDType::HCCL_CMD_BROADCAST, "hcom_broadcast_"}, {HcclCMDType::HCCL_CMD_ALLREDUCE, "hcom_allReduce_"}, {HcclCMDType::HCCL_CMD_REDUCE, "hcom_reduce_"}, {HcclCMDType::HCCL_CMD_SEND, "hcom_send_"}, {HcclCMDType::HCCL_CMD_RECEIVE, "hcom_receive_"}, {HcclCMDType::HCCL_CMD_ALLGATHER, "hcom_allGather_"}, + {HcclCMDType::HCCL_CMD_ALLGATHER_FAKE, "hcom_allgather_fake_"}, {HcclCMDType::HCCL_CMD_REDUCE_SCATTER, "hcom_reduceScatter_"}, {HcclCMDType::HCCL_CMD_SCATTER, "hcom_scatter_"}, {HcclCMDType::HCCL_CMD_ALLTOALL, "hcom_alltoall_"}, {HcclCMDType::HCCL_CMD_ALLTOALLV, "hcom_alltoallv_"}, {HcclCMDType::HCCL_CMD_ALLGATHER_V, "hcom_allGatherv_"}, {HcclCMDType::HCCL_CMD_REDUCE_SCATTER_V, "hcom_reduceScatterv_"}, diff --git a/src/domain/collective_communication/framework/CMakeLists.txt b/src/domain/collective_communication/framework/CMakeLists.txt index 7eb0439..7bc02e5 100644 --- a/src/domain/collective_communication/framework/CMakeLists.txt +++ b/src/domain/collective_communication/framework/CMakeLists.txt @@ -251,6 +251,11 @@ if(NOT BUILD_OPEN_PROJECT OR (BUILD_OPEN_PROJECT AND KERNEL_MODE)) ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_gather/all_gather_hd_stage.cc ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_gather/all_gather_hccs_sio.cc + # temp_all_gather_fake + ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_mesh_aicpu.cc + ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu.cc + ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_gather_fake/all_gather_fake_ring_aicpu_async.cc + ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_reduce/all_reduce_recursive_hd.cc ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_reduce/all_reduce_ring.cc ${HCCL_BASE_DIR}/algorithm/base/alg_template/temp_all_reduce/all_reduce_nhr_oneshot.cc @@ -379,6 +384,10 @@ if(NOT BUILD_OPEN_PROJECT OR (BUILD_OPEN_PROJECT AND KERNEL_MODE)) ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_zerocopy_exchange_executor.cc ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather/coll_all_gather_ring_zerocopy_pipeline_executor.cc + # coll_all_gather_fake + ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_executor.cc + ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather_fake/coll_all_gather_fake_mesh_executor.cc + ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather_v/coll_all_gather_v_executor.cc ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather_v/coll_all_gather_v_ring_for_910_93_executor.cc ${HCCL_BASE_DIR}/algorithm/impl/coll_executor/coll_all_gather_v/coll_aligned_all_gather_v_double_ring_for_910_93_executor.cc diff --git a/src/domain/collective_communication/framework/communicator/hccl_comm.cc b/src/domain/collective_communication/framework/communicator/hccl_comm.cc index b654c0a..9388adf 100644 --- a/src/domain/collective_communication/framework/communicator/hccl_comm.cc +++ b/src/domain/collective_communication/framework/communicator/hccl_comm.cc @@ -349,6 +349,24 @@ HcclResult hcclComm::AllGatherOutPlace(const std::string &tag, void *inputPtr, v return HCCL_SUCCESS; } +HcclResult hcclComm::AllGatherFakeOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, + HcclDataType dataType, HcclRtStream stream) +{ + /* 增加输出日志关键字 */ + HCCL_DEBUG("HCCL_KEY_INFO: tag[%s], input_ptr[%p], output_ptr[%p], count[%llu], data_type[%s]", + tag.c_str(), inputPtr, outputPtr, inputCount, GetDataTypeEnumStr(dataType).c_str()); + + /* * 入参检查 */ + CHK_RET(communicator_->CheckDataType(dataType, false)); + HcclResult ret = communicator_->AllGatherFakeOutPlace(tag, inputPtr, outputPtr, inputCount, dataType, stream); + if (ret != HCCL_SUCCESS) { + PrintSubmittedOpCnt(tag, ret); + return ret; + } + + return HCCL_SUCCESS; +} + HcclResult hcclComm::AllGatherVOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, const void *outputCounts, const void *outputDispls, HcclDataType dataType, HcclRtStream stream) { diff --git a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc index 0cf3f6d..cbc76da 100644 --- a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc +++ b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.cc @@ -103,6 +103,7 @@ namespace hccl { // 卸载自定义算子 // 请勿删除,该函数为用户自定义算子时使用,应释放句柄:UnloadBinary(binCustomHandle_); + UnloadBinary(binCustomHandle_); return; } @@ -110,6 +111,8 @@ namespace hccl { // 加载自定义算子 // 请勿删除,该函数为用户自定义算子时使用,应加载句柄 + const char *binPath1 = "/../aarch64-linux/lib64/device/lib64/libaicpu_custom.json"; + CHK_RET(LoadCustomFile(binPath1, ACL_RT_BINARY_LOAD_OPT_CPU_KERNEL_MODE, 1, binCustomHandle_)); return HCCL_SUCCESS; } diff --git a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h index 6699a44..a901410 100644 --- a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h +++ b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator.h @@ -137,6 +137,8 @@ public: virtual HcclResult AllGatherOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, HcclRtStream stream); + virtual HcclResult AllGatherFakeOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, + HcclDataType dataType, HcclRtStream stream); virtual HcclResult AllGatherVOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, const void *outputCounts, const void *outputDispls, HcclDataType dataType, HcclRtStream stream); diff --git a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_device.cc b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_device.cc index aafb8a2..0ff4a9a 100644 --- a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_device.cc +++ b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_device.cc @@ -622,6 +622,12 @@ namespace hccl return HCCL_SUCCESS; } + HcclResult HcclCommunicator::AllGatherFakeOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, + u64 inputCount, HcclDataType dataType, HcclRtStream stream) + { + return HCCL_SUCCESS; + } + HcclResult HcclCommunicator::AllGatherVOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, const void *outputCounts, const void *outputDispls, HcclDataType dataType, HcclRtStream stream) { diff --git a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_host.cc b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_host.cc index b188f26..51849c7 100644 --- a/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_host.cc +++ b/src/domain/collective_communication/framework/communicator/impl/hccl_communicator_host.cc @@ -2459,6 +2459,54 @@ namespace hccl return HCCL_SUCCESS; } + HcclResult HcclCommunicator::AllGatherFakeOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, + u64 inputCount, HcclDataType dataType, HcclRtStream stream) + { + CHK_RET(CheckSuspendingStatus()); + if (!IsAtomicInit()) { + HCCL_ERROR( + "[HcclCommunicator][AllGatherFakeOutPlace]errNo[0x%016llx] hccl init must be called before call this function", + HCCL_ERROR_CODE(HCCL_E_UNAVAIL)); + return HCCL_E_UNAVAIL; + } + + bool aicpuUnfoldMode = false; + if (GetExternalInputHcclAicpuUnfold() == true && (deviceType_ == DevType::DEV_TYPE_910_93 || deviceType_ == DevType::DEV_TYPE_910B) && (userRankSize_ != 1)) { + aicpuUnfoldMode = true; + HCCL_INFO("[LIP]unfold"); + } + + bool isCapture = StreamIsCapture(stream); + + Stream streamObj(stream); + CHK_RET(callbackTask_->CallbackRegStream(stream)); + + std::vector &ranksPorts = groupNicRanksPort_.empty() ? nicRanksPort_ : groupNicRanksPort_; + implAlg_->SetHDCModeInfo(rankDevicePhyIdNicInfoMap_, ranksPorts, isSetHDCModeInfo_, isUseRankPort_); + + u32 perDataSize = SIZE_TABLE[dataType]; + u64 totalSize = inputCount * perDataSize * userRankSize_; + + OpParam opParam; + opParam.tag = tag; + opParam.inputPtr = inputPtr; + opParam.inputSize = inputCount * perDataSize; + opParam.outputPtr = outputPtr; + opParam.outputSize = totalSize; + opParam.DataDes.count = inputCount; + opParam.DataDes.dataType = dataType; + opParam.reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED; + opParam.stream = streamObj; + opParam.syncMode = SyncMode::DEFAULT_TIMEWAITSYNCMODE; + opParam.opBaseAtraceInfo = opBaseAtraceInfo_.get(); + opParam.aicpuUnfoldMode = aicpuUnfoldMode; + opParam.isCapture = isCapture; + opParam.opType = HcclCMDType::HCCL_CMD_ALLGATHER_FAKE; + + CHK_RET(ExecOp(HcclCMDType::HCCL_CMD_ALLGATHER_FAKE, opParam, true)); + return HCCL_SUCCESS; + } + HcclResult HcclCommunicator::AllGatherVOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, const void *outputCounts, const void *outputDispls, HcclDataType dataType, HcclRtStream stream) { diff --git a/src/domain/collective_communication/framework/inc/hccl_comm_pub.h b/src/domain/collective_communication/framework/inc/hccl_comm_pub.h index 4b62a65..72b5912 100644 --- a/src/domain/collective_communication/framework/inc/hccl_comm_pub.h +++ b/src/domain/collective_communication/framework/inc/hccl_comm_pub.h @@ -90,6 +90,8 @@ public: rtStream_t stream, HcomCollOpInfo *opInfo = nullptr); HcclResult AllGatherOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, HcclDataType dataType, rtStream_t stream); + HcclResult AllGatherFakeOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, + HcclDataType dataType, rtStream_t stream); HcclResult AllGatherVOutPlace(const std::string &tag, void *inputPtr, void *outputPtr, u64 inputCount, const void *outputCounts, const void *outputDispls, HcclDataType dataType, HcclRtStream stream); HcclResult AllGatherV(const std::string &tag, const void *sendBuf, u64 sendCount, const void *recvBuf, diff --git a/src/domain/collective_communication/framework/op_base/src/op_base.cc b/src/domain/collective_communication/framework/op_base/src/op_base.cc index 44104b9..db0c9ee 100644 --- a/src/domain/collective_communication/framework/op_base/src/op_base.cc +++ b/src/domain/collective_communication/framework/op_base/src/op_base.cc @@ -1942,6 +1942,86 @@ HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclD return HCCL_SUCCESS; } +HcclResult HcclAllGatherFake(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + HcclComm comm, aclrtStream stream) +{ + HcclUs startut = TIME_NOW(); + bool isCapture; + rtStreamCaptureStatus captureStatus = rtStreamCaptureStatus::RT_STREAM_CAPTURE_STATUS_NONE; + uint32_t modelId = 0xFFFFFFFF; + CHK_PRT(GetCaptureInfo(stream, captureStatus, modelId, isCapture)); + if (!isCapture) { + HcclSetIfProfile(); + } + s32 threadID = SalGetTid(); + ProfilingManagerPub::SetThreadCaptureStatus(threadID, isCapture); + uint64_t beginTime = hrtMsprofSysCycleTime(); + + CHK_PRT_RET(sendCount == 0, HCCL_WARNING("input sendCount is 0, return HcclAllGatherFake success"), HCCL_SUCCESS); + // 入参合法性校验 + RPT_INPUT_ERR(comm == nullptr, "EI0003", std::vector({"ccl_op", "parameter", "value", "tips"}),\ + std::vector({"HcclAllGatherFake", "comm", "nullptr", "please check comm"})); + CHK_PTR_NULL(comm); + RPT_INPUT_ERR(sendBuf == nullptr, "EI0003", std::vector({"ccl_op", "parameter", "value", "tips"}),\ + std::vector({"HcclAllGatherFake", "sendBuf", "nullptr", "please check sendBuf"})); + CHK_PTR_NULL(sendBuf); + RPT_INPUT_ERR(recvBuf == nullptr, "EI0003", std::vector({"ccl_op", "parameter", "value", "tips"}),\ + std::vector({"HcclAllGatherFake", "recvBuf", "nullptr", "please check recvBuf"})); + CHK_PTR_NULL(recvBuf); + hccl::hcclComm* hcclComm = static_cast(comm); + const std::lock_guard lock(hcclComm->operatorlock_); + StateGuard guard(hcclComm, HcclCommState::INUSE); + // 同通信域同算子复用tag + const std::string tag = "AllGatherFake_" + hcclComm->GetIdentifier(); + CHK_RET_AND_PRINT_IDE(HcomCheckOpParam(tag.c_str(), sendCount, dataType, stream), tag.c_str()); + + /* 接口交互信息日志 */ + char stackLogBuffer[LOG_TMPBUF_SIZE]; + if (GetExternalInputHcclEnableEntryLog()) { + s32 deviceLogicId = 0; + CHK_RET(hrtGetDeviceRefresh(&deviceLogicId)); + + s32 streamId = 0; + CHK_RET_AND_PRINT_IDE(hrtGetStreamId(stream, streamId), tag.c_str()); + + s32 ret = snprintf_s(stackLogBuffer, LOG_TMPBUF_SIZE, LOG_TMPBUF_SIZE - 1U, + "tag[%s], sendBuf[%p], recvBuf[%p], sendCount[%llu], dataType[%s], streamId[%d]," + "deviceLogicId[%d]", + tag.c_str(), sendBuf, recvBuf, sendCount, GetDataTypeEnumStr(dataType).c_str(), streamId, deviceLogicId); + + CHK_PRT_CONT(ret == -1, HCCL_WARNING("Failed to build log info, tag[%s].", tag.c_str())); + std::string logInfo = "Entry-HcclAllGatherFake:" + std::string(stackLogBuffer) + + ", capture status[" + to_string(captureStatus) + "], model id[" + to_string(modelId) + "]."; + CHK_RET_AND_PRINT_IDE(hcclComm->SaveTraceInfo(logInfo), tag.c_str()); + } + + CHK_RET_AND_PRINT_IDE(SetWorkflowMode(HcclWorkflowMode::HCCL_WORKFLOW_MODE_OP_BASE), tag.c_str()); + + CHK_RET_AND_PRINT_IDE(PrintMemoryAttr(sendBuf), tag.c_str()); + + CHK_RET_AND_PRINT_IDE(PrintMemoryAttr(recvBuf), tag.c_str()); + + CHK_RET_AND_PRINT_IDE(SetDefaultQosConfig(hcclComm), tag.c_str()); + + CHK_RET_AND_PRINT_IDE(hcclComm->AllGatherFakeOutPlace(tag, sendBuf, recvBuf, sendCount, dataType, stream), tag.c_str()); + CHK_RET(CallMsprofReportHostApi(hcclComm, HcclCMDType::HCCL_CMD_ALLGATHER_FAKE, beginTime, sendCount, dataType, + tag)); + if (!isCapture) { + HcclResetIfProfile(); + } + ProfilingManagerPub::DeleteThreadCaptureStatus(threadID); + + if (GetExternalInputHcclEnableEntryLog()) { + HcclUs endut = TIME_NOW(); + /* 关键状态记录 */ + std::string endInfo = "HcclAllGatherFake:success,take time: " + + std::to_string(DURATION_US(endut - startut).count()) + " us," + std::string(stackLogBuffer); + CHK_RET_AND_PRINT_IDE(hcclComm->SaveTraceInfo(endInfo), tag.c_str()); + } + + return HCCL_SUCCESS; +} + HcclResult HcclAllGatherV(void *sendBuf, uint64_t sendCount, void *recvBuf, const void *recvCounts, const void *recvDispls, HcclDataType dataType, HcclComm comm, aclrtStream stream) { diff --git a/src/domain/collective_communication/framework/op_base/src/op_base.h b/src/domain/collective_communication/framework/op_base/src/op_base.h index f79e443..88bd884 100644 --- a/src/domain/collective_communication/framework/op_base/src/op_base.h +++ b/src/domain/collective_communication/framework/op_base/src/op_base.h @@ -105,6 +105,8 @@ extern "C" { HcclResult HcclCommInitClusterInfoMemConfig(const char *rankTableString, uint32_t rank, HcclCommConfig *config, HcclComm *comm); +extern HcclResult HcclAllGatherFake(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + HcclComm comm, aclrtStream stream); #ifdef __cplusplus } #endif // __cplusplus -- Gitee