From 7e2a9d1dfb0c4bfccfe6fb9136fbe5b910056029 Mon Sep 17 00:00:00 2001 From: liuzihan000 Date: Fri, 17 Oct 2025 15:32:20 +0800 Subject: [PATCH 1/2] add all_to_all_v op --- inferrt/src/ops/ascend/hccl/hccl_adapter.cc | 1 + .../src/ops/ascend/hccl/hccl_all_to_all.cc | 59 +++++++++++-- inferrt/src/ops/ascend/hccl/hccl_all_to_all.h | 1 + inferrt/src/ops/op_base/op_all_to_all.cc | 7 +- .../distributed/check_distributed_ops.py | 86 +++++++++++++------ .../distributed/test_check_distributed_ops.py | 22 ++++- 6 files changed, 137 insertions(+), 39 deletions(-) diff --git a/inferrt/src/ops/ascend/hccl/hccl_adapter.cc b/inferrt/src/ops/ascend/hccl/hccl_adapter.cc index 58c7ef81..5fc223e3 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_adapter.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_adapter.cc @@ -56,6 +56,7 @@ void HcclAdapter::InitPlugin() { launch_hccl_reduce_scatter_ = DlsymFuncObj(HcclReduceScatter, plugin_handle_); launch_hccl_all_gather_ = DlsymFuncObj(HcclAllGather, plugin_handle_); launch_hccl_all_to_all_ = DlsymFuncObj(HcclAlltoAll, plugin_handle_); + launch_hccl_all_to_allv_ = DlsymFuncObj(HcclAlltoAllV, plugin_handle_); } void HcclAdapter::FinalizePlugin() { diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc index e657f90a..47e074e6 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc @@ -17,19 +17,50 @@ #include #include "ops/ascend/hccl/hccl_all_to_all.h" -#include "ops/ascend/hccl/hccl_adapter.h" #include "hardware/hardware_abstract/collective/collective_manager.h" +#include "hardware/ascend/res_manager/ascend_stream_manager.h" #include "ops/ascend/hccl/hcom_utils.h" +#include "ops/ascend/hccl/hccl_adapter.h" #include "hccl/hccl_types.h" #include "hccl/hccl.h" #include "common/logger.h" #include "ops/op_register.h" -#include "hardware/ascend/res_manager/ascend_stream_manager.h" - namespace mrt { namespace ops { +bool is_all_to_all_v(const ir::TuplePtr &send_numel_list, const ir::TuplePtr &recv_numel_list) { + for (size_t i = 0; i < send_numel_list->Size(); i++) { + if (send_numel_list->operator[](i)->ToInt() != send_numel_list->operator[](0)->ToInt()) { + return true; + } + } + for (size_t i = 0; i < recv_numel_list->Size(); i++) { + if (recv_numel_list->operator[](i)->ToInt() != recv_numel_list->operator[](0)->ToInt()) { + return true; + } + } + return false; +} + +void GetAllToAllVParam(const ir::TuplePtr &send_numel_list, const ir::TuplePtr &recv_numel_list, + HcclAllToAllVParams *params) { + uint64_t offset = 0; + for (size_t i = 0; i < send_numel_list->Size(); i++) { + auto count = static_cast(send_numel_list->operator[](i)->ToInt()); + params->sendcounts.push_back(count); + params->sdispls.push_back(offset); + offset += count; + } + offset = 0; + for (size_t i = 0; i < recv_numel_list->Size(); i++) { + auto count = static_cast(recv_numel_list->operator[](i)->ToInt()); + params->recvcounts.push_back(count); + params->rdispls.push_back(offset); + offset += count; + } +} + OpsErrorCode HcclAllToAll::CalcWorkspace(const std::vector &input, const ir::Value *output, size_t *workspace_size) { LOG_OUT << "HcclAllToAll CalcWorkspace"; @@ -38,10 +69,9 @@ OpsErrorCode HcclAllToAll::CalcWorkspace(const std::vector &i HcclAdapter::GetInstance().InitHccl(); auto [hccl_count, hccl_data_type] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); hcclKernel.hccl_count_ = hccl_count / rank_size; - hcclKernel.hccl_data_type_ = hccl_data_type; hcclKernel.comm_ = HcomUtil::LoadHcclLibrary(group_name); - + useAllToAllV = is_all_to_all_v(input[kIndex2]->ToTuple(), input[kIndex1]->ToTuple()); return SUCCESS; } @@ -49,10 +79,21 @@ OpsErrorCode HcclAllToAll::Launch(const std::vector &input, v ir::Value *output, void *stream) { LOG_OUT << "HcclAllToAll launch"; auto out_tensor = output->ToTensor(); - HcclAllToAllParams params = {hcclKernel.hccl_count_, hcclKernel.hccl_count_}; - auto hccl_result = HcclAdapter::GetInstance().HcclAllToAll(const_cast(input[kIndex0]->ToTensor()->DataPtr()), - out_tensor->DataPtr(), params, hcclKernel.hccl_data_type_, - stream, hcclKernel.comm_); + ::HcclResult hccl_result; + if (useAllToAllV) { + LOG_OUT << "HcclAllToAll launch AllToAllV Kernel"; + HcclAllToAllVParams params; + GetAllToAllVParam(input[kIndex2]->ToTuple(), input[kIndex1]->ToTuple(), ¶ms); + hccl_result = HcclAdapter::GetInstance().HcclAlltoAllV(const_cast(input[kIndex0]->ToTensor()->DataPtr()), + out_tensor->DataPtr(), params, hcclKernel.hccl_data_type_, + stream, hcclKernel.comm_); + } else { + LOG_OUT << "HcclAllToAll launch AllToAll Kernel"; + HcclAllToAllParams params = {hcclKernel.hccl_count_, hcclKernel.hccl_count_}; + hccl_result = HcclAdapter::GetInstance().HcclAllToAll(const_cast(input[kIndex0]->ToTensor()->DataPtr()), + out_tensor->DataPtr(), params, hcclKernel.hccl_data_type_, + stream, hcclKernel.comm_); + } if (hccl_result != ::HcclResult::HCCL_SUCCESS) { LOG_ERROR << "HcclAllReduce failed, hccl_result: " << hccl_result; diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h index 87f7abe0..2456e679 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h +++ b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h @@ -36,6 +36,7 @@ class HcclAllToAll : public OpAllToAll { private: HcclKernel hcclKernel; + bool useAllToAllV; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/op_base/op_all_to_all.cc b/inferrt/src/ops/op_base/op_all_to_all.cc index e84573c8..aad4411b 100644 --- a/inferrt/src/ops/op_base/op_all_to_all.cc +++ b/inferrt/src/ops/op_base/op_all_to_all.cc @@ -25,7 +25,12 @@ OpsErrorCode OpAllToAll::InferShape(const std::vector &input, auto &input0Shape = input[kIndex0]->ToTensor()->Shape(); auto outputShape = input0Shape; - + auto output_split_sizes = input[kIndex1]->ToTuple(); + int64_t output_size = 0; + for (size_t i = 0; i < output_split_sizes->Size(); ++i) { + output_size += output_split_sizes->operator[](i)->ToInt(); + } + outputShape[0] = output_size; auto outputTensor = output->ToTensor(); CHECK_IF_NULL(outputTensor); outputTensor->SetShape(outputShape); diff --git a/tests/st/inferrt/distributed/check_distributed_ops.py b/tests/st/inferrt/distributed/check_distributed_ops.py index 0d7a654e..58aeee31 100644 --- a/tests/st/inferrt/distributed/check_distributed_ops.py +++ b/tests/st/inferrt/distributed/check_distributed_ops.py @@ -78,6 +78,12 @@ def check_group_info(pg=None): return +def check_op_output_with_mul(output, expect_out): + output = output.cpu() + expect_output = (expect_out * expect_out).cpu() + assert (output == expect_output).all(), f"expected output is {expect_output}, but got {output}" + + def test_all_gather(): """ Feature: Check all_gather op launch @@ -93,11 +99,6 @@ def test_all_gather(): output = torch.mul(gathered, gathered) return output - def check_allgather_output(output, expect_output): - output = output.cpu() - expect_output = (expect_output * expect_output).cpu() - assert (output == expect_output).all(), f"expected output is {expect_output}, but got {output}" - setup_distributed() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() @@ -122,7 +123,7 @@ def test_all_gather(): output = compiled_model(example_input, gathered, new_pg) check_group_info(new_pg) - check_allgather_output(output, expect_output) + check_op_output_with_mul(output, expect_output) dist.destroy_process_group() @@ -141,11 +142,6 @@ def test_reduce_scatter(): output = torch.mul(tensor_out, tensor_out) return output - def check_reduce_scatter_output(output, expect_out): - output = output.cpu() - expect_output = (expect_out * expect_out).cpu() - assert (output == expect_output).all(), f"expected output is {expect_output}, but got {output}" - setup_distributed() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() @@ -169,7 +165,7 @@ def test_reduce_scatter(): output = compiled_model(tensor_out, tensor_in, new_pg) check_group_info(new_pg) - check_reduce_scatter_output(output, expect_out) + check_op_output_with_mul(output, expect_out) dist.destroy_process_group() @@ -188,10 +184,6 @@ def test_all_reduce(): output = torch.mul(tensor_in, tensor_in) return output - def check_all_reduce_output(output, expect_out): - output = output.cpu() - expect_output = (expect_out * expect_out).cpu() - assert (output == expect_output).all(), f"expected output is {expect_output}, but got {output}" setup_distributed() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() @@ -214,11 +206,11 @@ def test_all_reduce(): output = compiled_model(tensor_in, new_pg) check_group_info(new_pg) - check_all_reduce_output(output, expect_out) + check_op_output_with_mul(output, expect_out) dist.destroy_process_group() -def test_all_to_all(): +def test_all_to_all_single(): """ Feature: Check all_to_all op launch Description: Check all_to_all op launch with cache @@ -233,11 +225,6 @@ def test_all_to_all(): output = torch.mul(tensor_out, tensor_out) return output - def check_all_to_all_output(output, expect_out): - output = output.cpu() - expect_output = (expect_out * expect_out).cpu() - assert (output == expect_output).all(), f"expected output is {expect_output}, but got {output}" - setup_distributed() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() @@ -260,13 +247,62 @@ def test_all_to_all(): output = compiled_model(tensor_out, tensor_in, new_pg) check_group_info(new_pg) - check_all_to_all_output(output, expect_out) + check_op_output_with_mul(output, expect_out) dist.destroy_process_group() +def test_all_to_all_v_single(): + """ + Feature: Check all_to_all op launch + Description: Check all_to_all op launch with cache + Expectation: The result is correct + """ + class SimpleNetwork(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor_out, tensor_in, output_split_sizes, input_split_sizes, pg=None): + dist.all_to_all_single(tensor_out, tensor_in, output_split_sizes, input_split_sizes, group=pg) + output = torch.mul(tensor_out, tensor_out) + return output + + setup_distributed() + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() + group_list = [ii for ii in range(world_size)] + new_pg = dist.new_group(group_list) + + model = SimpleNetwork().npu() + if rank == 0: + tensor = torch.arange(6, dtype=torch.int64).npu() + input_split_sizes = [2, 4] + output_split_sizes = [2, 3] + output = torch.zeros(size=[5], dtype=torch.int64).npu() + expect_out = torch.zeros(size=[5], dtype=torch.int64).npu() + + else: + tensor = torch.arange(4, dtype=torch.int64).npu() + 6 + input_split_sizes = [3, 1] + output_split_sizes = [4, 1] + output = torch.zeros(size=[5], dtype=torch.int64).npu() + expect_out = torch.zeros(size=[5], dtype=torch.int64).npu() + + dist.all_to_all_single(expect_out, tensor, output_split_sizes, input_split_sizes, group=new_pg) + compiled_model = torch.compile( + model, + backend=backend, + fullgraph=True, + ) + + output = compiled_model(output, tensor, output_split_sizes, input_split_sizes, new_pg) + + check_group_info(new_pg) + check_op_output_with_mul(output, expect_out) + dist.destroy_process_group() + if __name__ == "__main__": import sys test_name = sys.argv[1] if len(sys.argv) > 1 else None assert test_name is not None, f"test case name is None" exit_code = pytest.main([f"tests/st/inferrt/distributed/check_distributed_ops.py::{test_name}"]) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/tests/st/inferrt/distributed/test_check_distributed_ops.py b/tests/st/inferrt/distributed/test_check_distributed_ops.py index 2d55a98f..327b3ce3 100644 --- a/tests/st/inferrt/distributed/test_check_distributed_ops.py +++ b/tests/st/inferrt/distributed/test_check_distributed_ops.py @@ -65,14 +65,28 @@ def test_check_all_reduce_op(pipeline, monkeypatch): @arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") @pytest.mark.parametrize("pipeline", (True, False)) -def test_check_all_to_all_op(pipeline, monkeypatch): +def test_check_all_to_all_single_op(pipeline, monkeypatch): """ - Feature: Check all_to_all op launch - Description: Check all_to_all op launch with cache + Feature: Check all_to_all_single op launch + Description: Check all_to_all_single op launch with cache Expectation: The result is correct """ if pipeline: monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") - command = "torchrun --nproc_per_node=2 tests/st/inferrt/distributed/check_distributed_ops.py test_all_to_all" + command = "torchrun --nproc_per_node=2 tests/st/inferrt/distributed/check_distributed_ops.py test_all_to_all_single" + return_code = os.system(command) + assert return_code == 0 + +@arg_mark(plat_marks=["platform_ascend"], level_mark="level0", card_mark="onecard", essential_mark="essential") +@pytest.mark.parametrize("pipeline", (True, False)) +def test_check_all_to_all_v_single_op(pipeline, monkeypatch): + """ + Feature: Check all_to_all_v_single op launch + Description: Check all_to_all_v_single op launch with cache + Expectation: The result is correct + """ + if pipeline: + monkeypatch.setenv("MRT_ENABLE_PIPELINE", "on") + command = "torchrun --nproc_per_node=2 tests/st/inferrt/distributed/check_distributed_ops.py test_all_to_all_v_single" return_code = os.system(command) assert return_code == 0 -- Gitee From d2971da98a855e71c9ef143b767378ce9a224b6e Mon Sep 17 00:00:00 2001 From: liuzihan000 Date: Fri, 17 Oct 2025 16:56:18 +0800 Subject: [PATCH 2/2] fix name style --- inferrt/src/ops/ascend/hccl/hccl_adapter.cc | 437 +++++++++--------- inferrt/src/ops/ascend/hccl/hccl_adapter.h | 224 ++++----- .../src/ops/ascend/hccl/hccl_all_gather.cc | 16 +- inferrt/src/ops/ascend/hccl/hccl_all_gather.h | 4 +- .../src/ops/ascend/hccl/hccl_all_reduce.cc | 26 +- inferrt/src/ops/ascend/hccl/hccl_all_reduce.h | 4 +- .../src/ops/ascend/hccl/hccl_all_to_all.cc | 64 +-- inferrt/src/ops/ascend/hccl/hccl_all_to_all.h | 6 +- inferrt/src/ops/ascend/hccl/hccl_kernel.cc | 2 +- inferrt/src/ops/ascend/hccl/hccl_kernel.h | 12 +- .../ops/ascend/hccl/hccl_reduce_scatter.cc | 28 +- .../src/ops/ascend/hccl/hccl_reduce_scatter.h | 4 +- inferrt/src/ops/ascend/hccl/hcom_utils.cc | 102 ++-- inferrt/src/ops/ascend/hccl/hcom_utils.h | 24 +- inferrt/src/ops/ascend/hccl/tensor_copy.cc | 14 +- inferrt/src/ops/ascend/hccl/tensor_copy.h | 3 - inferrt/src/ops/ascend/hccl/wait_tensor.cc | 14 +- inferrt/src/ops/ascend/hccl/wait_tensor.h | 5 +- 18 files changed, 488 insertions(+), 501 deletions(-) diff --git a/inferrt/src/ops/ascend/hccl/hccl_adapter.cc b/inferrt/src/ops/ascend/hccl/hccl_adapter.cc index 5fc223e3..d3232323 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_adapter.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_adapter.cc @@ -40,277 +40,276 @@ HcclAdapter &HcclAdapter::GetInstance() { } void HcclAdapter::InitPlugin() { - if (plugin_handle_ != nullptr) { + if (pluginHandle_ != nullptr) { return; } #ifndef ENABLE_ASAN - plugin_handle_ = dlopen(kHcclPluginFileName, RTLD_DEEPBIND | RTLD_NOW | RTLD_LOCAL); + pluginHandle_ = dlopen(kHcclPluginFileName, RTLD_DEEPBIND | RTLD_NOW | RTLD_LOCAL); #else - plugin_handle_ = dlopen(kHcclPluginFileName, RTLD_NOW | RTLD_LOCAL); + pluginHandle_ = dlopen(kHcclPluginFileName, RTLD_NOW | RTLD_LOCAL); #endif - if (plugin_handle_ == nullptr) { + if (pluginHandle_ == nullptr) { LOG_EXCEPTION << "Dlopen " << kHcclPluginFileName << " failed, result = " << GetDlErrorMsg(); } - launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_); - launch_hccl_reduce_scatter_ = DlsymFuncObj(HcclReduceScatter, plugin_handle_); - launch_hccl_all_gather_ = DlsymFuncObj(HcclAllGather, plugin_handle_); - launch_hccl_all_to_all_ = DlsymFuncObj(HcclAlltoAll, plugin_handle_); - launch_hccl_all_to_allv_ = DlsymFuncObj(HcclAlltoAllV, plugin_handle_); + launchHcclAllReduce_ = DlsymFuncObj(HcclAllReduce, pluginHandle_); + launchHcclReduceScatter_ = DlsymFuncObj(HcclReduceScatter, pluginHandle_); + launchHcclAllGather_ = DlsymFuncObj(HcclAllGather, pluginHandle_); + launchHcclAllToAll_ = DlsymFuncObj(HcclAlltoAll, pluginHandle_); + launchHcclAllToAllV_ = DlsymFuncObj(HcclAlltoAllV, pluginHandle_); } void HcclAdapter::FinalizePlugin() { - if (plugin_handle_ == nullptr) { + if (pluginHandle_ == nullptr) { return; } - set_hccl_global_comm_info_ = nullptr; - init_hccl_root_info_config_ = nullptr; - init_hccl_global_comm_ranktable_ = nullptr; - init_hccl_sub_comm_ranktable_ = nullptr; - get_hccl_comm_config_capability_ = nullptr; - init_hccl_comm_ = nullptr; - finalize_hccl_comm_ = nullptr; - launch_hccl_broadcast_ = nullptr; - launch_hccl_all_reduce_ = nullptr; - launch_hccl_reduce_ = nullptr; - launch_hccl_scatter_ = nullptr; - launch_hccl_reduce_scatter_ = nullptr; - launch_hccl_all_gather_ = nullptr; - launch_hccl_send_ = nullptr; - launch_hccl_recv_ = nullptr; - launch_hccl_barrier_ = nullptr; - launch_hccl_batch_isend_irecv_ = nullptr; - hccl_create_group_ = nullptr; - hccl_destroy_group_ = nullptr; - hccl_get_rank_id_ = nullptr; - hccl_get_local_rank_id_ = nullptr; - hccl_get_local_rank_size_ = nullptr; - hccl_get_world_rank_by_group_rank_ = nullptr; - hccl_get_group_rank_by_world_rank_ = nullptr; - hccl_get_rank_size_ = nullptr; - hccl_exec_enqueue_op_ = nullptr; - hccl_exec_enqueue_all_to_all_v_ = nullptr; - hccl_comm_working_dev_nic_set_ = nullptr; - launch_hccl_all_to_allv_ = nullptr; - launch_hccl_reduce_scatterv_ = nullptr; - launch_hccl_all_gatherv_ = nullptr; - launch_hccl_comm_resume_ = nullptr; - hcom_destroy_ = nullptr; - (void)dlclose(plugin_handle_); - plugin_handle_ = nullptr; -} - -std::string HcclAdapter::GetHcclModeString(HcclMode hccl_mode) { + setHcclGlobalCommInfo_ = nullptr; + initHcclRootInfoConfig_ = nullptr; + initHcclGlobalCommRanktable_ = nullptr; + initHcclSubCommRanktable_ = nullptr; + getHcclCommConfigCapability_ = nullptr; + initHcclComm_ = nullptr; + finalizeHcclComm_ = nullptr; + launchHcclBroadcast_ = nullptr; + launchHcclAllReduce_ = nullptr; + launchHcclReduce_ = nullptr; + launchHcclScatter_ = nullptr; + launchHcclReduceScatter_ = nullptr; + launchHcclAllGather_ = nullptr; + launchHcclSend_ = nullptr; + launchHcclRecv_ = nullptr; + launchHcclBarrier_ = nullptr; + launchHcclBatchISendIRecv_ = nullptr; + hcclCreateGroup_ = nullptr; + hcclDestroyGroup_ = nullptr; + hcclGetRankId_ = nullptr; + hcclGetLocalRankId_ = nullptr; + hcclGetLocalRankSize_ = nullptr; + hcclGetWorldRankByGroupRank_ = nullptr; + hcclGetGroupRankByWorldRank_ = nullptr; + hcclGetRankSize_ = nullptr; + hcclExecEnqueueOp_ = nullptr; + hcclExecEnqueueAllToAllV_ = nullptr; + hcclCommWorkingDevNicSet_ = nullptr; + launchHcclAllToAllV_ = nullptr; + launchHcclReduceScatterV_ = nullptr; + launchHcclAllGatherV_ = nullptr; + launchHcclCommResume_ = nullptr; + hcomDestroy_ = nullptr; + (void)dlclose(pluginHandle_); + pluginHandle_ = nullptr; +} + +std::string HcclAdapter::GetHcclModeString(HcclMode hcclMode) { static std::map kHcclModeString = {{HcclMode::kGraph, "GE_MODE"}, {HcclMode::kPynative, "PYNATIVE_MODE"}, {HcclMode::kKernelByKernel, "KERNEL_BY_KERNEL_MODE"}}; - return kHcclModeString.at(hccl_mode); + return kHcclModeString.at(hcclMode); } bool HcclAdapter::InitHccl() { LOG_OUT << "Start init hccl adapter."; - std::lock_guard lock(init_mutex_); - if (init_flag_) { + std::lock_guard lock(initMutex_); + if (initFlag_) { LOG_OUT << "Hccl has been inited, skip."; return true; } InitPlugin(); - init_flag_ = true; + initFlag_ = true; LOG_OUT << "Init hccl adapter success."; return true; } -bool HcclAdapter::HcclWatchdogThread(HcclComm comm, std::string *error_info, bool *disable) { - if (!init_flag_) { +bool HcclAdapter::HcclWatchdogThread(HcclComm comm, std::string *errorInfo, bool *disable) { + if (!initFlag_) { LOG_OUT << "Hccl has never been inited, skip."; return true; } CHECK_IF_NULL(disable); - if (hccl_get_comm_async_error_ == nullptr) { + if (hcclGetCommAsyncError_ == nullptr) { LOG_OUT << "Hccl has never been inited, skip."; return true; } - if (hccl_get_error_string_ == nullptr) { + if (hcclGetErrorString_ == nullptr) { LOG_OUT << "Hccl has never been inited, skip."; return true; } - HcclResult hccl_async_error; - auto ret = hccl_get_comm_async_error_(comm, &hccl_async_error); + HcclResult hcclAsyncError; + auto ret = hcclGetCommAsyncError_(comm, &hcclAsyncError); if (ret != HCCL_SUCCESS) { LOG_OUT << "Call HcclGetCommAsyncError failed, close watchdog."; *disable = true; return true; } - if (hccl_async_error != HCCL_SUCCESS) { + if (hcclAsyncError != HCCL_SUCCESS) { std::ostringstream oss; - oss << "Hccl get comm async error failed, error code is: " << hccl_async_error - << ", detail info: " << hccl_get_error_string_(hccl_async_error); - *error_info = oss.str(); + oss << "Hccl get comm async error failed, error code is: " << hcclAsyncError + << ", detail info: " << hcclGetErrorString_(hcclAsyncError); + *errorInfo = oss.str(); return false; } return true; } bool HcclAdapter::FinalizeHccl() { - std::lock_guard lock(init_mutex_); - LOG_OUT << "Start destroy hccl adapter for " << GetHcclModeString(hccl_mode_); - if (!init_flag_) { + std::lock_guard lock(initMutex_); + LOG_OUT << "Start destroy hccl adapter for " << GetHcclModeString(hcclMode_); + if (!initFlag_) { LOG_OUT << "Hccl has never been inited, skip."; return true; } (void)FinalizeHcclExec(); (void)FinalizeKernelInfoStore(); (void)FinalizeHcclComm(); - if (hcom_destroy_ != nullptr) { - hcom_destroy_(); + if (hcomDestroy_ != nullptr) { + hcomDestroy_(); } FinalizePlugin(); - init_flag_ = false; + initFlag_ = false; LOG_OUT << "Destroy hccl adapter success."; return true; } HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, - aclrtStream stream, HcclComm hccl_comm) const { - HcclResult ret = launch_hccl_broadcast_(buf, count, dataType, root, hccl_comm, stream); + aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclBroadcast_(buf, count, dataType, root, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, - const HcclReduceOp op, const aclrtStream stream, HcclComm hccl_comm) const { - HcclResult ret = launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); +HcclResult HcclAdapter::HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclAllReduce_(sendBuf, recvBuf, count, dataType, op, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, - HcclReduceOp op, uint32_t root, const aclrtStream stream, HcclComm hccl_comm) const { - HcclResult ret = launch_hccl_reduce_(send_buf, recv_buf, count, dataType, op, root, hccl_comm, stream); +HcclResult HcclAdapter::HcclReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + HcclReduceOp op, uint32_t root, const aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclReduce_(sendBuf, recvBuf, count, dataType, op, root, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, - uint32_t root, HcclComm comm, aclrtStream stream) const { - HcclResult ret = launch_hccl_scatter_(send_buf, recv_buf, count, dataType, root, comm, stream); +HcclResult HcclAdapter::HcclScatter(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + uint32_t root, const aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclScatter_(sendBuf, recvBuf, count, dataType, root, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, - const HcclReduceOp op, const aclrtStream stream, HcclComm hccl_comm) const { - HcclResult ret = launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream); +HcclResult HcclAdapter::HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclReduceScatter_(sendBuf, recvBuf, count, dataType, op, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, - const aclrtStream stream, HcclComm hccl_comm) const { - CHECK_SYMBOL_NULL(launch_hccl_all_gather_); - CHECK_IF_NULL(hccl_comm); - CHECK_IF_NULL(send_buf); - CHECK_IF_NULL(recv_buf); - HcclResult ret = launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream); +HcclResult HcclAdapter::HcclAllGather(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + const aclrtStream stream, HcclComm hcclComm) const { + CHECK_SYMBOL_NULL(launchHcclAllGather_); + CHECK_IF_NULL(hcclComm); + CHECK_IF_NULL(sendBuf); + CHECK_IF_NULL(recvBuf); + HcclResult ret = launchHcclAllGather_(sendBuf, recvBuf, count, dataType, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, - const aclrtStream stream, HcclComm hccl_comm) const { - HcclResult ret = launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream); +HcclResult HcclAdapter::HcclSend(void *sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, + const aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclSend_(sendBuf, count, dataType, destRank, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, - const aclrtStream stream, HcclComm hccl_comm) const { - HcclResult ret = launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream); +HcclResult HcclAdapter::HcclRecv(void *recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, + const aclrtStream stream, HcclComm hcclComm) const { + HcclResult ret = launchHcclRecv_(recvBuf, count, dataType, srcRank, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclBarrier(const aclrtStream stream, HcclComm hccl_comm) const { - return launch_hccl_barrier_(hccl_comm, stream); +HcclResult HcclAdapter::HcclBarrier(const aclrtStream stream, HcclComm hcclComm) const { + return launchHcclBarrier_(hcclComm, stream); } -HcclResult HcclAdapter::HcclBatchISendIRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm comm, +HcclResult HcclAdapter::HcclBatchISendIRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm hcclComm, aclrtStream stream) const { - HcclResult ret = launch_hccl_batch_isend_irecv_(sendRecvInfo, itemNum, comm, stream); + HcclResult ret = launchHcclBatchISendIRecv_(sendRecvInfo, itemNum, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclCommResume(HcclComm comm) const { - if (launch_hccl_comm_resume_ == nullptr) { +HcclResult HcclAdapter::HcclCommResume(HcclComm hcclComm) const { + if (launchHcclCommResume_ == nullptr) { LOG_EXCEPTION << "Dynamically load HcclCommResume failed."; } - return launch_hccl_comm_resume_(comm); + return launchHcclCommResume_(hcclComm); } uint32_t HcclAdapter::HcclGetCommConfigCapability() { - CHECK_IF_NULL(get_hccl_comm_config_capability_); - return get_hccl_comm_config_capability_(); + CHECK_IF_NULL(getHcclCommConfigCapability_); + return getHcclCommConfigCapability_(); } HcclResult HcclAdapter::HcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_t totalRankSize, uint32_t nodeId, uint32_t localRankSize) { - if (set_hccl_global_comm_info_ == nullptr) { - set_hccl_global_comm_info_ = DlsymAscendFuncObj(HcclSetGlobalCommInfo, plugin_handle_); - if (set_hccl_global_comm_info_ == nullptr) { + if (setHcclGlobalCommInfo_ == nullptr) { + setHcclGlobalCommInfo_ = DlsymAscendFuncObj(HcclSetGlobalCommInfo, pluginHandle_); + if (setHcclGlobalCommInfo_ == nullptr) { LOG_OUT << "Func HcclSetGlobalCommInfo is not supported in CANN package."; return HCCL_E_NOT_SUPPORT; } } - return set_hccl_global_comm_info_(masterIp, masterPort, totalRankSize, nodeId, localRankSize); + return setHcclGlobalCommInfo_(masterIp, masterPort, totalRankSize, nodeId, localRankSize); } -HcclResult HcclAdapter::HcclCommInitClusterInfoConfig(const char *rank_table, uint32_t rank_id, HcclCommConfig *config, - HcclComm *hccl_comm) { - if (init_hccl_global_comm_ranktable_ == nullptr) { - init_hccl_global_comm_ranktable_ = DlsymFuncObj(HcclCommInitClusterInfoConfig, plugin_handle_); +HcclResult HcclAdapter::HcclCommInitClusterInfoConfig(const char *rankTable, uint32_t rankId, HcclCommConfig *config, + HcclComm *hcclComm) { + if (initHcclGlobalCommRanktable_ == nullptr) { + initHcclGlobalCommRanktable_ = DlsymFuncObj(HcclCommInitClusterInfoConfig, pluginHandle_); } - return init_hccl_global_comm_ranktable_(rank_table, rank_id, config, hccl_comm); + return initHcclGlobalCommRanktable_(rankTable, rankId, config, hcclComm); } - -HcclResult HcclAdapter::HcclCommInitRootInfoConfig(uint32_t n_ranks, const HcclRootInfo *root_info, uint32_t rank, - const HcclCommConfig *config, HcclComm *hccl_comm_) { - if (init_hccl_root_info_config_ == nullptr) { - init_hccl_root_info_config_ = DlsymFuncObj(HcclCommInitRootInfoConfig, plugin_handle_); - if (init_hccl_root_info_config_ == nullptr) { +HcclResult HcclAdapter::HcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, + const HcclCommConfig *config, HcclComm *hcclComm) { + if (initHcclRootInfoConfig_ == nullptr) { + initHcclRootInfoConfig_ = DlsymFuncObj(HcclCommInitRootInfoConfig, pluginHandle_); + if (initHcclRootInfoConfig_ == nullptr) { // new api in CANN C20 - return HcclCommInitRootInfo(n_ranks, root_info, rank, hccl_comm_); + return HcclCommInitRootInfo(nRanks, rootInfo, rank, hcclComm); } } - return init_hccl_root_info_config_(n_ranks, root_info, rank, config, hccl_comm_); + return initHcclRootInfoConfig_(nRanks, rootInfo, rank, config, hcclComm); } -HcclResult HcclAdapter::HcclCreateSubCommConfig(HcclComm *global_comm, uint32_t rank_size, uint32_t *rank_ids, - uint64_t comm_id, uint32_t rank_id, HcclCommConfig *config, - HcclComm *hccl_comm) { - if (init_hccl_sub_comm_ranktable_ == nullptr) { - init_hccl_sub_comm_ranktable_ = DlsymFuncObj(HcclCreateSubCommConfig, plugin_handle_); +HcclResult HcclAdapter::HcclCreateSubCommConfig(HcclComm *globalComm, uint32_t rankSize, uint32_t *rankIds, + uint64_t commId, uint32_t rankId, HcclCommConfig *config, + HcclComm *hcclComm) { + if (initHcclSubCommRanktable_ == nullptr) { + initHcclSubCommRanktable_ = DlsymFuncObj(HcclCreateSubCommConfig, pluginHandle_); } - return init_hccl_sub_comm_ranktable_(global_comm, rank_size, rank_ids, comm_id, rank_id, config, hccl_comm); + return initHcclSubCommRanktable_(globalComm, rankSize, rankIds, commId, rankId, config, hcclComm); } -bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_file) { +bool HcclAdapter::InitHcclComm(std::string_view rankId, std::string_view rankFile) { LOG_OUT << "Start init hccl comm."; - int rank_id_i = -1; + int rankIdI = -1; try { - rank_id_i = std::stoi(rank_id.data()); + rankIdI = std::stoi(rankId.data()); } catch (std::invalid_argument &) { - LOG_EXCEPTION << "Invalid rank id env:" << rank_id; + LOG_EXCEPTION << "Invalid rank id env:" << rankId; } - if (rank_id_i < 0) { + if (rankIdI < 0) { LOG_ERROR << "rank_id cannot be negative"; return false; } - CHECK_IF_NULL(init_hccl_comm_); - auto hccl_result = init_hccl_comm_(rank_file.data(), rank_id_i, &hccl_comm_); - if (hccl_result != HCCL_SUCCESS) { - LOG_ERROR << "HcclCommInitClusterInfo failed, ret:" << hccl_result; + CHECK_IF_NULL(initHcclComm_); + auto hcclResult = initHcclComm_(rankFile.data(), rankIdI, &hcclComm_); + if (hcclResult != HCCL_SUCCESS) { + LOG_ERROR << "HcclCommInitClusterInfo failed, ret:" << hcclResult; return false; } LOG_OUT << "InitHcclComm success"; @@ -319,104 +318,103 @@ bool HcclAdapter::InitHcclComm(std::string_view rank_id, std::string_view rank_f bool HcclAdapter::FinalizeHcclComm() { LOG_OUT << "Start finalize hccl comm."; - if (hccl_comm_ == nullptr) { + if (hcclComm_ == nullptr) { return true; } - CHECK_IF_NULL(finalize_hccl_comm_); - auto hccl_result = finalize_hccl_comm_(hccl_comm_); - if (hccl_result != HCCL_SUCCESS) { - LOG_ERROR << "HcclComm destroy failed, ret:" << hccl_result; + CHECK_IF_NULL(finalizeHcclComm_); + auto hcclResult = finalizeHcclComm_(hcclComm_); + if (hcclResult != HCCL_SUCCESS) { + LOG_ERROR << "HcclComm destroy failed, ret:" << hcclResult; return false; } - hccl_comm_ = nullptr; + hcclComm_ = nullptr; LOG_OUT << "HcclComm destroy success"; return true; } -HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const { - CHECK_SYMBOL_NULL(hccl_create_group_); - return hccl_create_group_(group.c_str(), rank_num, rank_ids); +HcclResult HcclAdapter::HcclCreateGroup(const std::string &group, uint32_t rankNum, uint32_t *rankIds) const { + CHECK_SYMBOL_NULL(hcclCreateGroup_); + return hcclCreateGroup_(group.c_str(), rankNum, rankIds); } - HcclResult HcclAdapter::HcclDestroyGroup(const std::string &group) const { - CHECK_SYMBOL_NULL(hccl_destroy_group_); - return hccl_destroy_group_(group.c_str()); + CHECK_SYMBOL_NULL(hcclDestroyGroup_); + return hcclDestroyGroup_(group.c_str()); } -HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rank_id) const { - if (hccl_mode_ != HcclMode::kGraph) { - CHECK_SYMBOL_NULL(single_op_hccl_get_rank_id_); - return single_op_hccl_get_rank_id_(hccl_comm_, rank_id); +HcclResult HcclAdapter::HcclGetRankId(const std::string &group, uint32_t *rankId) const { + if (hcclMode_ != HcclMode::kGraph) { + CHECK_SYMBOL_NULL(singleOpHcclGetRankId_); + return singleOpHcclGetRankId_(hcclComm_, rankId); } else { - CHECK_SYMBOL_NULL(hccl_get_rank_id_); - return hccl_get_rank_id_(group.c_str(), rank_id); + CHECK_SYMBOL_NULL(hcclGetRankId_); + return hcclGetRankId_(group.c_str(), rankId); } } -HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rank_size) const { - if (hccl_mode_ != HcclMode::kGraph) { - CHECK_SYMBOL_NULL(single_op_hccl_get_rank_size_); - return single_op_hccl_get_rank_size_(hccl_comm_, rank_size); +HcclResult HcclAdapter::HcclGetRankSize(const std::string &group, uint32_t *rankSize) const { + if (hcclMode_ != HcclMode::kGraph) { + CHECK_SYMBOL_NULL(singleOpHcclGetRankSize_); + return singleOpHcclGetRankSize_(hcclComm_, rankSize); } else { - CHECK_SYMBOL_NULL(hccl_get_rank_size_); - return hccl_get_rank_size_(group.c_str(), rank_size); + CHECK_SYMBOL_NULL(hcclGetRankSize_); + return hcclGetRankSize_(group.c_str(), rankSize); } } -HcclResult HcclAdapter::HcclGetLocalRankId(const std::string &group, uint32_t *local_rank_id) const { - CHECK_SYMBOL_NULL(hccl_get_local_rank_id_); - return hccl_get_local_rank_id_(group.c_str(), local_rank_id); +HcclResult HcclAdapter::HcclGetLocalRankId(const std::string &group, uint32_t *localRankId) const { + CHECK_SYMBOL_NULL(hcclGetLocalRankId_); + return hcclGetLocalRankId_(group.c_str(), localRankId); } -HcclResult HcclAdapter::HcclGetLocalRankSize(const std::string &group, uint32_t *local_rank_size) const { - if (hccl_mode_ != HcclMode::kGraph) { - LOG_ERROR << "The pynative mode doesn't support get local rank szie."; +HcclResult HcclAdapter::HcclGetLocalRankSize(const std::string &group, uint32_t *localRankSize) const { + if (hcclMode_ != HcclMode::kGraph) { + LOG_ERROR << "The pynative mode doesn't support get local rank size."; return HCCL_E_NOT_SUPPORT; } else { - CHECK_SYMBOL_NULL(hccl_get_local_rank_size_); - return hccl_get_local_rank_size_(group.c_str(), local_rank_size); + CHECK_SYMBOL_NULL(hcclGetLocalRankSize_); + return hcclGetLocalRankSize_(group.c_str(), localRankSize); } } -HcclResult HcclAdapter::HcclGetWorldRankFromGroupRank(const std::string &group, uint32_t local_rank, - uint32_t *world_rank) const { - if (hccl_mode_ != HcclMode::kGraph) { +HcclResult HcclAdapter::HcclGetWorldRankFromGroupRank(const std::string &group, uint32_t localRank, + uint32_t *worldRank) const { + if (hcclMode_ != HcclMode::kGraph) { LOG_ERROR << "The pynative mode doesn't support get world rank by group rank."; return HCCL_E_NOT_SUPPORT; } else { - CHECK_SYMBOL_NULL(hccl_get_world_rank_by_group_rank_); - return hccl_get_world_rank_by_group_rank_(group.c_str(), local_rank, world_rank); + CHECK_SYMBOL_NULL(hcclGetWorldRankByGroupRank_); + return hcclGetWorldRankByGroupRank_(group.c_str(), localRank, worldRank); } } -HcclResult HcclAdapter::HcclGetGroupRankFromWorldRank(uint32_t world_rank, const std::string &group, - uint32_t *local_rank) const { - if (hccl_mode_ != HcclMode::kGraph) { +HcclResult HcclAdapter::HcclGetGroupRankFromWorldRank(uint32_t worldRank, const std::string &group, + uint32_t *localRank) const { + if (hcclMode_ != HcclMode::kGraph) { LOG_ERROR << "The pynative mode doesn't support get group rank by world rank."; return HCCL_E_NOT_SUPPORT; } else { - CHECK_SYMBOL_NULL(hccl_get_group_rank_by_world_rank_); - return hccl_get_group_rank_by_world_rank_(world_rank, group.c_str(), local_rank); + CHECK_SYMBOL_NULL(hcclGetGroupRankByWorldRank_); + return hcclGetGroupRankByWorldRank_(worldRank, group.c_str(), localRank); } } -HcclResult HcclAdapter::HcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks) { - if (hccl_comm_working_dev_nic_set_ == nullptr) { - hccl_comm_working_dev_nic_set_ = DlsymFuncObj(HcclCommWorkingDevNicSet, plugin_handle_); +HcclResult HcclAdapter::HcclCommWorkingDevNicSet(HcclComm hcclComm, uint32_t *ranks, bool *useBackup, uint32_t nRanks) { + if (hcclCommWorkingDevNicSet_ == nullptr) { + hcclCommWorkingDevNicSet_ = DlsymFuncObj(HcclCommWorkingDevNicSet, pluginHandle_); } - CHECK_SYMBOL_NULL(hccl_comm_working_dev_nic_set_); - return hccl_comm_working_dev_nic_set_(comm, ranks, useBackup, nRanks); + CHECK_SYMBOL_NULL(hcclCommWorkingDevNicSet_); + return hcclCommWorkingDevNicSet_(hcclComm, ranks, useBackup, nRanks); } -HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const { - CHECK_SYMBOL_NULL(hccl_exec_enqueue_op_); - return hccl_exec_enqueue_op_(op_info, callback); +HcclResult HcclAdapter::HcclExecEnqueueOp(const ::HcomOperation &opInfo, const HExecCallBack &callback) const { + CHECK_SYMBOL_NULL(hcclExecEnqueueOp_); + return hcclExecEnqueueOp_(opInfo, callback); } HcclResult HcclAdapter::HcclExecAlltoAllV(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const { - CHECK_SYMBOL_NULL(hccl_exec_enqueue_all_to_all_v_); - return hccl_exec_enqueue_all_to_all_v_(params, callback); + CHECK_SYMBOL_NULL(hcclExecEnqueueAllToAllV_); + return hcclExecEnqueueAllToAllV_(params, callback); } bool HcclAdapter::UseHcclCM() const { @@ -424,53 +422,52 @@ bool HcclAdapter::UseHcclCM() const { return false; } -HcclResult HcclAdapter::HcclAlltoAllV(void *send_buf, void *recv_buf, HcclAllToAllVParams params, HcclDataType dataType, - aclrtStream stream, HcclComm hccl_comm) const { - CHECK_SYMBOL_NULL(launch_hccl_all_to_allv_); - CHECK_IF_NULL(hccl_comm); - HcclResult ret = - launch_hccl_all_to_allv_(send_buf, params.sendcounts.data(), params.sdispls.data(), dataType, recv_buf, - params.recvcounts.data(), params.rdispls.data(), dataType, hccl_comm, stream); +HcclResult HcclAdapter::HcclAlltoAllV(void *sendBuf, void *recvBuf, HcclAllToAllVParams params, HcclDataType dataType, + aclrtStream stream, HcclComm hcclComm) const { + CHECK_SYMBOL_NULL(launchHcclAllToAllV_); + CHECK_IF_NULL(hcclComm); + HcclResult ret = launchHcclAllToAllV_(sendBuf, params.sendCounts.data(), params.sdispls.data(), dataType, recvBuf, + params.recvCounts.data(), params.rdispls.data(), dataType, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclReduceScatterV(void *send_buf, void *recv_buf, HcclReduceScatterVParams params, - HcclDataType data_type, const HcclReduceOp op, const aclrtStream stream, - HcclComm hccl_comm) const { - CHECK_SYMBOL_NULL(launch_hccl_reduce_scatterv_); - CHECK_IF_NULL(hccl_comm); - HcclResult ret = launch_hccl_reduce_scatterv_(send_buf, params.send_counts.data(), params.sdispls.data(), recv_buf, - params.recv_count, data_type, op, hccl_comm, stream); +HcclResult HcclAdapter::HcclReduceScatterV(void *sendBuf, void *recvBuf, HcclReduceScatterVParams params, + HcclDataType dataType, const HcclReduceOp op, const aclrtStream stream, + HcclComm hcclComm) const { + CHECK_SYMBOL_NULL(launchHcclReduceScatterV_); + CHECK_IF_NULL(hcclComm); + HcclResult ret = launchHcclReduceScatterV_(sendBuf, params.sendCounts.data(), params.sdispls.data(), recvBuf, + params.recvCount, dataType, op, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclAllGatherV(void *send_buf, void *recv_buf, HcclAllGatherVParams params, - HcclDataType data_type, const aclrtStream stream, HcclComm hccl_comm) const { - CHECK_SYMBOL_NULL(launch_hccl_all_gatherv_); - CHECK_IF_NULL(hccl_comm); - HcclResult ret = launch_hccl_all_gatherv_(send_buf, params.send_count, recv_buf, params.recv_counts.data(), - params.rdispls.data(), data_type, hccl_comm, stream); +HcclResult HcclAdapter::HcclAllGatherV(void *sendBuf, void *recvBuf, HcclAllGatherVParams params, HcclDataType dataType, + const aclrtStream stream, HcclComm hcclComm) const { + CHECK_SYMBOL_NULL(launchHcclAllGatherV_); + CHECK_IF_NULL(hcclComm); + HcclResult ret = launchHcclAllGatherV_(sendBuf, params.sendCount, recvBuf, + params.recvCounts.data(), params.rdispls.data(), dataType, hcclComm, stream); return ret; } -HcclResult HcclAdapter::HcclAllToAll(void *send_buf, void *recv_buf, HcclAllToAllParams params, HcclDataType dataType, - aclrtStream stream, HcclComm hccl_comm) const { - CHECK_SYMBOL_NULL(launch_hccl_all_to_all_); - CHECK_IF_NULL(hccl_comm); +HcclResult HcclAdapter::HcclAllToAll(void *sendBuf, void *recvBuf, HcclAllToAllParams params, HcclDataType dataType, + aclrtStream stream, HcclComm hcclComm) const { + CHECK_SYMBOL_NULL(launchHcclAllToAll_); + CHECK_IF_NULL(hcclComm); - HcclResult ret = launch_hccl_all_to_all_(send_buf, params.sendcount, dataType, recv_buf, params.recvcount, dataType, - hccl_comm, stream); + HcclResult ret = + launchHcclAllToAll_(sendBuf, params.sendCount, dataType, recvBuf, params.recvCount, dataType, hcclComm, stream); return ret; } -bool HcclAdapter::IsSameServer(const std::vector &rank_ids) const { - auto min_iter = min_element(rank_ids.begin(), rank_ids.end()); - uint32_t min = (min_iter != rank_ids.end()) ? *min_iter : 0; - auto max_iter = max_element(rank_ids.begin(), rank_ids.end()); - uint32_t max = (max_iter != rank_ids.end()) ? *max_iter : 0; +bool HcclAdapter::IsSameServer(const std::vector &rankIds) const { + auto minIter = min_element(rankIds.begin(), rankIds.end()); + uint32_t min = (minIter != rankIds.end()) ? *minIter : 0; + auto maxIter = max_element(rankIds.begin(), rankIds.end()); + uint32_t max = (maxIter != rankIds.end()) ? *maxIter : 0; return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer)); } -} // namespace mrt::ops +} // namespace mrt::ops \ No newline at end of file diff --git a/inferrt/src/ops/ascend/hccl/hccl_adapter.h b/inferrt/src/ops/ascend/hccl/hccl_adapter.h index 8d977ff9..aa517aa0 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_adapter.h +++ b/inferrt/src/ops/ascend/hccl/hccl_adapter.h @@ -34,27 +34,27 @@ struct HcclTaskInfo { }; struct HcclAllToAllVParams { - std::vector sendcounts; + std::vector sendCounts; std::vector sdispls; - std::vector recvcounts; + std::vector recvCounts; std::vector rdispls; }; struct HcclAllGatherVParams { - uint64_t send_count; - std::vector recv_counts; + uint64_t sendCount; + std::vector recvCounts; std::vector rdispls; }; struct HcclReduceScatterVParams { - std::vector send_counts; + std::vector sendCounts; std::vector sdispls; - uint64_t recv_count; + uint64_t recvCount; }; struct HcclAllToAllParams { - uint64_t sendcount; - uint64_t recvcount; + uint64_t sendCount; + uint64_t recvCount; }; enum HcclMode { kGraph, kPynative, kKernelByKernel }; @@ -64,74 +64,74 @@ class HcclAdapter { static HcclAdapter &GetInstance(); // common - bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file, HcclMode hccl_mode); + bool InitHccl(uint32_t deviceId, std::string_view rankId, std::string_view rankFile, HcclMode hcclMode); bool InitHccl(); uint32_t HcclGetCommConfigCapability(); HcclResult HcclSetGlobalCommInfo(uint32_t masterIp, uint32_t masterPort, uint32_t totalRankSize, uint32_t nodeId, uint32_t localRankSize); - HcclResult HcclCommInitClusterInfoConfig(const char *rank_table, uint32_t rank_id, HcclCommConfig *config, - HcclComm *hccl_comm_); - HcclResult HcclCommInitRootInfoConfig(uint32_t n_ranks, const HcclRootInfo *root_info, uint32_t rank, - const HcclCommConfig *config, HcclComm *hccl_comm_); - HcclResult HcclCreateSubCommConfig(HcclComm *global_comm, uint32_t rank_size, uint32_t *rank_ids, uint64_t comm_id, - uint32_t rank_id, HcclCommConfig *config, HcclComm *hccl_comm_); + HcclResult HcclCommInitClusterInfoConfig(const char *rankTable, uint32_t rankId, HcclCommConfig *config, + HcclComm *hcclComm); + HcclResult HcclCommInitRootInfoConfig(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, + const HcclCommConfig *config, HcclComm *hcclComm); + HcclResult HcclCreateSubCommConfig(HcclComm *globalComm, uint32_t rankSize, uint32_t *rankIds, uint64_t commId, + uint32_t rankId, HcclCommConfig *config, HcclComm *hcclComm); bool FinalizeHccl(); - bool HcclWatchdogThread(HcclComm comm, std::string *error_info, bool *ret); - const bool Inited() const { return init_flag_; } - const HcclComm get_hccl_comm() const { return hccl_comm_; } - HcclResult HcclCreateGroup(const std::string &group, uint32_t rank_num, uint32_t *rank_ids) const; + bool HcclWatchdogThread(HcclComm comm, std::string *errorInfo, bool *ret); + const bool Inited() const { return initFlag_; } + const HcclComm get_hccl_comm() const { return hcclComm_; } + HcclResult HcclCreateGroup(const std::string &group, uint32_t rankNum, uint32_t *rankIds) const; HcclResult HcclDestroyGroup(const std::string &group) const; - HcclResult HcclGetRankId(const std::string &group, uint32_t *rank_id) const; - HcclResult HcclGetRankSize(const std::string &group, uint32_t *rank_size) const; - HcclResult HcclGetLocalRankId(const std::string &group, uint32_t *lcoal_rank_id) const; - HcclResult HcclGetLocalRankSize(const std::string &group, uint32_t *local_rank_size) const; - HcclResult HcclGetWorldRankFromGroupRank(const std::string &group, uint32_t local_rank, uint32_t *world_rank) const; - HcclResult HcclGetGroupRankFromWorldRank(uint32_t world_rank, const std::string &group, uint32_t *local_rank) const; + HcclResult HcclGetRankId(const std::string &group, uint32_t *rankId) const; + HcclResult HcclGetRankSize(const std::string &group, uint32_t *rankSize) const; + HcclResult HcclGetLocalRankId(const std::string &group, uint32_t *localRankId) const; + HcclResult HcclGetLocalRankSize(const std::string &group, uint32_t *localRankSize) const; + HcclResult HcclGetWorldRankFromGroupRank(const std::string &group, uint32_t localRank, uint32_t *worldRank) const; + HcclResult HcclGetGroupRankFromWorldRank(uint32_t worldRank, const std::string &group, uint32_t *localRank) const; // for single op HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, aclrtStream stream, - HcclComm comm) const; - HcclResult HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, - const aclrtStream stream, HcclComm comm) const; - HcclResult HcclReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, - uint32_t root, const aclrtStream stream, HcclComm comm) const; - HcclResult HcclScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t root, - HcclComm comm, aclrtStream stream) const; - HcclResult HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, - const aclrtStream stream, HcclComm comm) const; - HcclResult HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType, HcclReduceOp op, - const aclrtStream stream, HcclComm comm) const; - HcclResult HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank, - const aclrtStream stream, HcclComm comm) const; - HcclResult HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank, const aclrtStream stream, - HcclComm comm) const; - HcclResult HcclAlltoAllV(void *send_buf, void *recv_buf, HcclAllToAllVParams params, HcclDataType dataType, - const aclrtStream stream, HcclComm comm) const; - - HcclResult HcclReduceScatterV(void *send_buf, void *recv_buf, HcclReduceScatterVParams params, HcclDataType data_type, - const HcclReduceOp op, const aclrtStream stream, HcclComm hccl_comm) const; - - HcclResult HcclAllGatherV(void *send_buf, void *recv_buf, HcclAllGatherVParams params, HcclDataType data_type, - const aclrtStream stream, HcclComm hccl_comm) const; - - HcclResult HcclAllToAll(void *send_buf, void *recv_buf, HcclAllToAllParams params, HcclDataType dataType, - const aclrtStream stream, HcclComm comm) const; - HcclResult HcclBarrier(const aclrtStream stream, HcclComm comm) const; - HcclResult HcclBatchISendIRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm comm, + HcclComm hcclComm) const; + HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + uint32_t root, const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclScatter(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, uint32_t root, + HcclComm hcclComm, aclrtStream stream) const; + HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, + const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, + const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclSend(void *sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, + const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclRecv(void *recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, const aclrtStream stream, + HcclComm hcclComm) const; + HcclResult HcclAlltoAllV(void *sendBuf, void *recvBuf, HcclAllToAllVParams params, HcclDataType dataType, + const aclrtStream stream, HcclComm hcclComm) const; + + HcclResult HcclReduceScatterV(void *sendBuf, void *recvBuf, HcclReduceScatterVParams params, HcclDataType dataType, + const HcclReduceOp op, const aclrtStream stream, HcclComm hcclComm) const; + + HcclResult HcclAllGatherV(void *sendBuf, void *recvBuf, HcclAllGatherVParams params, HcclDataType dataType, + const aclrtStream stream, HcclComm hcclComm) const; + + HcclResult HcclAllToAll(void *sendBuf, void *recvBuf, HcclAllToAllParams params, HcclDataType dataType, + const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclBarrier(const aclrtStream stream, HcclComm hcclComm) const; + HcclResult HcclBatchISendIRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm hcclComm, aclrtStream stream) const; // for enqueue op - HcclResult HcclExecEnqueueOp(const ::HcomOperation &op_info, const HExecCallBack &callback) const; + HcclResult HcclExecEnqueueOp(const ::HcomOperation &opInfo, const HExecCallBack &callback) const; HcclResult HcclExecAlltoAllV(const ::HcomAllToAllVParams ¶ms, const HExecCallBack &callback) const; - HcclResult HcclCommResume(HcclComm comm) const; + HcclResult HcclCommResume(HcclComm hcclComm) const; - HcclResult HcclCommWorkingDevNicSet(HcclComm comm, uint32_t *ranks, bool *useBackup, uint32_t nRanks); + HcclResult HcclCommWorkingDevNicSet(HcclComm hcclComm, uint32_t *ranks, bool *useBackup, uint32_t nRanks); // Return whether using CM to initialize HCCL. bool UseHcclCM() const; - static void AddCMEnvToHcclOption(std::map *hccl_opt_map); + static void AddCMEnvToHcclOption(std::map *hcclOptMap); - bool IsSameServer(const std::vector &rank_ids) const; + bool IsSameServer(const std::vector &rankIds) const; private: HcclAdapter() = default; @@ -142,67 +142,67 @@ class HcclAdapter { bool InitKernelInfoStore(const std::map options); bool FinalizeKernelInfoStore(); - bool InitHcclComm(std::string_view rank_id, std::string_view rank_file); + bool InitHcclComm(std::string_view rankId, std::string_view rankFile); bool FinalizeHcclComm(); bool InitHcclExec(); bool FinalizeHcclExec(); - static std::string GetHcclModeString(HcclMode hccl_mode); + static std::string GetHcclModeString(HcclMode hcclMode); static bool IsSimulation(); - void *plugin_handle_ = nullptr; - - HcomDestroyFunObj hcom_destroy_ = nullptr; - - HcclGetCommConfigCapabilityFunObj get_hccl_comm_config_capability_ = nullptr; - HcclSetGlobalCommInfoFunObj set_hccl_global_comm_info_ = nullptr; - HcclCommInitClusterInfoFunObj init_hccl_comm_ = nullptr; - HcclCommInitClusterInfoConfigFunObj init_hccl_global_comm_ranktable_ = nullptr; - HcclCommInitRootInfoConfigFunObj init_hccl_root_info_config_ = nullptr; - HcclCreateSubCommConfigFunObj init_hccl_sub_comm_ranktable_ = nullptr; - HcclCommDestroyFunObj finalize_hccl_comm_ = nullptr; - HcclBroadcastFunObj launch_hccl_broadcast_ = nullptr; - HcclAllReduceFunObj launch_hccl_all_reduce_ = nullptr; - HcclReduceFunObj launch_hccl_reduce_ = nullptr; - HcclScatterFunObj launch_hccl_scatter_ = nullptr; - HcclReduceScatterFunObj launch_hccl_reduce_scatter_ = nullptr; - HcclAllGatherFunObj launch_hccl_all_gather_ = nullptr; - HcclSendFunObj launch_hccl_send_ = nullptr; - HcclRecvFunObj launch_hccl_recv_ = nullptr; - HcclBarrierFunObj launch_hccl_barrier_ = nullptr; - HcclGetRankIdFunObj single_op_hccl_get_rank_id_ = nullptr; - HcclGetRankSizeFunObj single_op_hccl_get_rank_size_ = nullptr; - HcclAlltoAllVFunObj launch_hccl_all_to_allv_ = nullptr; - HcclReduceScatterVFunObj launch_hccl_reduce_scatterv_ = nullptr; - HcclAllGatherVFunObj launch_hccl_all_gatherv_ = nullptr; - HcclAlltoAllFunObj launch_hccl_all_to_all_ = nullptr; - HcclBatchSendRecvFunObj launch_hccl_batch_isend_irecv_ = nullptr; - HcclCommResumeFunObj launch_hccl_comm_resume_ = nullptr; - HcclGetCommAsyncErrorFunObj hccl_get_comm_async_error_ = nullptr; - HcclGetErrorStringFunObj hccl_get_error_string_ = nullptr; - HcomCreateGroupFunObj hccl_create_group_ = nullptr; - HcomDestroyGroupFunObj hccl_destroy_group_ = nullptr; - HcomGetRankIdFunObj hccl_get_rank_id_ = nullptr; - HcomGetRankSizeFunObj hccl_get_rank_size_ = nullptr; - HcomGetLocalRankIdFunObj hccl_get_local_rank_id_ = nullptr; - HcomGetLocalRankSizeFunObj hccl_get_local_rank_size_ = nullptr; - HcomGetWorldRankFromGroupRankFunObj hccl_get_world_rank_by_group_rank_ = nullptr; - HcomGetGroupRankFromWorldRankFunObj hccl_get_group_rank_by_world_rank_ = nullptr; - HcclCommWorkingDevNicSetFunObj hccl_comm_working_dev_nic_set_ = nullptr; - - HcomExecInitializeFunObj hccl_exec_initialize_ = nullptr; - HcomExecFinalizeFunObj hccl_exec_finalize_ = nullptr; - HcomExecEnqueueOperationFunObj hccl_exec_enqueue_op_ = nullptr; - HcomExecEnqueueAllToAllVFunObj hccl_exec_enqueue_all_to_all_v_ = nullptr; - - HcclComm hccl_comm_ = nullptr; - - bool init_flag_ = false; - bool init_kernel_info_store_ = false; - bool init_hccl_exec_ = false; - HcclMode hccl_mode_ = HcclMode::kGraph; - std::mutex init_mutex_; + void *pluginHandle_ = nullptr; + + HcomDestroyFunObj hcomDestroy_ = nullptr; + + HcclGetCommConfigCapabilityFunObj getHcclCommConfigCapability_ = nullptr; + HcclSetGlobalCommInfoFunObj setHcclGlobalCommInfo_ = nullptr; + HcclCommInitClusterInfoFunObj initHcclComm_ = nullptr; + HcclCommInitClusterInfoConfigFunObj initHcclGlobalCommRanktable_ = nullptr; + HcclCommInitRootInfoConfigFunObj initHcclRootInfoConfig_ = nullptr; + HcclCreateSubCommConfigFunObj initHcclSubCommRanktable_ = nullptr; + HcclCommDestroyFunObj finalizeHcclComm_ = nullptr; + HcclBroadcastFunObj launchHcclBroadcast_ = nullptr; + HcclAllReduceFunObj launchHcclAllReduce_ = nullptr; + HcclReduceFunObj launchHcclReduce_ = nullptr; + HcclScatterFunObj launchHcclScatter_ = nullptr; + HcclReduceScatterFunObj launchHcclReduceScatter_ = nullptr; + HcclAllGatherFunObj launchHcclAllGather_ = nullptr; + HcclSendFunObj launchHcclSend_ = nullptr; + HcclRecvFunObj launchHcclRecv_ = nullptr; + HcclBarrierFunObj launchHcclBarrier_ = nullptr; + HcclGetRankIdFunObj singleOpHcclGetRankId_ = nullptr; + HcclGetRankSizeFunObj singleOpHcclGetRankSize_ = nullptr; + HcclAlltoAllVFunObj launchHcclAllToAllV_ = nullptr; + HcclReduceScatterVFunObj launchHcclReduceScatterV_ = nullptr; + HcclAllGatherVFunObj launchHcclAllGatherV_ = nullptr; + HcclAlltoAllFunObj launchHcclAllToAll_ = nullptr; + HcclBatchSendRecvFunObj launchHcclBatchISendIRecv_ = nullptr; + HcclCommResumeFunObj launchHcclCommResume_ = nullptr; + HcclGetCommAsyncErrorFunObj hcclGetCommAsyncError_ = nullptr; + HcclGetErrorStringFunObj hcclGetErrorString_ = nullptr; + HcomCreateGroupFunObj hcclCreateGroup_ = nullptr; + HcomDestroyGroupFunObj hcclDestroyGroup_ = nullptr; + HcomGetRankIdFunObj hcclGetRankId_ = nullptr; + HcomGetRankSizeFunObj hcclGetRankSize_ = nullptr; + HcomGetLocalRankIdFunObj hcclGetLocalRankId_ = nullptr; + HcomGetLocalRankSizeFunObj hcclGetLocalRankSize_ = nullptr; + HcomGetWorldRankFromGroupRankFunObj hcclGetWorldRankByGroupRank_ = nullptr; + HcomGetGroupRankFromWorldRankFunObj hcclGetGroupRankByWorldRank_ = nullptr; + HcclCommWorkingDevNicSetFunObj hcclCommWorkingDevNicSet_ = nullptr; + + HcomExecInitializeFunObj hcclExecInitialize_ = nullptr; + HcomExecFinalizeFunObj hcclExecFinalize_ = nullptr; + HcomExecEnqueueOperationFunObj hcclExecEnqueueOp_ = nullptr; + HcomExecEnqueueAllToAllVFunObj hcclExecEnqueueAllToAllV_ = nullptr; + + HcclComm hcclComm_ = nullptr; + + bool initFlag_ = false; + bool initKernelInfoStore_ = false; + bool initHcclExec_ = false; + HcclMode hcclMode_ = HcclMode::kGraph; + std::mutex initMutex_; }; } // namespace mrt::ops #endif // OPS_ASCEND_HCCL_ADAPTER_H_ diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc b/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc index 4665da0d..138c9115 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_all_gather.cc @@ -32,14 +32,14 @@ namespace mrt { namespace ops { OpsErrorCode HcclAllGather::CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) { + size_t *workspaceSize) { LOG_OUT << "HcclAllGather CalcWorkspace"; HcclAdapter::GetInstance().InitHccl(); - auto [hccl_count, hccl_data_type] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); - hcclKernel.hccl_count_ = hccl_count; - hcclKernel.hccl_data_type_ = hccl_data_type; - const string &group_name = input[kIndex2]->ToString(); - hcclKernel.comm_ = HcomUtil::LoadHcclLibrary(group_name); + auto [hcclCount, hcclDataType] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); + hcclKernel_.hcclCount_ = hcclCount; + hcclKernel_.hcclDataType_ = hcclDataType; + const string &groupName = input[kIndex2]->ToString(); + hcclKernel_.comm_ = HcomUtil::LoadHcclLibrary(groupName); return SUCCESS; } @@ -49,8 +49,8 @@ OpsErrorCode HcclAllGather::Launch(const std::vector &input, LOG_OUT << "HcclAllGather launch"; auto hccl_result = HcclAdapter::GetInstance().HcclAllGather(const_cast(input[kIndex0]->ToTensor()->DataPtr()), - output->ToTensor()->DataPtr(), hcclKernel.hccl_count_, - hcclKernel.hccl_data_type_, stream, hcclKernel.comm_); + output->ToTensor()->DataPtr(), hcclKernel_.hcclCount_, + hcclKernel_.hcclDataType_, stream, hcclKernel_.comm_); if (hccl_result != ::HcclResult::HCCL_SUCCESS) { LOG_ERROR << "HcomAllGather failed, hccl_result: " << hccl_result; } diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_gather.h b/inferrt/src/ops/ascend/hccl/hccl_all_gather.h index 34bba835..d301bf9d 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_gather.h +++ b/inferrt/src/ops/ascend/hccl/hccl_all_gather.h @@ -32,12 +32,12 @@ class HcclAllGather : public OpAllGather { ~HcclAllGather() = default; OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) override; + size_t *workspaceSize) override; OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) override; private: - HcclKernel hcclKernel; + HcclKernel hcclKernel_; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_reduce.cc b/inferrt/src/ops/ascend/hccl/hccl_all_reduce.cc index e248f7e4..b493cf78 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_reduce.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_all_reduce.cc @@ -31,14 +31,14 @@ namespace mrt { namespace ops { OpsErrorCode HcclAllReduce::CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) { + size_t *workspaceSize) { LOG_OUT << "HcclAllReduce CalcWorkspace"; HcclAdapter::GetInstance().InitHccl(); - auto [hccl_count, hccl_data_type] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); - hcclKernel.hccl_count_ = hccl_count; - hcclKernel.hccl_data_type_ = hccl_data_type; - const string &group_name = input[kIndex2]->ToString(); - hcclKernel.comm_ = HcomUtil::LoadHcclLibrary(group_name); + auto [hcclCount, hcclDataType] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); + hcclKernel_.hcclCount_ = hcclCount; + hcclKernel_.hcclDataType_ = hcclDataType; + const string &groupName = input[kIndex2]->ToString(); + hcclKernel_.comm_ = HcomUtil::LoadHcclLibrary(groupName); return SUCCESS; } @@ -46,15 +46,15 @@ OpsErrorCode HcclAllReduce::CalcWorkspace(const std::vector & OpsErrorCode HcclAllReduce::Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) { LOG_OUT << "HcclAllReduce launch"; - auto hccl_op_type = HcomUtil::GetHcomReduceOpType(input[kIndex1]->ToString()); - auto out_tensor = output->ToTensor(); + auto hcclOpType = HcomUtil::GetHcomReduceOpType(input[kIndex1]->ToString()); + auto outTensor = output->ToTensor(); - auto hccl_result = HcclAdapter::GetInstance().HcclAllReduce( - const_cast(input[kIndex0]->ToTensor()->DataPtr()), out_tensor->DataPtr(), hcclKernel.hccl_count_, - hcclKernel.hccl_data_type_, hccl_op_type, stream, hcclKernel.comm_); + auto hcclResult = HcclAdapter::GetInstance().HcclAllReduce( + const_cast(input[kIndex0]->ToTensor()->DataPtr()), outTensor->DataPtr(), hcclKernel_.hcclCount_, + hcclKernel_.hcclDataType_, hcclOpType, stream, hcclKernel_.comm_); - if (hccl_result != ::HcclResult::HCCL_SUCCESS) { - LOG_ERROR << "HcclAllReduce failed, hccl_result: " << hccl_result; + if (hcclResult != ::HcclResult::HCCL_SUCCESS) { + LOG_ERROR << "HcclAllReduce failed, hcclResult: " << hcclResult; } return SUCCESS; diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_reduce.h b/inferrt/src/ops/ascend/hccl/hccl_all_reduce.h index 8f17cbdc..8b93140b 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_reduce.h +++ b/inferrt/src/ops/ascend/hccl/hccl_all_reduce.h @@ -31,12 +31,12 @@ class HcclAllReduce : public OpAllReduce { ~HcclAllReduce() = default; OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) override; + size_t *workspaceSize) override; OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) override; private: - HcclKernel hcclKernel; + HcclKernel hcclKernel_; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc index 47e074e6..c0a1a6fd 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.cc @@ -29,74 +29,74 @@ namespace mrt { namespace ops { -bool is_all_to_all_v(const ir::TuplePtr &send_numel_list, const ir::TuplePtr &recv_numel_list) { - for (size_t i = 0; i < send_numel_list->Size(); i++) { - if (send_numel_list->operator[](i)->ToInt() != send_numel_list->operator[](0)->ToInt()) { +bool is_all_to_all_v(const ir::TuplePtr &sendNumelList, const ir::TuplePtr &recvNumelList) { + for (size_t i = 0; i < sendNumelList->Size(); i++) { + if (sendNumelList->operator[](i)->ToInt() != sendNumelList->operator[](0)->ToInt()) { return true; } } - for (size_t i = 0; i < recv_numel_list->Size(); i++) { - if (recv_numel_list->operator[](i)->ToInt() != recv_numel_list->operator[](0)->ToInt()) { + for (size_t i = 0; i < recvNumelList->Size(); i++) { + if (recvNumelList->operator[](i)->ToInt() != recvNumelList->operator[](0)->ToInt()) { return true; } } return false; } -void GetAllToAllVParam(const ir::TuplePtr &send_numel_list, const ir::TuplePtr &recv_numel_list, +void GetAllToAllVParam(const ir::TuplePtr &sendNumelList, const ir::TuplePtr &recvNumelList, HcclAllToAllVParams *params) { uint64_t offset = 0; - for (size_t i = 0; i < send_numel_list->Size(); i++) { - auto count = static_cast(send_numel_list->operator[](i)->ToInt()); - params->sendcounts.push_back(count); + for (size_t i = 0; i < sendNumelList->Size(); i++) { + auto count = static_cast(sendNumelList->operator[](i)->ToInt()); + params->sendCounts.push_back(count); params->sdispls.push_back(offset); offset += count; } offset = 0; - for (size_t i = 0; i < recv_numel_list->Size(); i++) { - auto count = static_cast(recv_numel_list->operator[](i)->ToInt()); - params->recvcounts.push_back(count); + for (size_t i = 0; i < recvNumelList->Size(); i++) { + auto count = static_cast(recvNumelList->operator[](i)->ToInt()); + params->recvCounts.push_back(count); params->rdispls.push_back(offset); offset += count; } } OpsErrorCode HcclAllToAll::CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) { + size_t *workspaceSize) { LOG_OUT << "HcclAllToAll CalcWorkspace"; - const string &group_name = input[kIndex3]->ToString(); - auto rank_size = mrt::collective::CollectiveManager::Instance().GetGroupSize(group_name); + const string &groupName = input[kIndex3]->ToString(); + auto rankSize = mrt::collective::CollectiveManager::Instance().GetGroupSize(groupName); HcclAdapter::GetInstance().InitHccl(); - auto [hccl_count, hccl_data_type] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); - hcclKernel.hccl_count_ = hccl_count / rank_size; - hcclKernel.hccl_data_type_ = hccl_data_type; - hcclKernel.comm_ = HcomUtil::LoadHcclLibrary(group_name); - useAllToAllV = is_all_to_all_v(input[kIndex2]->ToTuple(), input[kIndex1]->ToTuple()); + auto [hcclCount, hcclDataType] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor()); + hcclKernel_.hcclCount_ = hcclCount / rankSize; + hcclKernel_.hcclDataType_ = hcclDataType; + hcclKernel_.comm_ = HcomUtil::LoadHcclLibrary(groupName); + useAllToAllV_ = is_all_to_all_v(input[kIndex2]->ToTuple(), input[kIndex1]->ToTuple()); return SUCCESS; } OpsErrorCode HcclAllToAll::Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) { LOG_OUT << "HcclAllToAll launch"; - auto out_tensor = output->ToTensor(); - ::HcclResult hccl_result; - if (useAllToAllV) { + auto outTensor = output->ToTensor(); + ::HcclResult hcclResult; + if (useAllToAllV_) { LOG_OUT << "HcclAllToAll launch AllToAllV Kernel"; HcclAllToAllVParams params; GetAllToAllVParam(input[kIndex2]->ToTuple(), input[kIndex1]->ToTuple(), ¶ms); - hccl_result = HcclAdapter::GetInstance().HcclAlltoAllV(const_cast(input[kIndex0]->ToTensor()->DataPtr()), - out_tensor->DataPtr(), params, hcclKernel.hccl_data_type_, - stream, hcclKernel.comm_); + hcclResult = HcclAdapter::GetInstance().HcclAlltoAllV(const_cast(input[kIndex0]->ToTensor()->DataPtr()), + outTensor->DataPtr(), params, hcclKernel_.hcclDataType_, + stream, hcclKernel_.comm_); } else { LOG_OUT << "HcclAllToAll launch AllToAll Kernel"; - HcclAllToAllParams params = {hcclKernel.hccl_count_, hcclKernel.hccl_count_}; - hccl_result = HcclAdapter::GetInstance().HcclAllToAll(const_cast(input[kIndex0]->ToTensor()->DataPtr()), - out_tensor->DataPtr(), params, hcclKernel.hccl_data_type_, - stream, hcclKernel.comm_); + HcclAllToAllParams params = {hcclKernel_.hcclCount_, hcclKernel_.hcclCount_}; + hcclResult = HcclAdapter::GetInstance().HcclAllToAll(const_cast(input[kIndex0]->ToTensor()->DataPtr()), + outTensor->DataPtr(), params, hcclKernel_.hcclDataType_, + stream, hcclKernel_.comm_); } - if (hccl_result != ::HcclResult::HCCL_SUCCESS) { - LOG_ERROR << "HcclAllReduce failed, hccl_result: " << hccl_result; + if (hcclResult != ::HcclResult::HCCL_SUCCESS) { + LOG_ERROR << "HcclAllToAll failed, hcclResult: " << hcclResult; } return SUCCESS; diff --git a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h index 2456e679..edc48f38 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h +++ b/inferrt/src/ops/ascend/hccl/hccl_all_to_all.h @@ -30,13 +30,13 @@ class HcclAllToAll : public OpAllToAll { ~HcclAllToAll() = default; OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) override; + size_t *workspaceSize) override; OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) override; private: - HcclKernel hcclKernel; - bool useAllToAllV; + HcclKernel hcclKernel_; + bool useAllToAllV_; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hccl_kernel.cc b/inferrt/src/ops/ascend/hccl/hccl_kernel.cc index 6f1c98c3..7952d33f 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_kernel.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_kernel.cc @@ -26,7 +26,7 @@ namespace mrt { namespace ops { -HcclKernel::HcclKernel() : hccl_count_(0), root_id_(0), src_rank_(0), dest_rank_(0), comm_(nullptr) {} +HcclKernel::HcclKernel() : hcclCount_(0), rootId_(0), comm_(nullptr) {} } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hccl_kernel.h b/inferrt/src/ops/ascend/hccl/hccl_kernel.h index bb6e896f..6d8d7f89 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_kernel.h +++ b/inferrt/src/ops/ascend/hccl/hccl_kernel.h @@ -38,16 +38,12 @@ class HcclKernel { ~HcclKernel() = default; public: - HcclDataType hccl_data_type_; - uint64_t hccl_count_; - uint32_t root_id_; - uint32_t src_rank_; - uint32_t dest_rank_; + HcclDataType hcclDataType_; + uint64_t hcclCount_; + uint32_t rootId_; std::string group_; HcclComm comm_; - ulong loop_size_{0}; - bool is_graph_mode_{false}; - std::string hccl_inner_comm_name_; + std::string hcclInnerCommName_; }; } // namespace ops diff --git a/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.cc b/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.cc index 8b031ead..aa2f6381 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.cc +++ b/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.cc @@ -30,15 +30,15 @@ namespace mrt { namespace ops { OpsErrorCode HcclReduceScatter::CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) { + size_t *workspaceSize) { LOG_OUT << "HcclReduceScatter CalcWorkspace"; HcclAdapter::GetInstance().InitHccl(); - auto rank_size = input[kIndex2]->ToInt(); - auto [hccl_count, hccl_data_type] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor(), rank_size); - hcclKernel.hccl_count_ = hccl_count; - hcclKernel.hccl_data_type_ = hccl_data_type; - const string &group_name = input[kIndex3]->ToString(); - hcclKernel.comm_ = HcomUtil::LoadHcclLibrary(group_name); + auto rankSize = input[kIndex2]->ToInt(); + auto [hcclCount, hcclDataType] = HcomUtil::GetHcclCountAndTypeFromTensor(input[kIndex0]->ToTensor(), rankSize); + hcclKernel_.hcclCount_ = hcclCount; + hcclKernel_.hcclDataType_ = hcclDataType; + const string &groupName = input[kIndex3]->ToString(); + hcclKernel_.comm_ = HcomUtil::LoadHcclLibrary(groupName); return SUCCESS; } @@ -46,15 +46,15 @@ OpsErrorCode HcclReduceScatter::CalcWorkspace(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) { LOG_OUT << "HcclReduceScatter launch"; - auto hccl_op_type = HcomUtil::GetHcomReduceOpType(input[kIndex1]->ToString()); - auto out_tensor = output->ToTensor(); + auto hcclOpType = HcomUtil::GetHcomReduceOpType(input[kIndex1]->ToString()); + auto outTensor = output->ToTensor(); - auto hccl_result = HcclAdapter::GetInstance().HcclReduceScatter( - const_cast(input[kIndex0]->ToTensor()->DataPtr()), out_tensor->DataPtr(), hcclKernel.hccl_count_, - hcclKernel.hccl_data_type_, hccl_op_type, stream, hcclKernel.comm_); + auto hcclResult = HcclAdapter::GetInstance().HcclReduceScatter( + const_cast(input[kIndex0]->ToTensor()->DataPtr()), outTensor->DataPtr(), hcclKernel_.hcclCount_, + hcclKernel_.hcclDataType_, hcclOpType, stream, hcclKernel_.comm_); - if (hccl_result != ::HcclResult::HCCL_SUCCESS) { - LOG_ERROR << "HcclReduceScatter failed, hccl_result: " << hccl_result; + if (hcclResult != ::HcclResult::HCCL_SUCCESS) { + LOG_ERROR << "HcclReduceScatter failed, hccl_result: " << hcclResult; } return SUCCESS; diff --git a/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.h b/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.h index 0ceb0c24..4e5d36a3 100644 --- a/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.h +++ b/inferrt/src/ops/ascend/hccl/hccl_reduce_scatter.h @@ -30,12 +30,12 @@ class HcclReduceScatter : public OpReduceScatter { ~HcclReduceScatter() = default; OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) override; + size_t *workspaceSize) override; OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) override; private: - HcclKernel hcclKernel; + HcclKernel hcclKernel_; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/hcom_utils.cc b/inferrt/src/ops/ascend/hccl/hcom_utils.cc index 1b312f9f..7f3011c7 100644 --- a/inferrt/src/ops/ascend/hccl/hcom_utils.cc +++ b/inferrt/src/ops/ascend/hccl/hcom_utils.cc @@ -46,33 +46,33 @@ inline size_t SizetMulWithOverflowCheck(size_t a, size_t b) { inline size_t LongToSizeClipNeg(int64_t u) { return u < 0 ? 0 : static_cast(u); } -::HcclDataType HcomUtil::ConvertHcclType(DataType type_id) { - auto iter = kConstOpHcomDataTypeMap.find(type_id); +::HcclDataType HcomUtil::ConvertHcclType(DataType typeId) { + auto iter = kConstOpHcomDataTypeMap.find(typeId); if (iter == kConstOpHcomDataTypeMap.end()) { - LOG_EXCEPTION << "HcomDataType can't support Current Ascend Data Type : " << type_id.ToString(); + LOG_EXCEPTION << "HcomDataType can't support Current Ascend Data Type : " << typeId.ToString(); } return iter->second; } -bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const std::vector &shape, size_t *size) { +bool HcomUtil::GetHcclOpSize(const HcclDataType &dataType, const std::vector &shape, size_t *size) { CHECK_IF_NULL(size); - int64_t tmp_size = 1; - uint32_t type_size = 4; + int64_t tmpSize = 1; + uint32_t typeSize = 4; for (size_t i = 0; i < shape.size(); i++) { - tmp_size = LongMulWithOverflowCheck(tmp_size, shape[i]); + tmpSize = LongMulWithOverflowCheck(tmpSize, shape[i]); } - if (!GetHcomTypeSize(data_type, &type_size)) { + if (!GetHcomTypeSize(dataType, &typeSize)) { return false; } - *size = SizetMulWithOverflowCheck(LongToSizeClipNeg(tmp_size), type_size); + *size = SizetMulWithOverflowCheck(LongToSizeClipNeg(tmpSize), typeSize); return true; } -bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) { +bool HcomUtil::GetHcomTypeSize(const HcclDataType &dataType, uint32_t *size) { CHECK_IF_NULL(size); - auto iter = kConstOpHcomDataTypeSizeMap.find(data_type); + auto iter = kConstOpHcomDataTypeSizeMap.find(dataType); if (iter == kConstOpHcomDataTypeSizeMap.end()) { LOG_ERROR << "HcomUtil::HcomDataTypeSize, No DataTypeSize!"; return false; @@ -81,79 +81,79 @@ bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) { return true; } -bool HcomUtil::GetHcomCount(const std::vector &data_type_list, - const std::vector> &shape_list, const size_t input_tensor_num, - const std::optional rank_size_opt, uint64_t *total_count) { - CHECK_IF_NULL(total_count); - - const uint32_t align_size = 512; - const uint32_t filled_size = 32; - uint64_t total_size = 0; - size_t input_size; - uint32_t type_size = 4; - size_t rank_size = 1; - bool is_reduce_scatter = false; - if (rank_size_opt.has_value()) { - rank_size = rank_size_opt.value(); - is_reduce_scatter = true; +bool HcomUtil::GetHcomCount(const std::vector &dataTypeList, + const std::vector> &shapeList, const size_t inputTensorNum, + const std::optional rankSizeOpt, uint64_t *totalCount) { + CHECK_IF_NULL(totalCount); + + const uint32_t alignSize = 512; + const uint32_t filledSize = 32; + uint64_t totalSize = 0; + size_t inputSize; + uint32_t typeSize = 4; + size_t rankSize = 1; + bool isReduceScatter = false; + if (rankSizeOpt.has_value()) { + rankSize = rankSizeOpt.value(); + isReduceScatter = true; } - CHECK_IF_FAIL(data_type_list.size() == shape_list.size()); + CHECK_IF_FAIL(dataTypeList.size() == shapeList.size()); - for (size_t i = 0; i < data_type_list.size(); ++i) { - if (!GetHcomTypeSize(data_type_list[i], &type_size)) { + for (size_t i = 0; i < dataTypeList.size(); ++i) { + if (!GetHcomTypeSize(dataTypeList[i], &typeSize)) { return false; } - if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) { + if (!GetHcclOpSize(dataTypeList[i], shapeList[i], &inputSize)) { LOG_ERROR << "Get GetHcclOpSize failed"; return false; } - if (input_tensor_num > 1) { + if (inputTensorNum > 1) { // communication operator with dynamic input should have continuous memory. - input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + inputSize = (inputSize + alignSize - 1 + filledSize) / alignSize * alignSize; } - if (is_reduce_scatter) { - input_size /= rank_size; + if (isReduceScatter) { + inputSize /= rankSize; } - bool all_dynamic = std::all_of(shape_list[i].begin(), shape_list[i].end(), [](int64_t x) { return x == -1; }); - if (!all_dynamic && (type_size == 0 || input_size % type_size != 0)) { + bool allDynamic = std::all_of(shapeList[i].begin(), shapeList[i].end(), [](int64_t x) { return x == -1; }); + if (!allDynamic && (typeSize == 0 || inputSize % typeSize != 0)) { return false; } - total_size += input_size / type_size; + totalSize += inputSize / typeSize; } - *total_count = total_size; + *totalCount = totalSize; return true; } std::pair HcomUtil::GetHcclCountAndTypeFromTensor( - const ir::TensorPtr &tensor, const std::optional rank_size_opt) { - auto type_id = tensor->Dtype(); + const ir::TensorPtr &tensor, const std::optional rankSizeOpt) { + auto typeId = tensor->Dtype(); auto shape = tensor->Shape(); - auto hccl_type = ConvertHcclType(type_id); + auto hcclType = ConvertHcclType(typeId); - uint64_t hccl_count = 0; - constexpr size_t input_tensor_size = 1; - if (!GetHcomCount({hccl_type}, {shape}, input_tensor_size, rank_size_opt, &hccl_count)) { + uint64_t hcclCount = 0; + constexpr size_t inputTensorSize = 1; + if (!GetHcomCount({hcclType}, {shape}, inputTensorSize, rankSizeOpt, &hcclCount)) { LOG_EXCEPTION << "GetHcomCount fail!"; } - return std::make_pair(hccl_count, hccl_type); + return std::make_pair(hcclCount, hcclType); } -CollectiveOpReduceType HcomUtil::GetCollectiveOpReduceType(const std::string &reduce_op) { - auto iter = kConstOpCollectiveOpReduceTypeMap.find(reduce_op); +CollectiveOpReduceType HcomUtil::GetCollectiveOpReduceType(const std::string &reduceOp) { + auto iter = kConstOpCollectiveOpReduceTypeMap.find(reduceOp); if (iter == kConstOpCollectiveOpReduceTypeMap.end()) { - LOG_EXCEPTION << "HcomUtil::Get CollectiveOpReduceType fail, [" << reduce_op << "] not support!"; + LOG_EXCEPTION << "HcomUtil::Get CollectiveOpReduceType fail, [" << reduceOp << "] not support!"; } return iter->second; } -HcclReduceOp HcomUtil::GetHcomReduceOpType(const std::string &reduce_op) { - auto iter = kConstOpHcomReduceOpTypeMap.find(reduce_op); +HcclReduceOp HcomUtil::GetHcomReduceOpType(const std::string &reduceOp) { + auto iter = kConstOpHcomReduceOpTypeMap.find(reduceOp); if (iter == kConstOpHcomReduceOpTypeMap.end()) { - LOG_EXCEPTION << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << reduce_op << "] not support!"; + LOG_EXCEPTION << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << reduceOp << "] not support!"; } return iter->second; } diff --git a/inferrt/src/ops/ascend/hccl/hcom_utils.h b/inferrt/src/ops/ascend/hccl/hcom_utils.h index 7c81ba1b..3598f415 100644 --- a/inferrt/src/ops/ascend/hccl/hcom_utils.h +++ b/inferrt/src/ops/ascend/hccl/hcom_utils.h @@ -107,23 +107,23 @@ static const std::unordered_map kConstOpCol class HcomUtil { public: - static ::HcclDataType ConvertHcclType(DataType type_id); - static HcclComm LoadHcclLibrary(const std::string &group_name) { - int64_t hccl_comm = collective::CollectiveManager::Instance().GetCommunicationGroup(group_name)->communicator(); - return reinterpret_cast(static_cast(hccl_comm)); + static ::HcclDataType ConvertHcclType(DataType typeId); + static HcclComm LoadHcclLibrary(const std::string &groupName) { + int64_t hcclComm = collective::CollectiveManager::Instance().GetCommunicationGroup(groupName)->communicator(); + return reinterpret_cast(static_cast(hcclComm)); } // static bool GetHcomDataType(const std::string &kernel_name, const std::vector &inputs, // const std::vector &outputs, std::vector *data_type_list); - static bool GetHcclOpSize(const HcclDataType &data_type, const std::vector &shape, size_t *size); - static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size); - static bool GetHcomCount(const std::vector &data_type_list, - const std::vector> &shape_list, const size_t input_tensor_num, - const std::optional rank_size_opt, uint64_t *total_count); + static bool GetHcclOpSize(const HcclDataType &dataType, const std::vector &shape, size_t *size); + static bool GetHcomTypeSize(const HcclDataType &dataType, uint32_t *size); + static bool GetHcomCount(const std::vector &dataTypeList, + const std::vector> &shapeList, const size_t inputTensorNum, + const std::optional rankSizeOpt, uint64_t *totalCount); static std::pair GetHcclCountAndTypeFromTensor( - const ir::TensorPtr &tensor, const std::optional rank_size_opt = std::nullopt); - static CollectiveOpReduceType GetCollectiveOpReduceType(const std::string &reduce_op); - static HcclReduceOp GetHcomReduceOpType(const std::string &reduce_op); + const ir::TensorPtr &tensor, const std::optional rankSizeOpt = std::nullopt); + static CollectiveOpReduceType GetCollectiveOpReduceType(const std::string &reduceOp); + static HcclReduceOp GetHcomReduceOpType(const std::string &reduceOp); }; } // namespace mrt::ops diff --git a/inferrt/src/ops/ascend/hccl/tensor_copy.cc b/inferrt/src/ops/ascend/hccl/tensor_copy.cc index 27d1cd5f..17934387 100644 --- a/inferrt/src/ops/ascend/hccl/tensor_copy.cc +++ b/inferrt/src/ops/ascend/hccl/tensor_copy.cc @@ -39,22 +39,22 @@ OpsErrorCode HcclTensorCopy::InferShape(const std::vector &in } OpsErrorCode HcclTensorCopy::CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) { + size_t *workspaceSize) { return SUCCESS; } OpsErrorCode HcclTensorCopy::Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) { LOG_OUT << "TensorCopy launch"; - auto src_tensor = input[kIndex1]->ToTensor(); - auto out_tensor = input[kIndex0]->ToTensor(); - auto dst_size = out_tensor->Numel() * out_tensor->Dtype().GetSize(); + auto srcTensor = input[kIndex1]->ToTensor(); + auto outTensor = input[kIndex0]->ToTensor(); + auto dstSize = outTensor->Numel() * outTensor->Dtype().GetSize(); // host_ptr, size, device_ptr, size, ACL_MEMCPY_DEVICE_TO_HOST, stream_ptr - auto ret = mrt::device::ascend::AscendResManager::MemcpyDeviceToDevice(out_tensor->DataPtr(), dst_size, - src_tensor->DataPtr(), dst_size, stream); + auto ret = mrt::device::ascend::AscendResManager::MemcpyDeviceToDevice(outTensor->DataPtr(), dstSize, + srcTensor->DataPtr(), dstSize, stream); if (ret == false) { - LOG_ERROR << " call aclrtMemcpyAsync in Op HcclTensorCopy failed"; + LOG_ERROR << " call aclrtMemcpyAsync in Op TensorCopy failed"; } return SUCCESS; diff --git a/inferrt/src/ops/ascend/hccl/tensor_copy.h b/inferrt/src/ops/ascend/hccl/tensor_copy.h index 6c13af49..244a12d6 100644 --- a/inferrt/src/ops/ascend/hccl/tensor_copy.h +++ b/inferrt/src/ops/ascend/hccl/tensor_copy.h @@ -35,9 +35,6 @@ class HcclTensorCopy : public Operator { size_t *workspace_size) override; OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) override; - - private: - HcclKernel hcclKernel; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/ascend/hccl/wait_tensor.cc b/inferrt/src/ops/ascend/hccl/wait_tensor.cc index 7d9be8e2..0692ce34 100644 --- a/inferrt/src/ops/ascend/hccl/wait_tensor.cc +++ b/inferrt/src/ops/ascend/hccl/wait_tensor.cc @@ -38,7 +38,7 @@ OpsErrorCode HcclWaitTensor::InferShape(const std::vector &in } OpsErrorCode HcclWaitTensor::CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) { + size_t *workspaceSize) { return SUCCESS; } @@ -46,14 +46,14 @@ OpsErrorCode HcclWaitTensor::Launch(const std::vector &input, ir::Value *output, void *stream) { LOG_OUT << "WaitTensor launch"; - auto src_tensor = input[kIndex0]->ToTensor(); - auto out_tensor = output->ToTensor(); - auto dst_size = out_tensor->Numel() * out_tensor->Dtype().GetSize(); + auto srcTensor = input[kIndex0]->ToTensor(); + auto outTensor = output->ToTensor(); + auto dstSize = outTensor->Numel() * outTensor->Dtype().GetSize(); - auto ret = mrt::device::ascend::AscendResManager::MemcpyDeviceToDevice(out_tensor->DataPtr(), dst_size, - src_tensor->DataPtr(), dst_size, stream); + auto ret = mrt::device::ascend::AscendResManager::MemcpyDeviceToDevice(outTensor->DataPtr(), dstSize, + srcTensor->DataPtr(), dstSize, stream); if (ret == false) { - LOG_ERROR << " call aclrtMemcpyAsync in Op HcclTensorCopy failed"; + LOG_ERROR << " call aclrtMemcpyAsync in Op WaitTensor failed"; } mrt::device::ascend::AscendStreamMng::GetInstance().SyncStream(stream); diff --git a/inferrt/src/ops/ascend/hccl/wait_tensor.h b/inferrt/src/ops/ascend/hccl/wait_tensor.h index cbd86574..9985b128 100644 --- a/inferrt/src/ops/ascend/hccl/wait_tensor.h +++ b/inferrt/src/ops/ascend/hccl/wait_tensor.h @@ -32,12 +32,9 @@ class HcclWaitTensor : public Operator { OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; OpsErrorCode CalcWorkspace(const std::vector &input, const ir::Value *output, - size_t *workspace_size) override; + size_t *workspaceSize) override; OpsErrorCode Launch(const std::vector &input, void *workspace, size_t workspaceSize, ir::Value *output, void *stream) override; - - private: - HcclKernel hcclKernel; }; } // namespace ops } // namespace mrt -- Gitee