diff --git a/examples/rdma_demo/main.cpp b/examples/rdma_demo/main.cpp index b225e27435fce65c5de69c9ac2d396f1c48c73f1..3e54875254c28f9e08c25d032fb2cbcf3f183f84 100644 --- a/examples/rdma_demo/main.cpp +++ b/examples/rdma_demo/main.cpp @@ -55,8 +55,10 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size // AllGather allgather_demo(1, stream, (uint8_t *)ptr, trans_size * sizeof(int32_t)); + shmem_handle_t handle; + handle.team_id = SHMEM_TEAM_WORLD; + shmem_handle_wait(handle, stream); status = aclrtSynchronizeStream(stream); - shm::shmemi_control_barrier_all(); // 结果校验打印 int32_t *y_host; @@ -65,16 +67,13 @@ int test_shmem_team_all_gather(int rank_id, int n_ranks, uint64_t local_mem_size status = aclrtMemcpy(y_host, input_size, ptr, input_size, ACL_MEMCPY_DEVICE_TO_HOST); for (int i = 0; i < n_ranks; i++) { - if (y_host[trans_size * i] != num10 + i) { - std::cout << y_host[trans_size * i] << " != " << num10 + i << std::endl; - std::exit(EXIT_FAILURE); + for (int j = 0; j < 16; j++) { + if (y_host[trans_size * i + trans_size / 16 * j] != num10 + i) { + std::cout << y_host[trans_size * i + trans_size / 16 * j] << " != " << num10 + i << std::endl; + std::exit(EXIT_FAILURE); + } } } - std::cout << "rank: " << rank_id << " ["; - for (int j = 0; j < trans_size * n_ranks; j++) { - std::cout << y_host[j] << ", "; - } - std::cout << "]" << std::endl; // 去初始化 status = aclrtFreeHost(y_host); shmem_free(ptr); @@ -93,7 +92,7 @@ int main(int argc, char *argv[]) int rank_id = atoi(argv[argIdx++]); ipport = argv[argIdx++]; g_npus = atoi(argv[argIdx++]); - f_rank = atoi(argv[argIdx]); + f_rank = atoi(argv[argIdx++]); f_npu = atoi(argv[argIdx++]); uint64_t local_mem_size = 1024UL * 1024UL * 1024; status = test_shmem_team_all_gather(rank_id, n_ranks, local_mem_size); diff --git a/include/device/low_level/shmem_device_low_level_roce.h b/include/device/low_level/shmem_device_low_level_roce.h index 66f8a0bcb621d9999477c47fa70facf5b971b57e..ec8d3f4514c390118495e4283035dd77dc69d90b 100644 --- a/include/device/low_level/shmem_device_low_level_roce.h +++ b/include/device/low_level/shmem_device_low_level_roce.h @@ -120,6 +120,10 @@ SHMEM_DEVICE uint32_t shmemi_roce_poll_cq(uint32_t remoteRankId, uint32_t qpIdx, AscendC::LocalTensor ubLocal64, AscendC::LocalTensor ubLocal32) { + if (idx == 0) { + return 0; + } + __gm__ SHMEMHybmDeviceMeta* metaPtr = (__gm__ SHMEMHybmDeviceMeta*)(SMEM_SHM_DEVICE_META_ADDR + SMEM_SHM_DEVICE_GLOBAL_META_SIZE); __gm__ SHMEMAIVRDMAInfo* RDMAInfo = (__gm__ SHMEMAIVRDMAInfo*)(metaPtr->qpInfoAddress); diff --git a/include/device/shmem_device_sync.h b/include/device/shmem_device_sync.h index bce97b199af325c804d65e228f79fe28e54fa8e4..0979b2c7bc6a3d9ce75fdb5f5babb57cdf19a238 100644 --- a/include/device/shmem_device_sync.h +++ b/include/device/shmem_device_sync.h @@ -39,6 +39,7 @@ #include "internal/device/sync/shmemi_device_quiet.h" #include "internal/device/sync/shmemi_device_p2p.h" #include "internal/device/sync/shmemi_device_barrier.h" +#include "internal/device/sync/shmemi_device_handle.h" #ifdef __cplusplus extern "C" { diff --git a/include/host/shmem_host_sync.h b/include/host/shmem_host_sync.h index 939cb7c0aa07eff5ab2eadbc0f888e4ba4b1d92c..56a4e730cede3a95ada3ed587ad399bb02632651 100644 --- a/include/host/shmem_host_sync.h +++ b/include/host/shmem_host_sync.h @@ -37,6 +37,12 @@ extern "C" { */ SHMEM_HOST_API uint64_t shmemx_get_ffts_config(); +/** + * @fn SHMEM_HOST_API void shmem_handle_wait(shmem_handle_t handle) + * @brief Wait asynchronous RMA operations to finish. + */ +SHMEM_HOST_API void shmem_handle_wait(shmem_handle_t handle, aclrtStream stream); + #ifdef __cplusplus } #endif diff --git a/include/host_device/shmem_types.h b/include/host_device/shmem_types.h index 030cf36c10068143fee91facbaf26592c32d6e77..22d23ed7440ddb37704cd47cb23d2f33482089d6 100644 --- a/include/host_device/shmem_types.h +++ b/include/host_device/shmem_types.h @@ -100,6 +100,21 @@ typedef int shmem_team_t; /**@} */ // end of group_typedef +/** + * @addtogroup group_structs + * @{ +*/ +/** + * @struct shmem_handle_t + * @brief Handle info used for non-blocking API synchronization. + * + * - shmem_team_t team_id: Team ID used for synchronization. +*/ +struct shmem_handle_t { + shmem_team_t team_id; +}; + +/**@} */ // end of group_structs #ifdef __cplusplus } #endif diff --git a/include/internal/device/sync/shmemi_device_handle.h b/include/internal/device/sync/shmemi_device_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..3f18538ed3b4d663ddc232769046cea7a056ebe6 --- /dev/null +++ b/include/internal/device/sync/shmemi_device_handle.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef SHEMEI_DEVICE_HANDLE_H +#define SHEMEI_DEVICE_HANDLE_H + +#include "shmemi_device_barrier.h" + +#include "kernel_operator.h" + +SHMEM_DEVICE void shmemi_barrier_cross_host(shmemi_team_t *team) +{ + if (AscendC::GetBlockIdx() != 0) + return; + + int my_pe = shmemi_get_state()->team_pools[SHMEM_TEAM_WORLD]->mype; + int start = team->start; + int stride = team->stride; + int size = team->size; + auto sync_array = shmemi_get_team_sync_array(team->team_idx); + auto sync_counter = shmemi_get_team_sync_counter(team->team_idx); + + int shift = 1; + int my_pe_in_team = (my_pe - start) / stride; + int32_t count = shmemi_load((__gm__ int32_t *)sync_counter) + 1; + shmemi_store((__gm__ int32_t *)sync_counter, count); + dcci_cacheline((__gm__ uint8_t *)sync_counter); + while (shift < size) { + int pre_pe_in_team = (my_pe_in_team - shift + size) % size; + int next_pe_in_team = (my_pe_in_team + shift) % size; + + int pre_pe = start + pre_pe_in_team * stride; + int next_pe = start + next_pe_in_team * stride; + + // signal next pe + shmemi_highlevel_signal_set((__gm__ int32_t *)(sync_array + my_pe), (__gm__ int32_t *)sync_counter, next_pe); + + // wait pre pe + shmemi_signal_wait_until_eq_for_barrier((__gm__ int32_t *)(sync_array + pre_pe), count); + + shift *= SHIFT_MULTIPLIER; + } +} + +SHMEM_DEVICE void shmemi_handle(shmem_team_t tid) +{ + shmemi_team_t *team = shmemi_get_state()->team_pools[tid]; + + int mype = shmemi_get_state()->team_pools[SHMEM_TEAM_WORLD]->mype; + int start = team->start; + int stride = team->stride; + int size = team->size; + + if ((mype - start) % stride != 0) { + // not in this team + return; + } + + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + + UB_ALIGN_SIZE); + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + + for (int i = 0; i < size; i++) { + int peer = start + i * stride; + if (peer == mype) { + continue; + } + shmemi_roce_quiet(peer, 0, ub_tensor_64, ub_tensor_32); + } + + if ASCEND_IS_AIV { + shmemi_barrier_cross_host(team); + } +} + +#endif \ No newline at end of file diff --git a/include/internal/device/sync/shmemi_device_p2p.h b/include/internal/device/sync/shmemi_device_p2p.h index 0d6da3133e464c5720e922920a073c91330b4e36..d972d7d7256a2c2ef7655313e0f09941d608356f 100644 --- a/include/internal/device/sync/shmemi_device_p2p.h +++ b/include/internal/device/sync/shmemi_device_p2p.h @@ -25,6 +25,22 @@ SHMEM_DEVICE void shmemi_signal_set(__gm__ int32_t *addr, int pe, int32_t val) shmemi_signal_set(shmemi_ptr(addr, pe), val); } +SHMEM_DEVICE void shmemi_highlevel_signal_set(__gm__ int32_t *dst, __gm__ int32_t *src, int pe) +{ + AscendC::LocalTensor ub_tensor_32; + ub_tensor_32.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_32.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR); + ub_tensor_32.address_.dataLen = UB_ALIGN_SIZE; + AscendC::LocalTensor ub_tensor_64; + ub_tensor_64.address_.logicPos = static_cast(AscendC::TPosition::VECOUT); + ub_tensor_64.address_.bufferAddr = reinterpret_cast(SHMEM_INTERNAL_UB_BUF_START_ADDR + + UB_ALIGN_SIZE); + ub_tensor_64.address_.dataLen = UB_ALIGN_SIZE; + shmemi_roce_write((__gm__ uint8_t*)shmem_ptr(dst, pe), (__gm__ uint8_t*)src, pe, 0, sizeof(int32_t), + ub_tensor_64, ub_tensor_32); + shmemi_roce_quiet(pe, 0, ub_tensor_64, ub_tensor_32); +} + SHMEM_DEVICE void shmemi_signal_add(__gm__ int32_t *addr, int pe, int32_t val) { // ensure previous atomic operations end diff --git a/include/internal/host_device/shmemi_types.h b/include/internal/host_device/shmemi_types.h index 2607d83c3c7cc54d5e155f65714c24c567311699..d835301ee9c230dfd8ecc5b66702224a163f4e16 100644 --- a/include/internal/host_device/shmemi_types.h +++ b/include/internal/host_device/shmemi_types.h @@ -43,7 +43,7 @@ extern "C" { #define SHMEM_CORE_SYNC_COUNTER_SIZE SHMEMI_SYNCBIT_SIZE // Total extra -#define SHMEM_EXTRA_SIZE_UNALIGHED SYNC_POOL_SIZE +#define SHMEM_EXTRA_SIZE_UNALIGHED (SYNC_POOL_SIZE + SYNC_COUNTERS_SIZE) #define SHMEM_EXTRA_SIZE ALIGH_TO(SHMEM_EXTRA_SIZE_UNALIGHED, SHMEM_PAGE_SIZE) // synchronization diff --git a/src/device/shmemi_device_intf.h b/src/device/shmemi_device_intf.h index e7af385cb2227cd10b1abdf2cd1dadaa6b177495..e0386be3ddee8d8d73d0f9b3f2b0463f2341df04 100644 --- a/src/device/shmemi_device_intf.h +++ b/src/device/shmemi_device_intf.h @@ -18,4 +18,6 @@ int32_t shmemi_memset(int32_t *array, int32_t len, int32_t val, int32_t count); int32_t shmemi_barrier_on_stream(shmem_team_t tid, void *stream); +void shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream); + #endif \ No newline at end of file diff --git a/src/device/shmemi_handle.cpp b/src/device/shmemi_handle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..547af3879f35855b1bdfeef3dd0b6f75fc014abf --- /dev/null +++ b/src/device/shmemi_handle.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "acl/acl.h" +#include "kernel_operator.h" + +#include "shmem_api.h" + +// kernels +SHMEM_GLOBAL void k_shmem_handle_wait(int32_t tid) +{ + shmemi_handle(tid); +} + +// interfaces +void shmemi_handle_wait_on_stream(shmem_handle_t handle, aclrtStream stream) +{ + // call barrier kernel + k_shmem_handle_wait<<<1, nullptr, stream>>>((int32_t)handle.team_id); +} \ No newline at end of file diff --git a/src/host/sync/shmemi_sync.cpp b/src/host/sync/shmemi_sync.cpp index 62fe991f805c0d160d27bc1a3a9b119910d24f13..72c7c8af8124a60473e6dfaaa1e041ae89472eef 100644 --- a/src/host/sync/shmemi_sync.cpp +++ b/src/host/sync/shmemi_sync.cpp @@ -55,4 +55,8 @@ void shmem_barrier_on_stream(shmem_team_t tid, aclrtStream stream) void shmem_barrier_all_on_stream(aclrtStream stream) { shmemi_barrier_on_stream(SHMEM_TEAM_WORLD, stream); +} + +void shmem_handle_wait(shmem_handle_t handle, aclrtStream stream) { + shmemi_handle_wait_on_stream(handle, stream); } \ No newline at end of file diff --git a/src/host/team/shmem_team.cpp b/src/host/team/shmem_team.cpp index 942337f7c5c8a836a061d53a8b1dc403580cb640..fcfa811b949222fa26db5774d4fec0d5f7051c62 100644 --- a/src/host/team/shmem_team.cpp +++ b/src/host/team/shmem_team.cpp @@ -89,13 +89,13 @@ int32_t shmemi_team_init_sync_pool() int32_t shmemi_team_init_sync_counter() { - auto ret = aclrtMalloc((void **)&(g_state.sync_counter), SYNC_COUNTERS_SIZE, ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != 0 || g_state.sync_counter == 0) { + g_state.sync_counter = (uint64_t)shmem_malloc(SYNC_COUNTERS_SIZE); + if (g_state.sync_counter == 0) { shmemi_team_finalize(); SHM_LOG_ERROR("malloc sync counter failed."); return SHMEM_INNER_ERROR; } - ret = aclrtMemset((void *)g_state.sync_counter, SYNC_COUNTERS_SIZE, 0, SYNC_COUNTERS_SIZE); + auto ret = aclrtMemset((void *)g_state.sync_counter, SYNC_COUNTERS_SIZE, 0, SYNC_COUNTERS_SIZE); if (ret != 0) { shmemi_team_finalize(); SHM_LOG_ERROR("memset sync counter failed."); @@ -203,7 +203,7 @@ int32_t shmemi_team_finalize() } if (g_state.sync_counter != 0) { - aclrtFree(reinterpret_cast(g_state.sync_counter)); + shmem_free(reinterpret_cast(g_state.sync_counter)); g_state.sync_counter = 0; } if (g_state.sync_pool != 0) {