From 783eb0f64b37bc0f3f0c95132c3273ef5e0a8deb Mon Sep 17 00:00:00 2001 From: oniond Date: Sun, 12 Oct 2025 12:52:29 +0800 Subject: [PATCH 1/2] add barrier --- src/comm/all_gather.cpp | 5 +++++ src/comm/transpose.cpp | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/comm/all_gather.cpp b/src/comm/all_gather.cpp index 88d9b8b..ea44f01 100644 --- a/src/comm/all_gather.cpp +++ b/src/comm/all_gather.cpp @@ -16,6 +16,7 @@ #include "../tensor/tensor.h" #include "kutacc.h" #include "kupl.h" +#include namespace kutacc { void af2_all_gather_kernel(Tensor &data, Tensor &out) @@ -61,6 +62,8 @@ void af2_all_gather_kernel(Tensor &data, Tensor &out) }); }); + MPI_Barrier(MPI_COMM_WORLD); + parallel_for(0, par_size, 1, [&](int64_t start, int64_t end) { collapse_for(start, end, world_size, size_m, n, [&](int64_t ri, int64_t sbmi, int64_t ni) { int64_t bmi = sbmi + block_mi_start; @@ -75,6 +78,8 @@ void af2_all_gather_kernel(Tensor &data, Tensor &out) } }); }); + + MPI_Barrier(MPI_COMM_WORLD); } } } diff --git a/src/comm/transpose.cpp b/src/comm/transpose.cpp index c47604d..0003f0a 100644 --- a/src/comm/transpose.cpp +++ b/src/comm/transpose.cpp @@ -16,6 +16,7 @@ #include "../tensor/tensor.h" #include "kutacc.h" #include "kupl.h" +#include namespace kutacc { void af2_transpose_kernel(Tensor &data, Tensor &out) @@ -64,6 +65,8 @@ void af2_transpose_kernel(Tensor &data, Tensor &out) }); }); + MPI_Barrier(MPI_COMM_WORLD); + parallel_for(0, par_size, 1, [&](int64_t start, int64_t end) { collapse_for(start, end, world_size, size_m, block_n, [&](int64_t ri, int64_t sbmi, int64_t bni) { int64_t bmi = sbmi + block_mi_start; @@ -79,6 +82,8 @@ void af2_transpose_kernel(Tensor &data, Tensor &out) } }); }); + + MPI_Barrier(MPI_COMM_WORLD); } } } -- Gitee From 9150687a6e34c52393b5e9438b4702121b5e3132 Mon Sep 17 00:00:00 2001 From: oniond Date: Sun, 12 Oct 2025 12:53:54 +0800 Subject: [PATCH 2/2] improve perf using ref --- src/tensor/tensor.cpp | 2 +- src/tensor/tensor.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index c0cb728..2c5fc1d 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -25,7 +25,7 @@ std::vector Tensor::sizes() const return data.sizes; } -std::vector Tensor::strides() const +const std::vector& Tensor::strides() const { return data.strides; } diff --git a/src/tensor/tensor.h b/src/tensor/tensor.h index 433bca0..b7c2c6e 100644 --- a/src/tensor/tensor.h +++ b/src/tensor/tensor.h @@ -32,7 +32,7 @@ struct Tensor { : data(SimpleTensor(data_ptr, sizes, strides, dim, dtype)) {}; void* data_ptr() const; std::vector sizes() const; - std::vector strides() const; + const std::vector& strides() const; DType dtype() const; int64_t dim() const; }; -- Gitee