From 945b4425de4fd41a0022b71354ad3b336ff5281c Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Tue, 14 Oct 2025 20:59:25 +0800 Subject: [PATCH 1/2] Support handle wait. --- examples/rdma_demo/main.cpp | 6 +- .../low_level/shmem_device_low_level_roce.h | 4 + include/device/shmem_device_sync.h | 1 + include/host/shmem_host_sync.h | 6 ++ include/host_device/shmem_types.h | 15 ++++ .../device/sync/shmemi_device_handle.h | 90 +++++++++++++++++++ .../internal/device/sync/shmemi_device_p2p.h | 21 +++++ src/device/shmemi_device_intf.h | 2 + src/device/shmemi_handle.cpp | 27 ++++++ src/host/sync/shmemi_sync.cpp | 4 + src/host/team/shmem_team.cpp | 6 +- 11 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 include/internal/device/sync/shmemi_device_handle.h create mode 100644 src/device/shmemi_handle.cpp diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index b225e274..58b9150a 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -55,8 +55,10 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size // AllGather allgather_demo(1, stream, (uint8_t *)ptr, trans_size * sizeof(int32_t)); + shmem_handle_t handle; + handle.team_id = SHMEM_TEAM_WORLD; + shmem_handle_wait(handle, stream); status = aclrtSynchronizeStream(stream); - shm::shmemi_control_barrier_all(); // 结果校验打印 int32_t *y_host; @@ -93,7 +95,7 @@ int main(int argc, char *argv[]) int rank_id = atoi(argv[argIdx++]); ipport = argv[argIdx++]; g_npus = atoi(argv[argIdx++]); - f_rank = atoi(argv[argIdx]); + f_rank = atoi(argv[argIdx++]); f_npu = atoi(argv[argIdx++]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index 66f8a0bc..ec8d3f45 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -120,6 +120,10 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { + if (idx == 0) { + return 0; + } + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); diff --git a/include/device/shmem_device_sync.h b/include/device/shmem_device_sync.h index bce97b19..0979b2c7 100644 --- a/include/device/shmem_device_sync.h +++ b/include/device/shmem_device_sync.h @@ -39,6 +39,7 @@ #include "internal/device/sync/shmemi_device_quiet.h" #include "internal/device/sync/shmemi_device_p2p.h" #include "internal/device/sync/shmemi_device_barrier.h" +#include "internal/device/sync/shmemi_device_handle.h" #ifdef __cplusplus extern "C" { diff --git a/include/host/shmem_host_sync.h b/include/host/shmem_host_sync.h index 939cb7c0..56a4e730 100644 --- a/include/host/shmem_host_sync.h +++ b/include/host/shmem_host_sync.h @@ -37,6 +37,12 @@ extern "C" { */ SHMEM_HOST_API uint64_t shmemx_get_ffts_config(); +/** + * @fn SHMEM_HOST_API void shmem_handle_wait(shmem_handle_t handle) + * @brief Wait asynchronous RMA operations to finish. + */ +SHMEM_HOST_API void shmem_handle_wait(shmem_handle_t handle, aclrtStream stream); + #ifdef __cplusplus } #endif diff --git a/include/host_device/shmem_types.h b/include/host_device/shmem_types.h index 030cf36c..22d23ed7 100644 --- a/include/host_device/shmem_types.h +++ b/include/host_device/shmem_types.h @@ -100,6 +100,21 @@ typedef int shmem_team_t; /**@} */ // end of group_typedef +/** + * @addtogroup group_structs + * @{ +*/ +/** + * @struct shmem_handle_t + * @brief Handle info used for non-blocking API synchronization. + * + * - shmem_team_t team_id: Team ID used for synchronization. +*/ +struct shmem_handle_t { + shmem_team_t team_id; +}; + +/**@} */ // end of group_structs #ifdef __cplusplus } #endif diff --git a/include/internal/device/sync/shmemi_device_handle.h b/include/internal/device/sync/shmemi_device_handle.h new file mode 100644 index 00000000..0f3ea61e --- /dev/null +++ b/include/internal/device/sync/shmemi_device_handle.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef SHEMEI_DEVICE_HANDLE_H +#define SHEMEI_DEVICE_HANDLE_H + +#include "shmemi_device_barrier.h" + +#include "kernel_operator.h" + +SHMEM_DEVICE void shmemi_barrier_cross_host(shmemi_team_t *team) +{ + if (AscendC::GetBlockIdx() != 0) + return; + + int my_pe = shmemi_get_state()->team_pools[SHMEM_TEAM_WORLD]->mype; + int start = team->start; + int stride = team->stride; + int size = team->size; + auto sync_array = shmemi_get_team_sync_array(team->team_idx); + auto sync_counter = shmemi_get_team_sync_counter(team->team_idx); + + int shift = 1; + int my_pe_in_team = (my_pe - start) / stride; + int32_t count = shmemi_load((__gm__ int32_t *)sync_counter) + 1; + + while (shift < size) { + int pre_pe_in_team = (my_pe_in_team - shift + size) % size; + int next_pe_in_team = (my_pe_in_team + shift) % size; + + int pre_pe = start + pre_pe_in_team * stride; + int next_pe = start + next_pe_in_team * stride; + + // signal next pe + shmemi_highlevel_signal_set((__gm__ int32_t *)(sync_array + my_pe), (__gm__ int32_t *)sync_counter, next_pe); + + // wait pre pe + shmemi_signal_wait_until_eq_for_barrier((__gm__ int32_t *)(sync_array + pre_pe), count); + + shift *= SHIFT_MULTIPLIER; + } + + shmemi_store((__gm__ int32_t *)sync_counter, count); +} + +SHMEM_DEVICE void shmemi_handle(shmem_team_t tid) +{ + shmemi_team_t *team = shmemi_get_state()->team_pools[tid]; + + int mype = shmemi_get_state()->team_pools[SHMEM_TEAM_WORLD]->mype; + int start = team->start; + int stride = team->stride; + int size = team->size; + + if ((mype - start) % stride != 0) { + // not in this team + return; + } + + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + + UB_ALIGN_SIZE); + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + + for (int i = 0; i < size; i++) { + int peer = start + i * stride; + if (peer == mype) { + continue; + } + shmemi_roce_quiet(peer, 0, ub_tensor_64, ub_tensor_32); + } + + if ASCEND_IS_AIV { + shmemi_barrier_cross_host(team); + } +} + +#endif \ No newline at end of file diff --git a/include/internal/device/sync/shmemi_device_p2p.h b/include/internal/device/sync/shmemi_device_p2p.h index 0d6da313..2ea70b0b 100644 --- a/include/internal/device/sync/shmemi_device_p2p.h +++ b/include/internal/device/sync/shmemi_device_p2p.h @@ -25,6 +25,27 @@ SHMEM_DEVICE void shmemi_signal_set(__gm__ int32_t *addr, int pe, int32_t val) shmemi_signal_set(shmemi_ptr(addr, pe), val); } +SHMEM_DEVICE void shmemi_highlevel_signal_set(__gm__ int32_t *dst, __gm__ int32_t *src, int pe) +{ + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + + UB_ALIGN_SIZE); + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + dcci_cacheline((__gm__ uint8_t *)src); + int32_t count = shmemi_load(src) + 1; + shmemi_store(src, count); + // flush data cache to GM after signal to ensure it is visiable to other ranks + dcci_cacheline((__gm__ uint8_t *)src); + shmemi_roce_write((__gm__ uint8_t*)shmem_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), + ub_tensor_64, ub_tensor_32); + shmemi_roce_quiet(pe, 0, ub_tensor_64, ub_tensor_32); +} + SHMEM_DEVICE void shmemi_signal_add(__gm__ int32_t *addr, int pe, int32_t val) { // ensure previous atomic operations end diff --git a/src/device/shmemi_device_intf.h b/src/device/shmemi_device_intf.h index e7af385c..94bfeb89 100644 --- a/src/device/shmemi_device_intf.h +++ b/src/device/shmemi_device_intf.h @@ -18,4 +18,6 @@ int32_t shmemi_memset(int32_t *array, int32_t len, int32_t val, int32_t count); int32_t shmemi_barrier_on_stream(shmem_team_t tid, void *stream); +int32_t shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream); + #endif \ No newline at end of file diff --git a/src/device/shmemi_handle.cpp b/src/device/shmemi_handle.cpp new file mode 100644 index 00000000..633e0b8a --- /dev/null +++ b/src/device/shmemi_handle.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "acl/acl.h" +#include "kernel_operator.h" + +#include "shmem_api.h" + +// kernels +SHMEM_GLOBAL void k_shmem_handle_wait(int32_t tid) +{ + shmemi_handle(tid); +} + +// interfaces +int32_t shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream) +{ + // call barrier kernel + k_shmem_handle_wait<<<1, nullptr, stream>>>((int32_t)handle.team_id); + return aclrtSynchronizeStream(stream); +} \ No newline at end of file diff --git a/src/host/sync/shmemi_sync.cpp b/src/host/sync/shmemi_sync.cpp index 62fe991f..72c7c8af 100644 --- a/src/host/sync/shmemi_sync.cpp +++ b/src/host/sync/shmemi_sync.cpp @@ -55,4 +55,8 @@ void shmem_barrier_on_stream(shmem_team_t tid, aclrtStream stream) void shmem_barrier_all_on_stream(aclrtStream stream) { shmemi_barrier_on_stream(SHMEM_TEAM_WORLD, stream); +} + +void shmem_handle_wait(shmem_handle_t handle, aclrtStream stream) { + shmemi_handle_wait_on_stream(handle, stream); } \ No newline at end of file diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index 942337f7..bd7f0b5d 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -89,13 +89,13 @@ int32_t shmemi_team_init_sync_pool() int32_t shmemi_team_init_sync_counter() { - auto ret = aclrtMalloc((void **)&(g_state.sync_counter), SYNC_COUNTERS_SIZE, ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != 0 || g_state.sync_counter == 0) { + g_state.sync_counter = (uint64_t)shmem_malloc(SYNC_COUNTERS_SIZE); + if (g_state.sync_counter == 0) { shmemi_team_finalize(); SHM_LOG_ERROR("malloc sync counter failed."); return SHMEM_INNER_ERROR; } - ret = aclrtMemset((void *)g_state.sync_counter, SYNC_COUNTERS_SIZE, 0, SYNC_COUNTERS_SIZE); + auto ret = aclrtMemset((void *)g_state.sync_counter, SYNC_COUNTERS_SIZE, 0, SYNC_COUNTERS_SIZE); if (ret != 0) { shmemi_team_finalize(); SHM_LOG_ERROR("memset sync counter failed."); -- Gitee From dc37f8db31219eb631dd0afc5586912cca800dc2 Mon Sep 17 00:00:00 2001 From: Qi Gao Date: Thu, 16 Oct 2025 16:10:55 +0800 Subject: [PATCH 2/2] Remove sync within handle API. --- examples/rdma_demo/main.cpp | 13 +++++-------- include/internal/device/sync/shmemi_device_handle.h | 5 ++--- include/internal/device/sync/shmemi_device_p2p.h | 5 ----- include/internal/host_device/shmemi_types.h | 2 +- src/device/shmemi_device_intf.h | 2 +- src/device/shmemi_handle.cpp | 3 +-- src/host/team/shmem_team.cpp | 2 +- 7 files changed, 11 insertions(+), 21 deletions(-) diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index 58b9150a..3e548752 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -67,16 +67,13 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size status = aclrtMemcpy(y_host, input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST); for (int i = 0; i < n_ranks; i++) { - if (y_host[trans_size * i] != num10 + i) { - std::cout << y_host[trans_size * i] << " != " << num10 + i << std::endl; - std::exit(EXIT_FAILURE); + for (int j = 0; j < 16; j++) { + if (y_host[trans_size * i + trans_size / 16 * j] != num10 + i) { + std::cout << y_host[trans_size * i + trans_size / 16 * j] << " != " << num10 + i << std::endl; + std::exit(EXIT_FAILURE); + } } } - std::cout << "rank: " << rank_id << " ["; - for (int j = 0; j < trans_size * n_ranks; j++) { - std::cout << y_host[j] << ", "; - } - std::cout << "]" << std::endl; // 去初始化 status = aclrtFreeHost(y_host); shmem_free(ptr); diff --git a/include/internal/device/sync/shmemi_device_handle.h b/include/internal/device/sync/shmemi_device_handle.h index 0f3ea61e..3f18538e 100644 --- a/include/internal/device/sync/shmemi_device_handle.h +++ b/include/internal/device/sync/shmemi_device_handle.h @@ -30,7 +30,8 @@ SHMEM_DEVICE void shmemi_barrier_cross_host(shmemi_team_t *team) int shift = 1; int my_pe_in_team = (my_pe - start) / stride; int32_t count = shmemi_load((__gm__ int32_t *)sync_counter) + 1; - + shmemi_store((__gm__ int32_t *)sync_counter, count); + dcci_cacheline((__gm__ uint8_t *)sync_counter); while (shift < size) { int pre_pe_in_team = (my_pe_in_team - shift + size) % size; int next_pe_in_team = (my_pe_in_team + shift) % size; @@ -46,8 +47,6 @@ SHMEM_DEVICE void shmemi_barrier_cross_host(shmemi_team_t *team) shift *= SHIFT_MULTIPLIER; } - - shmemi_store((__gm__ int32_t *)sync_counter, count); } SHMEM_DEVICE void shmemi_handle(shmem_team_t tid) diff --git a/include/internal/device/sync/shmemi_device_p2p.h b/include/internal/device/sync/shmemi_device_p2p.h index 2ea70b0b..d972d7d7 100644 --- a/include/internal/device/sync/shmemi_device_p2p.h +++ b/include/internal/device/sync/shmemi_device_p2p.h @@ -36,11 +36,6 @@ SHMEM_DEVICE void shmemi_highlevel_signal_set(__gm__ int32_t *dst, __gm__ int32_ ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + UB_ALIGN_SIZE); ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; - dcci_cacheline((__gm__ uint8_t *)src); - int32_t count = shmemi_load(src) + 1; - shmemi_store(src, count); - // flush data cache to GM after signal to ensure it is visiable to other ranks - dcci_cacheline((__gm__ uint8_t *)src); shmemi_roce_write((__gm__ uint8_t*)shmem_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), ub_tensor_64, ub_tensor_32); shmemi_roce_quiet(pe, 0, ub_tensor_64, ub_tensor_32); diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 2607d83c..d835301e 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -43,7 +43,7 @@ extern "C" { #define SHMEM_CORE_SYNC_COUNTER_SIZE SHMEMI_SYNCBIT_SIZE // Total extra -#define SHMEM_EXTRA_SIZE_UNALIGHED SYNC_POOL_SIZE +#define SHMEM_EXTRA_SIZE_UNALIGHED (SYNC_POOL_SIZE + SYNC_COUNTERS_SIZE) #define SHMEM_EXTRA_SIZE ALIGH_TO(SHMEM_EXTRA_SIZE_UNALIGHED, SHMEM_PAGE_SIZE) // synchronization diff --git a/src/device/shmemi_device_intf.h b/src/device/shmemi_device_intf.h index 94bfeb89..e0386be3 100644 --- a/src/device/shmemi_device_intf.h +++ b/src/device/shmemi_device_intf.h @@ -18,6 +18,6 @@ int32_t shmemi_memset(int32_t *array, int32_t len, int32_t val, int32_t count); int32_t shmemi_barrier_on_stream(shmem_team_t tid, void *stream); -int32_t shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream); +void shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream); #endif \ No newline at end of file diff --git a/src/device/shmemi_handle.cpp b/src/device/shmemi_handle.cpp index 633e0b8a..547af387 100644 --- a/src/device/shmemi_handle.cpp +++ b/src/device/shmemi_handle.cpp @@ -19,9 +19,8 @@ SHMEM_GLOBAL void k_shmem_handle_wait(int32_t tid) } // interfaces -int32_t shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream) +void shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream) { // call barrier kernel k_shmem_handle_wait<<<1, nullptr, stream>>>((int32_t)handle.team_id); - return aclrtSynchronizeStream(stream); } \ No newline at end of file diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index bd7f0b5d..fcfa811b 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -203,7 +203,7 @@ int32_t shmemi_team_finalize() } if (g_state.sync_counter != 0) { - aclrtFree(reinterpret_cast(g_state.sync_counter)); + shmem_free(reinterpret_cast(g_state.sync_counter)); g_state.sync_counter = 0; } if (g_state.sync_pool != 0) { -- Gitee