diff --git a/docs/README.md b/docs/README.md
index b688a7f7960ad7bec14e33c5eeb99b51da3e196b..b90de102a7bd4bd32cad0d34380f6b855122a0ba 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -334,10 +334,14 @@
给定两个源操作数src0和src1,根据maskTensor相应位置的值(非bit位)选取元素,得到目的操作数dst。 |
- | 变形 |
+ 变形 |
ConfusionTranspose |
对输入数据进行数据排布及Reshape操作。 |
+
+ | TransData |
+ 对输入数据排布格式转换为输出所需的数据排布格式 |
+
| 索引操作 |
ArithProgression |
diff --git a/impl/CMakeLists.txt b/impl/CMakeLists.txt
index c29d95d33e3a958ae66a96848fc92ab7cc2df526..1ab2cc72b6dc11b43344c7529011edcaa3002e4e 100644
--- a/impl/CMakeLists.txt
+++ b/impl/CMakeLists.txt
@@ -92,6 +92,7 @@ add_library(tiling_api STATIC
${CMAKE_CURRENT_SOURCE_DIR}/math/axpy/axpy_tiling_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/math/ceil/ceil_tiling_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/math/floor/floor_tiling_impl.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/transpose/transdata/transdata_tiling.cpp
${CMAKE_CURRENT_SOURCE_DIR}/math/fmod/fmod_tiling_impl.cpp
${CMAKE_CURRENT_SOURCE_DIR}/math/trunc/trunc_tiling_impl.cpp
$<$:$>
diff --git a/impl/reduce/reduce_tiling.cpp b/impl/reduce/reduce_tiling.cpp
index d6efe31fa567979eb885947e0253374ee4406501..71722928c3a4c898614e4c5f2ba0528dcbd113c0 100644
--- a/impl/reduce/reduce_tiling.cpp
+++ b/impl/reduce/reduce_tiling.cpp
@@ -102,7 +102,6 @@ void GetReduceCommonMaxMinTmpSize(const ge::Shape &srcShape,
}
inline void GetReduceSumMeanCommonTmpSize(const ge::Shape &srcShape,
- const ge::DataType dataType,
ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
uint32_t &maxValue, uint32_t &minValue, std::string apiName, std::string funcName)
{
@@ -137,7 +136,6 @@ inline void GetReduceSumMeanCommonTmpSize(const ge::Shape &srcShape,
}
inline void GetReduceAnyAllCommonTmpSize(const ge::Shape &srcShape,
- const ge::DataType dataType,
ReducePattern pattern, bool isSrcInnerPad, bool isReuseSource,
uint32_t &maxValue, uint32_t &minValue, std::string apiName, std::string funcName)
{
@@ -229,7 +227,7 @@ void GetReduceAnyMaxMinTmpSize(const ge::Shape &srcShape,
return,
"[ReduceAny][GetReduceAnyMaxMinTmpSize] it only supports float and uint8_t type on this platform.");
if (dataType == ge::DT_UINT8) {
- GetReduceAnyAllCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
+ GetReduceAnyAllCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
"ReduceAny", "GetReduceAnyMaxMinTmpSize");
} else {
GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
@@ -245,7 +243,7 @@ void GetReduceAllMaxMinTmpSize(const ge::Shape &srcShape,
ASCENDC_HOST_ASSERT((dataType == ge::DT_FLOAT || dataType == ge::DT_UINT8), return,
"[ReduceAll][GetReduceAllMaxMinTmpSize] it only supports float and uint8 type on this platform.");
if (dataType == ge::DT_UINT8) {
- GetReduceAnyAllCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
+ GetReduceAnyAllCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
"ReduceAll", "GetReduceAllMaxMinTmpSize");
} else {
GetReduceCommonMaxMinTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
@@ -260,7 +258,7 @@ void GetReduceSumMaxMinTmpSize(const ge::Shape &srcShape,
{
ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return,
"[ReduceSum][GetReduceSumMaxMinTmpSize] it only supports float type on this platform.");
- GetReduceSumMeanCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
+ GetReduceSumMeanCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
"ReduceSum", "GetReduceSumMaxMinTmpSize");
}
@@ -271,7 +269,7 @@ void GetReduceMeanMaxMinTmpSize(const ge::Shape &srcShape,
{
ASCENDC_HOST_ASSERT(dataType == ge::DT_FLOAT, return,
"[ReduceMean][GetReduceMeanMaxMinTmpSize] it only supports float type on this platform.");
- GetReduceSumMeanCommonTmpSize(srcShape, dataType, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
+ GetReduceSumMeanCommonTmpSize(srcShape, pattern, isSrcInnerPad, isReuseSource, maxValue, minValue,
"ReduceMean", "GetReduceMeanMaxMinTmpSize");
}
} // namespace AscendC
diff --git a/impl/transpose/transdata/transdata_impl.h b/impl/transpose/transdata/transdata_impl.h
new file mode 100644
index 0000000000000000000000000000000000000000..31547e27ba9775f61b6e7fd3a1b22b45d0331049
--- /dev/null
+++ b/impl/transpose/transdata/transdata_impl.h
@@ -0,0 +1,524 @@
+/**
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+ * See LICENSE in the root of the software repository for the full text of the License.
+ */
+#ifndef IMPL_TRANSPOSE_TRANSDATA_TRANSDATA_IMPL_H
+#define IMPL_TRANSPOSE_TRANSDATA_TRANSDATA_IMPL_H
+
+#include "kernel_tensor.h"
+#include "kernel_operator_intf.h"
+#include "kernel_tiling/kernel_tiling.h"
+#include "../../common/check.h"
+
+namespace AscendC {
+namespace Internal {
+
+namespace {
+constexpr int32_t n0 = 16;
+constexpr int32_t c0 = 16;
+constexpr int32_t hw0 = 16;
+constexpr int32_t ncdhwDims = 5;
+constexpr int32_t fractalZ3DDims = 7;
+constexpr int32_t ndc1hwc0Dims = 6;
+}
+
+struct TransDataTmpParams {
+ int32_t n;
+ int32_t c;
+ int32_t d;
+ int32_t h;
+ int32_t w;
+ int32_t n1;
+ int32_t c1;
+ int32_t padHw;
+};
+
+constexpr int32_t DEFAULT_TRANSDATA_5HD_LIST = 16;
+
+template
+__aicore__ inline void DC1Hwn1n0c0ToC1DHwn1n0c0HWAlign(const LocalTensor& dst, const LocalTensor& src,
+ const TransDataTmpParams& params)
+{
+ // d, c1, h w n1 n0 c0 -> c1, d, hw1*hw0 n1 n0 c0
+ int32_t d = params.d;
+ int32_t h = params.h;
+ int32_t w = params.w;
+ int32_t n1 = params.n1;
+ int32_t c1 = params.c1;
+ int32_t padHw = params.padHw;
+
+ uint32_t dim0 = d;
+ uint32_t dim1 = c1;
+ uint32_t lastDim = h * w * n1 * n0 * c0;
+
+ // dim0, dim1, lastDim -> dim1, dim0, lastDim
+ int32_t n1n0c0DimElems = n1 * n0 * c0;
+ int32_t hwAlignElems = padHw * n1n0c0DimElems;
+ int32_t hwPadElems = (padHw - h * w) * n1n0c0DimElems;
+
+ uint16_t blockCount = dim1;
+ uint16_t blockLen = lastDim * sizeof(T) / ONE_BLK_SIZE;
+ uint16_t srcGap = 0;
+ uint16_t dstGap = ((dim0 - 1) * hwAlignElems + hwPadElems) * sizeof(T) / ONE_BLK_SIZE;
+
+ uint32_t dstSize = c1 * d * padHw * n1 * n0 * c0;
+ Duplicate(dst, static_cast(0), dstSize);
+ PipeBarrier();
+
+ DataCopyParams dataCopyParams = { blockCount, blockLen, srcGap, dstGap };
+ for (uint32_t d0 = 0; d0 < dim0; d0++) {
+ DataCopy(dst[d0 * hwAlignElems], src[d0 * dim1 * lastDim], dataCopyParams);
+ }
+ PipeBarrier();
+}
+
+template
+__aicore__ inline void C1Dhwn1n0c0ToC1C0Dhwn1n0(const LocalTensor& dst, const LocalTensor& src,
+ const TransDataTmpParams& params)
+{
+ // C1 DHWN1N0 C0 -> C1 C0 DHWN1N0
+ int32_t d = params.d;
+ int32_t n1 = params.n1;
+ int32_t c1 = params.c1;
+ int32_t padHw = params.padHw;
+
+ TransDataTo5HDParams transDataParams;
+ transDataParams.dstHighHalf = false;
+ transDataParams.srcHighHalf = false;
+ transDataParams.repeatTimes = d * padHw * n1;
+ if (transDataParams.repeatTimes == 1) {
+ transDataParams.srcRepStride = 0;
+ transDataParams.dstRepStride = 0;
+ } else {
+ transDataParams.srcRepStride = DEFAULT_TRANSDATA_5HD_LIST * c0 * sizeof(T) / ONE_BLK_SIZE;
+ transDataParams.dstRepStride = n0 * sizeof(T) / ONE_BLK_SIZE;
+ }
+
+ uint64_t srcOffsetArr[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t dstOffsetArr[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t srcAddr = (uint64_t)src.GetPhyAddr();
+ uint64_t dstAddr = (uint64_t)dst.GetPhyAddr();
+ for (uint32_t j = 0; j < c1; j++) {
+ uint32_t outOffset = j * d * padHw * n1 * n0 * c0;
+ for (uint8_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ srcOffsetArr[i] = (uint64_t)(srcAddr + (outOffset + i * n0) * sizeof(T));
+ dstOffsetArr[i] = (uint64_t)(dstAddr + (outOffset + i * d * padHw * n1 * n0) * sizeof(T));
+ }
+ TransDataTo5HD(dstOffsetArr, srcOffsetArr, transDataParams);
+ }
+ PipeBarrier();
+}
+
+template
+__aicore__ inline void C1c0dhwN1n0ToNcdhw(const LocalTensor& dst, const LocalTensor& src,
+ const LocalTensor& tmp, const TransDataTmpParams& params)
+{
+ // C1C0DHW N1N0 -> N CDHW
+ int32_t d = params.d;
+ int32_t n1 = params.n1;
+ int32_t padHw = params.padHw;
+ int32_t currN = params.n;
+ int32_t c = params.c;
+
+ TransDataTo5HDParams transDataParams;
+ transDataParams.dstHighHalf = false;
+ transDataParams.srcHighHalf = false;
+ transDataParams.repeatTimes = c * d * padHw / n0;
+ if (transDataParams.repeatTimes == 1) {
+ transDataParams.srcRepStride = 0;
+ transDataParams.dstRepStride = 0;
+ } else {
+ transDataParams.srcRepStride = DEFAULT_TRANSDATA_5HD_LIST * n1 * n0 * sizeof(T) / ONE_BLK_SIZE;
+ transDataParams.dstRepStride = c0 * sizeof(T) / ONE_BLK_SIZE;
+ }
+
+ uint64_t srcOffsetArr[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t dstOffsetArr[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t srcAddr = (uint64_t)src.GetPhyAddr();
+ uint64_t dstAddr = (uint64_t)dst.GetPhyAddr();
+ uint64_t tmpAddr = (uint64_t)tmp.GetPhyAddr();
+ for (uint32_t j = 0; j < n1; j++) {
+ if (n0 - currN > 0) {
+ for (uint8_t i = 0; i < currN; i++) {
+ dstOffsetArr[i] = (uint64_t)(dstAddr + (j * d * c * padHw * n0 + i * c * d * padHw) * sizeof(T));
+ }
+ for (uint8_t i = currN; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ dstOffsetArr[i] = (uint64_t)(tmpAddr + i * ONE_BLK_SIZE * sizeof(T));
+ }
+ } else {
+ for (uint8_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ dstOffsetArr[i] = (uint64_t)(dstAddr + (j * d * c * padHw * n0 + i * c * d * padHw) * sizeof(T));
+ }
+ }
+ for (uint8_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ srcOffsetArr[i] = (uint64_t)(srcAddr + (j * n0 + i * n0 * n1) * sizeof(T));
+ }
+ TransDataTo5HD(dstOffsetArr, srcOffsetArr, transDataParams);
+ currN -= n0;
+ }
+ PipeBarrier();
+}
+
+template
+__aicore__ inline void N1n0C1c0DHWToNCDHW(const LocalTensor& dst, const LocalTensor& src,
+ const TransDataTmpParams& params)
+{
+ // N1N0 C1C0 D H W -> N C D H W
+ int32_t n = params.n;
+ int32_t c = params.c;
+ int32_t d = params.d;
+ int32_t c1 = params.c1;
+ int32_t padHw = params.padHw;
+
+ uint16_t blockCount = n;
+ uint16_t blockLen = (c * (d * padHw)) * sizeof(T) / ONE_BLK_SIZE;
+ uint16_t srcGap = ((c1 * c0 - c) * (d * padHw)) * sizeof(T) /ONE_BLK_SIZE;
+ uint16_t dstGap = 0;
+ DataCopyParams dataCopyParams = { blockCount, blockLen, srcGap, dstGap };
+ DataCopy(dst, src, dataCopyParams);
+ PipeBarrier();
+}
+
+template
+__aicore__ inline void TransDataFractalToNcdhw(const LocalTensor& dst, const LocalTensor& src,
+ const LocalTensor& tmpBuffer, const TransDataTmpParams& params)
+{
+ int32_t d = params.d;
+ int32_t n1 = params.n1;
+ int32_t c1 = params.c1;
+ int32_t padHw = params.padHw;
+ int32_t n = params.n;
+ int32_t c = params.c;
+
+ LocalTensor tmp = tmpBuffer.template ReinterpretCast();
+ LocalTensor srcTmp = src.template ReinterpretCast();
+ if (c == c1 * c0 && n == n1 * n0) {
+ LocalTensor dstTmp = dst.template ReinterpretCast();
+ // D C1 HWN1N0C0 -> C1 D HWN1N0C0 (H*W 32B ALIGN -> HW1*HW0)
+ DC1Hwn1n0c0ToC1DHwn1n0c0HWAlign(dstTmp, srcTmp, params);
+ // C1 DHWN1N0 C0 -> C1 C0 DHWN1N0
+ C1Dhwn1n0c0ToC1C0Dhwn1n0(tmp, dstTmp, params);
+ // C1C0DHW N1N0 -> N CDHW
+ C1c0dhwN1n0ToNcdhw(dstTmp, tmp, tmp, params);
+ } else {
+ LocalTensor transDataTmp = tmp[n1 * n0 * c1 * c0 * d * padHw];
+ LocalTensor dstTmp = dst.template ReinterpretCast();
+ // D C1 HWN1N0C0 -> C1 D HWN1N0C0 (H*W 32B ALIGN -> HW1*HW0)
+ DC1Hwn1n0c0ToC1DHwn1n0c0HWAlign(tmp, srcTmp, params);
+ // C1 DHWN1N0 C0 -> C1 C0 DHWN1N0
+ C1Dhwn1n0c0ToC1C0Dhwn1n0(transDataTmp, tmp, params);
+ // C1C0DHW N1N0 -> N CDHW
+ C1c0dhwN1n0ToNcdhw(dstTmp, transDataTmp, tmp, params);
+ }
+}
+
+// Transdata NCDHW -> FRACTAL_Z_3D
+template
+__aicore__ inline void TransDataImplNcdhwToFractal(const LocalTensor& dst, const LocalTensor& src, const LocalTensor& tmpBuffer,
+ const TransDataTmpParams& param)
+{
+ constexpr int32_t elePerBlk = ONE_BLK_SIZE / sizeof(T);
+ const int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w;
+ constexpr int32_t c0 = 16;
+ constexpr int32_t n0 = 16;
+ const int32_t c1 = DivCeil(c, c0);
+ const int32_t n1 = DivCeil(n, n0);
+ int32_t padHw = AlignUp(h * w, elePerBlk);
+ int32_t currAxis = c * d * padHw;
+ Duplicate(tmpBuffer.ReinterpretCast(), static_cast(0), currAxis);
+ PipeBarrier();
+ auto tmpDstTensor = tmpBuffer[currAxis * sizeof(T)].ReinterpretCast();
+ uint64_t dstLocalList[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t srcLocalList[DEFAULT_TRANSDATA_5HD_LIST];
+
+ uint64_t dstTensorAddr = (uint64_t)dst.GetPhyAddr();
+ uint64_t srcTensorAddr = (uint64_t)src.GetPhyAddr();
+ uint64_t tmpDstTensorAddr = (uint64_t)tmpDstTensor.GetPhyAddr();
+ uint64_t tmpBufferAddr = (uint64_t)tmpBuffer.GetPhyAddr();
+ // step1, NCDHW -> CDHW, N1, N0
+ // Do n1 times Transpose to split axis N, and fill with 0 on padding data.
+ TransDataTo5HDParams transDataParams;
+ transDataParams.dstHighHalf = false;
+ transDataParams.srcHighHalf = false;
+ transDataParams.repeatTimes = currAxis / elePerBlk;
+ // if repeat = 1, start offset is auto incremental by stride.
+ transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : n1 * n0;
+ transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : 1;
+
+ bool isPadded = padHw != h * w;
+ // dst tensor is unable to fill all padded data.
+ auto tmpIfPadAddr = isPadded ? tmpDstTensorAddr : dstTensorAddr;
+ for (int j = 0; j < n1; j++) {
+ uint64_t currDstAddr = tmpIfPadAddr + j * n0 * sizeof(T);
+ uint64_t currSrcAddr = srcTensorAddr + j * currAxis * n0 * sizeof(T);
+ // handle the last axis if N is not even splited by n0.
+ int remain = j == n1 - 1 ? n - j * n0 : n0;
+ for (int32_t i = 0; i < n0; i++) {
+ dstLocalList[i] = currDstAddr + (i * n1 * n0) * sizeof(T);
+ }
+ for (int32_t i = 0; i < remain; i++) {
+ srcLocalList[i] = currSrcAddr + i * currAxis * sizeof(T);
+ }
+ for (int32_t i = remain; i < n0; i++) {
+ srcLocalList[i] = tmpBufferAddr;
+ }
+ TransDataTo5HD(dstLocalList, srcLocalList, transDataParams);
+ }
+ PipeBarrier();
+ // step1.5 collapse padded H,W axis for CDHW, N1N0
+ DataCopyParams copyParams;
+ if (isPadded) {
+ currAxis = h * w * n1 * n0;
+ copyParams.blockCount = c * d;
+ copyParams.blockLen = currAxis / elePerBlk;
+ // Merge axis by skiping padded H,W.
+ copyParams.srcStride = (padHw - h * w) * n1 * n0 / elePerBlk;
+ copyParams.dstStride = 0;
+ DataCopy(dst, tmpDstTensor, copyParams);
+ }
+ PipeBarrier();
+
+ // step2, CDHWN1N0 -> C1DHW, N1N0, C0
+ currAxis = d * h * w * n1 * n0;
+ transDataParams.repeatTimes = currAxis / elePerBlk;
+ transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : c0;
+ transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : 1;
+ for (int32_t j = 0; j < c1; j++) {
+ uint64_t currDstAddr = tmpDstTensorAddr + j * currAxis * c0 * sizeof(T);
+ uint64_t currSrcAddr = dstTensorAddr + j * currAxis * c0 * sizeof(T);
+ int remain = j == c1 - 1 ? c - j * c0 : c0;
+ for (int32_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ dstLocalList[i] = currDstAddr + i * c0 * sizeof(T);
+ }
+ for (int32_t i = 0; i < remain; i++) {
+ srcLocalList[i] = currSrcAddr + i * currAxis * sizeof(T);
+ }
+ for (int32_t i = remain; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ srcLocalList[i] = tmpBufferAddr;
+ }
+ TransDataTo5HD(dstLocalList, srcLocalList, transDataParams);
+ }
+ PipeBarrier();
+ // steo3 C1DHW, N1N0, C0 -> DC1HW, N1N0, C0
+ currAxis = c0 * h * w * n1 * n0;
+ copyParams.blockCount = d;
+ copyParams.blockLen = currAxis / elePerBlk;
+ // Merge axis by skiping padding padHW -> h, w
+ copyParams.srcStride = 0;
+ copyParams.dstStride = (c1 - 1) * currAxis / elePerBlk;
+ for (int32_t i = 0; i < c1; i++) {
+ DataCopy(dst[i * currAxis], tmpDstTensor[i * d * currAxis], copyParams);
+ }
+ PipeBarrier();
+}
+
+// Transdata NCDHW -> NDC1HWC0
+template
+__aicore__ inline void TransDataImplNcdhwTo6Hd(const LocalTensor& dst, const LocalTensor& src, const LocalTensor& tmpBuffer,
+ const TransDataTmpParams& param)
+{
+ constexpr int32_t c0 = 16;
+ constexpr int32_t elePerBlk = ONE_BLK_SIZE / sizeof(T);
+ const int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w;
+ const int32_t c1 = DivCeil(c, c0);
+ const int32_t padHw = AlignUp(h * w, elePerBlk);
+ int32_t currAxis = d * padHw;
+
+ int32_t axisHwd = h * w * d;
+ int32_t axisHwc0 = h * w * c0;
+ int32_t axisC1hwc0 = axisHwc0 * c1;
+ int32_t axisC1hwdc0 = axisC1hwc0 * d;
+ int32_t axisPadHwd = padHw * d;
+ int32_t axisPadHwc0 = padHw * c0;
+ int32_t axisPadHwdc0 = padHw * c0 * d;
+ Duplicate(tmpBuffer.ReinterpretCast(), static_cast(0), axisPadHwd);
+ PipeBarrier();
+
+ // reserve for padded 0 on additional axis c.
+ auto tmpDstTensor = tmpBuffer[axisPadHwd * sizeof(T)].ReinterpretCast();
+
+ uint64_t dstTensorAddr = (uint64_t)dst.GetPhyAddr();
+ uint64_t srcTensorAddr = (uint64_t)src.GetPhyAddr();
+ uint64_t tmpDstTensorAddr = (uint64_t)tmpDstTensor.GetPhyAddr();
+ uint64_t tmpBufferAddr = (uint64_t)tmpBuffer.GetPhyAddr();
+ uint64_t dstLocalList[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t srcLocalList[DEFAULT_TRANSDATA_5HD_LIST];
+ TransDataTo5HDParams transDataParams;
+ transDataParams.dstHighHalf = false;
+ transDataParams.srcHighHalf = false;
+ transDataParams.repeatTimes = axisPadHwd / elePerBlk;
+ transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : c0;
+ transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : 1;
+
+ DataCopyParams copyParams;
+ copyParams.blockCount = d;
+ copyParams.blockLen = axisHwc0 / elePerBlk;
+ copyParams.srcStride = (padHw - h * w) * c0 / elePerBlk;
+ copyParams.dstStride = (c1 - 1) * axisHwc0 / elePerBlk;
+ // iterates N times CDHW -> C1DHWC0
+ for (int32_t k = 0; k < n; k++) {
+ int32_t currSrcStart = k * axisPadHwd * c;
+ int32_t currDstStart = k * axisC1hwdc0;
+ // it's impossible to have calculation size exceed max 255 repeats due to the total memory size.
+ // step1, CDHW -> C1DHWC0 with pad data
+ for (int32_t j = 0; j < c1; j++) {
+ uint64_t currDstAddr = tmpDstTensorAddr + j * axisPadHwdc0 * sizeof(T);
+ uint64_t currSrcAddr = srcTensorAddr + (currSrcStart + j * axisPadHwdc0) * sizeof(T);
+ int remain = j == c1 - 1 ? c - j * c0 : c0;
+ for (int32_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ dstLocalList[i] = currDstAddr + i * c0 * sizeof(T);
+ }
+ for (int32_t i = 0; i < remain; i++) {
+ srcLocalList[i] = currSrcAddr + i * axisPadHwd * sizeof(T);
+ }
+ for (int32_t i = remain; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ srcLocalList[i] = tmpBufferAddr;
+ }
+ TransDataTo5HD(dstLocalList, srcLocalList, transDataParams);
+ }
+ PipeBarrier();
+ // step2, C1DHWC0 -> DC1HWC0
+ for (int32_t i = 0; i < c1; i++) {
+ DataCopy(dst[currDstStart + i * axisHwc0], tmpDstTensor[i * axisPadHwdc0], copyParams);
+ }
+ PipeBarrier();
+ }
+}
+
+// Transdata NDC1HWC0 -> NCDHW
+template
+__aicore__ inline void TransDataImpl6HdToNcdhw(const LocalTensor& dst, const LocalTensor& src, const LocalTensor& tmpBuffer,
+ const TransDataTmpParams& param)
+{
+ const int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w;
+ constexpr int32_t c0 = 16;
+ constexpr int32_t elePerBlk = ONE_BLK_SIZE / sizeof(T);
+ const int32_t c1 = DivCeil(c, c0);
+ const int32_t padHw = AlignUp(h * w, elePerBlk);
+ constexpr int32_t reservedDummy = 512;
+ auto tmpDstTensor = tmpBuffer[reservedDummy].template ReinterpretCast();
+ uint64_t dstLocalList[DEFAULT_TRANSDATA_5HD_LIST];
+ uint64_t srcLocalList[DEFAULT_TRANSDATA_5HD_LIST];
+
+ uint64_t dstTensorAddr = (uint64_t)dst.GetPhyAddr();
+ uint64_t tmpDstTensorAddr = (uint64_t)tmpDstTensor.GetPhyAddr();
+ uint64_t tmpBufferAddr = (uint64_t)tmpBuffer.GetPhyAddr();
+
+ int32_t axisHwd = h * w * d;
+ int32_t axisHwc0 = h * w * c0;
+ int32_t axisC1hwc0 = axisHwc0 * c1;
+ int32_t axisC1hwdc0 = axisC1hwc0 * d;
+ int32_t axisPadHwd = padHw * d;
+ int32_t axisPadHwc0 = padHw * c0;
+ int32_t axisPadHwdc0 = padHw * c0 * d;
+ TransDataTo5HDParams transDataParams;
+ transDataParams.dstHighHalf = false;
+ transDataParams.srcHighHalf = false;
+ transDataParams.repeatTimes = padHw * d / elePerBlk;
+ transDataParams.srcRepStride = transDataParams.repeatTimes == 1 ? 0 : c0;
+ transDataParams.dstRepStride = transDataParams.repeatTimes == 1 ? 0 : 1;
+
+ DataCopyParams copyParams;
+ copyParams.blockCount = c1;
+ copyParams.blockLen = h * w * c0 / elePerBlk;
+ copyParams.srcStride = 0;
+ copyParams.dstStride = (d * padHw - h * w) * c0 / elePerBlk;
+ // iterates N times C1DHWC0 -> CDHW
+ for (int32_t k = 0; k < n; k++) {
+ // step1 DC1HWC0 -> C1DHWC0
+ int32_t currSrcStart = k * axisC1hwdc0;
+ int32_t currDstStart = k * axisPadHwd * c;
+ for (int32_t i = 0; i < d; i++) {
+ DataCopy(tmpDstTensor[i * axisPadHwc0], src[currSrcStart + i * axisC1hwc0], copyParams);
+ }
+ PipeBarrier();
+ // step2, C1DHWC0 -> C1C0DHW
+ // it's impossible to have calculation size exceed max 255 repeats due to the total memory size.
+ for (int32_t j = 0; j < c1; j++) {
+ int32_t remain = j == c1 - 1 ? c - j * c0 : c0;
+ uint64_t currDstAddr = dstTensorAddr + (currDstStart + j * axisPadHwdc0) * sizeof(T);
+ uint64_t currSrcAddr = tmpDstTensorAddr + j * axisPadHwdc0 * sizeof(T);
+ for (int32_t i = 0; i < remain; i++) {
+ dstLocalList[i] = currDstAddr + i * axisPadHwd * sizeof(T);
+ }
+ for (int32_t i = remain; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ // temp for reserve redundant data.
+ dstLocalList[i] = tmpBufferAddr + i * ONE_BLK_SIZE;
+ }
+ for (int32_t i = 0; i < DEFAULT_TRANSDATA_5HD_LIST; i++) {
+ srcLocalList[i] = currSrcAddr + i * c0 * sizeof(T);
+ }
+ TransDataTo5HD(dstLocalList, srcLocalList, transDataParams);
+ }
+ PipeBarrier();
+ }
+}
+
+template
+__aicore__ inline void TransDataCheck(const TransDataParams& params)
+{
+ static_assert(SupportType(),
+ "Currents only supports half/bfloat16_t/uint16_t/int16_t types.");
+ static_assert(is_layout_v, "srcLayout must be a layout");
+ static_assert(is_layout_v, "dstLayout must be a layout");
+ using SrcShapeTuple = Std::remove_cvref_t;
+ using DstShapeTuple = Std::remove_cvref_t;
+ static_assert(Std::is_tuple_v, "srcLayout.GetShape() must be a shape.");
+ static_assert(Std::is_tuple_v, "dstLayout.GetShape() must be a shape.");
+}
+
+template
+__aicore__ inline void TransDataImpl(const LocalTensor& dstTensor, const LocalTensor& srcTensor,
+ const LocalTensor& sharedTmpBuffer, const TransDataParams& params)
+{
+ TransDataCheck(params);
+ auto srcShape = params.srcLayout.GetShape();
+ auto dstShape = params.dstLayout.GetShape();
+ constexpr uint32_t srcShapeSize = static_cast(Std::tuple_size::value);
+ constexpr uint32_t dstShapeSize = static_cast(Std::tuple_size::value);
+ using srcType = decltype(srcShape);
+ using dstType = decltype(dstShape);
+ using ncdhwType = Std::conditional_t;
+ ncdhwType ncdhwShape;
+ if constexpr (config.srcFormat == DataFormat::NCDHW) {
+ ncdhwShape = params.srcLayout.GetShape();
+ } else {
+ ncdhwShape = params.dstLayout.GetShape();
+ }
+ int32_t n = Std::get<0>(ncdhwShape);
+ int32_t c = Std::get<1>(ncdhwShape);
+ int32_t d = Std::get<2>(ncdhwShape);
+ int32_t h = Std::get<3>(ncdhwShape);
+ int32_t w = Std::get<4>(ncdhwShape);
+ int32_t n1 = (n + n0 - 1) / n0;
+ int32_t c1 = (c + c0 - 1) / c0;
+ int32_t hw1 = (h * w + hw0 - 1) / hw0;
+ int32_t padHw = hw1 * hw0;
+ TransDataTmpParams tmpParams = { n, c, d, h, w, n1, c1, padHw };
+ if constexpr (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) {
+ static_assert(srcShapeSize == ncdhwDims, "srcLayout's shape dims must be equal to 5!");
+ static_assert(dstShapeSize == fractalZ3DDims, "dstLayout's shape dims must be equal to 7!");
+ TransDataImplNcdhwToFractal(dstTensor, srcTensor, sharedTmpBuffer, tmpParams);
+ } else if constexpr (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) {
+ static_assert(srcShapeSize == fractalZ3DDims, "srcLayout's shape dims must be equal to 7!");
+ static_assert(dstShapeSize == ncdhwDims, "dstLayout's shape dims must be equal to 5!");
+ TransDataFractalToNcdhw(dstTensor, srcTensor, sharedTmpBuffer, tmpParams);
+ } else if constexpr (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) {
+ static_assert(srcShapeSize == ncdhwDims, "srcLayout's shape dims must be equal to 5!");
+ static_assert(dstShapeSize == ndc1hwc0Dims, "dstLayout's shape dims must be equal to 6!");
+ TransDataImplNcdhwTo6Hd(dstTensor, srcTensor, sharedTmpBuffer, tmpParams);
+ } else if constexpr (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW) {
+ static_assert(srcShapeSize == ndc1hwc0Dims, "srcLayout's shape dims must be equal to 6!");
+ static_assert(dstShapeSize == ncdhwDims, "dstLayout's shape dims must be equal to 5!");
+ TransDataImpl6HdToNcdhw(dstTensor, srcTensor, sharedTmpBuffer, tmpParams);
+ }
+}
+
+} // namespace Internal
+} // namespace AscendC
+#endif // IMPL_TRANSPOSE_TRANSDATA_TRANSDATA_IMPL_H
\ No newline at end of file
diff --git a/impl/transpose/transdata/transdata_tiling.cpp b/impl/transpose/transdata/transdata_tiling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dbfc281e48b3c8e9770a2f55bb2bea36d545ae38
--- /dev/null
+++ b/impl/transpose/transdata/transdata_tiling.cpp
@@ -0,0 +1,197 @@
+/**
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+ * See LICENSE in the root of the software repository for the full text of the License.
+ */
+
+#include "lib/transpose/transdata_tiling.h"
+
+#include
+#include
+
+#include "graph/tensor.h"
+#include "impl/host_log.h"
+#include "tiling/platform/platform_ascendc.h"
+namespace AscendC {
+namespace {
+constexpr int32_t PAD_ELE_FOR_HALF = 16;
+constexpr int32_t N_INDEX = 0;
+constexpr int32_t C_INDEX = 1;
+constexpr int32_t D_INDEX = 2;
+constexpr int32_t H_INDEX = 3;
+constexpr int32_t W_INDEX = 4;
+
+struct TmpTransDataParams {
+ int32_t n = 0;
+ int32_t c = 0;
+ int32_t d = 0;
+ int32_t h = 0;
+ int32_t w = 0;
+};
+
+int32_t DivCeil(int32_t a, int32_t b)
+{
+ if (b == 0) {
+ return a;
+ }
+ return (a + b - 1) / b;
+}
+
+int32_t AlignUp(int32_t a, int32_t b)
+{
+ return DivCeil(a, b) * b;
+}
+
+bool GenerateFractalZ3DToNcdhwShapeInfo(const std::vector& dstDims, const std::vector& srcDims,
+ TmpTransDataParams ¶m, const int32_t c0, const int32_t n0)
+{
+ ASCENDC_HOST_ASSERT(srcDims.size() == 7 && dstDims.size() == 5, return false,
+ "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat.");
+ param.n = dstDims[N_INDEX];
+ param.c = dstDims[C_INDEX];
+ param.d = dstDims[D_INDEX];
+ param.h = dstDims[H_INDEX];
+ param.w = dstDims[W_INDEX];
+ // validate d, h, w
+ ASCENDC_HOST_ASSERT(param.d == srcDims[0] && param.h == srcDims[2] && param.w == srcDims[3], return false,
+ "[TransData][GetTransDataMaxMinTmpSize] shapeInfo d,h,w is not matched.");
+ ASCENDC_HOST_ASSERT(srcDims[6] == c0 && srcDims[1] * c0 == AlignUp(param.c, c0), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] src c0, c1 is not able to be converted to c.");
+ ASCENDC_HOST_ASSERT(srcDims[5] == n0 && srcDims[4] * n0 == AlignUp(param.n, n0), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] src n0, n1 is not able to be converted to n.");
+ return true;
+}
+
+bool GenerateNcdhwToFractalZ3DShapeInfo(const std::vector& dstDims, const std::vector& srcDims,
+ TmpTransDataParams ¶m, const int32_t c0, const int32_t n0)
+{
+ ASCENDC_HOST_ASSERT(srcDims.size() == 5 && dstDims.size() == 7, return false,
+ "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat.");
+ param.n = srcDims[N_INDEX];
+ param.c = srcDims[C_INDEX];
+ param.d = srcDims[D_INDEX];
+ param.h = srcDims[H_INDEX];
+ param.w = srcDims[W_INDEX];
+ // validate d, h, w
+ ASCENDC_HOST_ASSERT(param.d == dstDims[0] && param.h == dstDims[2] && param.w == dstDims[3], return false,
+ "[TransData][GetTransDataMaxMinTmpSize] shapeInfo d,h,w is not matched.");
+ ASCENDC_HOST_ASSERT(dstDims[6] == c0 && dstDims[1] * c0 == AlignUp(param.c, c0), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] dst c0, c1 is not able to be converted to c.");
+ ASCENDC_HOST_ASSERT(dstDims[5] == n0 && dstDims[4] * n0 == AlignUp(param.n, n0), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] dst n0, n1 is not able to be converted to n.");
+ return true;
+}
+
+bool GenerateShapeInfo(const TransDataConfig &config, const ge::Shape &srcShape, const ge::Shape &dstShape, ge::DataType type,
+ TmpTransDataParams ¶m)
+{
+ (void)type;
+ constexpr int32_t c0 = 16, n0 = 16;
+ std::vector srcDims = srcShape.GetDims();
+ std::vector dstDims = dstShape.GetDims();
+ if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) {
+ ASCENDC_HOST_ASSERT(srcDims.size() == 5 && dstDims.size() == 6, return false,
+ "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat.");
+ param.n = srcDims[N_INDEX];
+ param.c = srcDims[C_INDEX];
+ param.d = srcDims[D_INDEX];
+ param.h = srcDims[H_INDEX];
+ param.w = srcDims[W_INDEX];
+ // validate n, d, h, w
+ ASCENDC_HOST_ASSERT(param.n == dstDims[0] && param.d == dstDims[1] && param.h == dstDims[3] && param.w == dstDims[4],
+ return false, "[TransData][GetTransDataMaxMinTmpSize] shapeInfo n,d,h,w is not matched.");
+ ASCENDC_HOST_ASSERT(dstDims[5] == c0 && dstDims[2] * c0 == AlignUp(param.c, c0), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] dst c0, c1 is not able to be converted to c.");
+ return true;
+ }
+ if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) {
+ return GenerateNcdhwToFractalZ3DShapeInfo(dstDims, srcDims, param, c0, n0);
+ }
+ if (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) {
+ return GenerateFractalZ3DToNcdhwShapeInfo(dstDims, srcDims, param, c0, n0);
+ }
+ if (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW) {
+ ASCENDC_HOST_ASSERT(srcDims.size() == 6 && dstDims.size() == 5, return false,
+ "[TransData][GetTransDataMaxMinTmpSize] input shapes are not matched with DataFormat.");
+ param.n = dstDims[N_INDEX];
+ param.c = dstDims[C_INDEX];
+ param.d = dstDims[D_INDEX];
+ param.h = dstDims[H_INDEX];
+ param.w = dstDims[W_INDEX];
+ // validate n, d, h, w
+ ASCENDC_HOST_ASSERT(param.n == srcDims[0] && param.d == srcDims[1] && param.h == srcDims[3] && param.w == srcDims[4],
+ return false, "[TransData][GetTransDataMaxMinTmpSize] shapeInfo n,d,h,w is not matched.");
+ ASCENDC_HOST_ASSERT(srcDims[5] == c0 && srcDims[2] * c0 == AlignUp(param.c, c0), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] src c0, c1 is not able to be converted to c.");
+ return true;
+ }
+ return false;
+}
+
+int32_t GetTmpBufferSize(const TransDataConfig &config, const TmpTransDataParams ¶m)
+{
+ constexpr int32_t dataSize = 2;
+ int32_t n = param.n, c = param.c, d = param.d, h = param.h, w = param.w;
+ constexpr int32_t c0 = 16, n0 = 16;
+ int32_t c1 = DivCeil(c, c0), n1 = DivCeil(n, n0);
+ int32_t padHw = AlignUp(h * w, 16);
+ if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0)
+ {
+ return d * padHw * dataSize + d * c1 * c0 * padHw * dataSize;
+ }
+ if (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW)
+ {
+ constexpr int32_t redundantDataBuffer = 512;
+ return d * c1 * c0 * padHw * dataSize + redundantDataBuffer;
+ }
+ if (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D)
+ {
+ return c * d * padHw * dataSize + n1 * n0 * d * c1 * c0 * padHw * dataSize;
+ }
+ if (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW)
+ {
+ constexpr int32_t doubleTmpSize = 2;
+ if (n == n0 * n1 && c == c0 * c1) {
+ return n1 * n0 * c1 * c0 * d * padHw * dataSize;
+ }
+ return n1 * n0 * c1 * c0 * d * padHw * dataSize * doubleTmpSize;
+ }
+ return 0;
+}
+} // namespace
+
+bool GetTransDataMaxMinTmpSize(const platform_ascendc::PlatformAscendC &platform,
+ const ge::Shape &srcShape,
+ const ge::Shape &dstShape,
+ const ge::DataType dataType,
+ const TransDataConfig &config,
+ uint32_t &maxValue, uint32_t &minValue)
+{
+ ASCENDC_HOST_ASSERT(dataType == ge::DataType::DT_FLOAT16 || dataType == ge::DataType::DT_BF16 ||
+ dataType == ge::DataType::DT_UINT16 || dataType == ge::DataType::DT_INT16, return false,
+ "[TransData][GetTransDataMaxMinTmpSize] it only supports DT_FLOAT16/DT_BF16/DT_UINT16/DT_INT16 data type");
+
+ ASCENDC_HOST_ASSERT(((config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::FRACTAL_Z_3D) ||
+ (config.srcFormat == DataFormat::FRACTAL_Z_3D && config.dstFormat == DataFormat::NCDHW) ||
+ (config.srcFormat == DataFormat::NCDHW && config.dstFormat == DataFormat::NDC1HWC0) ||
+ (config.srcFormat == DataFormat::NDC1HWC0 && config.dstFormat == DataFormat::NCDHW)), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] The parameter config srcFormat/dstFormat only supports "
+ "(NCDHW, FRACTAL_Z_3D)/(FRACTAL_Z_3D, NCDHW)/(NCDHW, NDC1HWC0)/(NDC1HWC0, NCDHW)!");
+
+ platform_ascendc::SocVersion socVersion = platform.GetSocVersion();
+ ASCENDC_HOST_ASSERT(socVersion == platform_ascendc::SocVersion::ASCEND910B, return false,
+ "[TransData][GetTransDataMaxMinTmpSize] Unsupported SocVersion for TransData API.");
+
+ TmpTransDataParams tmpParam;
+
+ ASCENDC_HOST_ASSERT(GenerateShapeInfo(config, srcShape, dstShape, dataType, tmpParam), return false,
+ "[TransData][GetTransDataMaxMinTmpSize] failed to validate inputs informations.");
+ maxValue = GetTmpBufferSize(config, tmpParam);
+ minValue = maxValue;
+ return true;
+}
+} // namespace AscendC
diff --git a/lib/transpose/transdata.h b/lib/transpose/transdata.h
new file mode 100644
index 0000000000000000000000000000000000000000..c0075cf5276392e5ca5399c7e122203059d393ca
--- /dev/null
+++ b/lib/transpose/transdata.h
@@ -0,0 +1,48 @@
+/**
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+ * See LICENSE in the root of the software repository for the full text of the License.
+ */
+#ifndef LIB_TRANSPOSE_TRANSDATA_H
+#define LIB_TRANSPOSE_TRANSDATA_H
+#if __CCE_AICORE__ == 220
+#include "transdata_common.h"
+#include "kernel_tensor.h"
+#include "kernel_operator_intf.h"
+#include "kernel_pop_stack_buffer.h"
+#include "../../impl/transpose/transdata/transdata_impl.h"
+#if ASCENDC_CPU_DEBUG
+#include "kernel_log.h"
+#include
+#endif
+
+namespace AscendC {
+
+template
+__aicore__ inline void TransData(const LocalTensor& dstTensor, const LocalTensor& srcTensor,
+ const LocalTensor& sharedTmpBuffer, const TransDataParams& params)
+{
+ Internal::TransDataImpl(dstTensor, srcTensor, sharedTmpBuffer, params);
+}
+
+template
+__aicore__ inline void TransData(const LocalTensor& dstTensor, const LocalTensor& srcTensor,
+ const TransDataParams& params)
+{
+ // Only for AI Vector Core.
+ if ASCEND_IS_AIC {
+ return;
+ }
+ LocalTensor tmp;
+ const bool ret = PopStackBuffer(tmp);
+ ASCENDC_ASSERT((ret), { KERNEL_LOG(KERNEL_ERROR, "PopStackBuffer Error!"); });
+
+ TransData(dstTensor, srcTensor, tmp, params);
+}
+} // namespace AscendC
+#endif
+#endif // LIB_TRANSPOSE_TRANSDATA_H
\ No newline at end of file
diff --git a/lib/transpose/transdata_common.h b/lib/transpose/transdata_common.h
new file mode 100644
index 0000000000000000000000000000000000000000..0421a3ca724cf4f6feaae1ab14e7f8fe3e89692d
--- /dev/null
+++ b/lib/transpose/transdata_common.h
@@ -0,0 +1,29 @@
+/**
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+ * See LICENSE in the root of the software repository for the full text of the License.
+ */
+#ifndef LIB_TRANSPOSE_TRANSDATA_COMMON_H
+#define LIB_TRANSPOSE_TRANSDATA_COMMON_H
+
+namespace AscendC {
+template
+struct TransDataParams {
+ T srcLayout;
+ U dstLayout;
+};
+
+#ifndef ASCC_PARAM_TRANSDATACONFIG
+#define ASCC_PARAM_TRANSDATACONFIG
+struct TransDataConfig {
+ DataFormat srcFormat;
+ DataFormat dstFormat;
+};
+#endif // ASCC_PARAM_TRANSDATACONFIG
+} // namespace AscendC
+
+#endif // LIB_TRANSPOSE_TRANSDATA_COMMON_H
\ No newline at end of file
diff --git a/lib/transpose/transdata_tiling.h b/lib/transpose/transdata_tiling.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2d722218a7c5556d30910c293a1f808c54b3d7e
--- /dev/null
+++ b/lib/transpose/transdata_tiling.h
@@ -0,0 +1,67 @@
+/**
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+ * See LICENSE in the root of the software repository for the full text of the License.
+ */
+
+/*!
+ * \file transdata_tiling.h
+ * \brief
+ */
+#ifndef LIB_TRANSPOSE_TRANSDATA_TILING_H
+#define LIB_TRANSPOSE_TRANSDATA_TILING_H
+#include
+#include "graph/tensor.h"
+#include "tiling/platform/platform_ascendc.h"
+
+namespace AscendC {
+/*
+ * @brief DataFormat
+*/
+#ifndef ASCC_ENUM_DATAFORMAT
+#define ASCC_ENUM_DATAFORMAT
+enum class DataFormat : uint8_t {
+ ND = 0,
+ NZ,
+ NCHW,
+ NC1HWC0,
+ NHWC,
+ NCDHW,
+ NDC1HWC0,
+ FRACTAL_Z_3D,
+};
+#endif // ASCC_ENUM_DATAFORMAT
+
+#ifndef ASCC_PARAM_TRANSDATACONFIG
+#define ASCC_PARAM_TRANSDATACONFIG
+struct TransDataConfig {
+ DataFormat srcFormat;
+ DataFormat dstFormat;
+};
+#endif // ASCC_PARAM_TRANSDATACONFIG
+
+/*!
+ * \brief This interface is used to obtain the maximum and minimum temporary space reserved or applied.
+ * The developer selects a proper space size based on this range as the tiling parameter.
+ *
+ * \param [in] platform, targeted platform information
+ * \param [in] srcShape, src tensor shape
+ * \param [in] dstShape, src tensor shape
+ * \param [in] dataType, actual data type of the input
+ * \param [in] config, transdata config
+ * \param [out] maxValue, maximum temporary space required
+ * \param [out] minValue, minimum temporary space required
+ * \return whether get the max/min value successfully
+ */
+bool GetTransDataMaxMinTmpSize(const platform_ascendc::PlatformAscendC &platform,
+ const ge::Shape &srcShape,
+ const ge::Shape &dstShape,
+ const ge::DataType dataType,
+ const TransDataConfig &config,
+ uint32_t &maxValue, uint32_t &minValue);
+} // AscendC
+#endif // LIB_TRANSPOSE_TRANSDATA_TILING_H
\ No newline at end of file
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index ccf15d0139ab4d86cbb3d296ac2206d0303a62a1..10b2f42cfd9948141b6c34892b5dd9e5452c17dc 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -154,6 +154,7 @@ file(GLOB ASCENDC_TEST_ascend910B1_AIV_CASE_SRC_FILES
${ASCENDC_TESTS_DIR}/normalization/welfordfinalize/test_operator_welfordfinalize.cpp
${ASCENDC_TESTS_DIR}/utils/init_global_memory/test_operator_init_global_memory.cpp
${ASCENDC_TESTS_DIR}/normalization/layernormV2/test_operator_layernormV2.cpp
+ ${ASCENDC_TESTS_DIR}/transpose/transdata/*cpp
)
# ascend910B1 aic test cases
@@ -406,6 +407,7 @@ file(GLOB ASCENDC_TILING_SRC_FILES
${ASCENDC_API_DIR}/impl/sort/topk/*.cpp
${ASCENDC_API_DIR}/impl/reduce/reduce_tiling.cpp
${ASCENDC_API_DIR}/impl/normalization/layernormV2/*.cpp
+ ${ASCENDC_API_DIR}/impl/transpose/transdata/transdata_tiling.cpp
)
# ascendc_tiling_utest
diff --git a/tests/tiling/test_tiling.cpp b/tests/tiling/test_tiling.cpp
index 34be9f4c0e5ef4ed13e6b76d1bf571026340b59e..73dc3e4b174625d4549a3315b10fb996471699ca 100644
--- a/tests/tiling/test_tiling.cpp
+++ b/tests/tiling/test_tiling.cpp
@@ -32,7 +32,7 @@ protected:
void TearDown() {}
};
-
+extern void platfrom_stub_set_chip_version(const char *num);
TEST_F(TestTiling, MultiCoreSmallMN)
{
matmul_tiling::MultiCoreMatmulTiling rnnMatmul3,rnnMatmul4,rnnMatmul5;
@@ -5418,6 +5418,106 @@ TEST_F(TestTiling, TestOneElementBroadCast200)
}
#endif
+TEST_F(TestTiling, testTransDataTilingUnalignedHw)
+{
+ platfrom_stub_set_chip_version("Ascend910B");
+ uint32_t maxSize;
+ uint32_t minSize;
+ int32_t n = 16;
+ int32_t c = 16;
+ int32_t d = 3;
+ int32_t h = 3;
+ int32_t w = 3;
+ int32_t c0 = 16;
+ int32_t n0 = 16;
+ int32_t c1 = (c + c0 - 1) / c0;
+ int32_t n1 = (n + n0 - 1) / n0;
+ int32_t hw0 = 16;
+ int32_t hw1 = (h * w + hw0 - 1) / hw0;
+ auto ncdhwShape = ge::Shape({ n, c, d, h, w });
+ auto ndc1hwc0Shape = ge::Shape({ n, d, c1, h, w, c0});
+ auto fractalzShape = ge::Shape({ d, c1, h, w, n1, n0, c0});
+ fe::PlatFormInfos platform_info;
+ auto plat = platform_ascendc::PlatformAscendC(&platform_info);
+ TransDataConfig config = {DataFormat::NCDHW, DataFormat::NDC1HWC0};
+ bool ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, ndc1hwc0Shape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, 1632);
+ EXPECT_EQ(minSize, 1632);
+
+ config = {DataFormat::NDC1HWC0, DataFormat::NCDHW};
+ ret = GetTransDataMaxMinTmpSize(plat, ndc1hwc0Shape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, 2048);
+ EXPECT_EQ(minSize, 2048);
+
+ config = {DataFormat::NCDHW, DataFormat::FRACTAL_Z_3D};
+ ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, fractalzShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, 26112);
+ EXPECT_EQ(minSize, 26112);
+
+ config = {DataFormat::FRACTAL_Z_3D, DataFormat::NCDHW};
+ ret = GetTransDataMaxMinTmpSize(plat, fractalzShape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2);
+ EXPECT_EQ(minSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2);
+}
+
+TEST_F(TestTiling, testTransDataTilingAlignedHw)
+{
+ platfrom_stub_set_chip_version("Ascend910B");
+ uint32_t maxSize;
+ uint32_t minSize;
+ int32_t n = 5;
+ int32_t c = 30;
+ int32_t d = 2;
+ int32_t h = 4;
+ int32_t w = 8;
+ int32_t c0 = 16;
+ int32_t n0 = 16;
+ int32_t c1 = (c + c0 - 1) / c0;
+ int32_t n1 = (n + n0 - 1) / n0;
+ int32_t hw0 = 16;
+ int32_t hw1 = (h * w + hw0 - 1) / hw0;
+ auto ncdhwShape = ge::Shape({ n, c, d, h, w });
+ auto ndc1hwc0Shape = ge::Shape({ n, d, c1, h, w, c0});
+ auto fractalzShape = ge::Shape({ d, c1, h, w, n1, n0, c0});
+ fe::PlatFormInfos platform_info;
+ auto plat = platform_ascendc::PlatformAscendC(&platform_info);
+ TransDataConfig config = {DataFormat::NCDHW, DataFormat::NDC1HWC0};
+ bool ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, ndc1hwc0Shape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, 4224);
+ EXPECT_EQ(minSize, 4224);
+
+ config = {DataFormat::NDC1HWC0, DataFormat::NCDHW};
+ ret = GetTransDataMaxMinTmpSize(plat, ndc1hwc0Shape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, 4608);
+ EXPECT_EQ(minSize, 4608);
+
+ config = {DataFormat::NCDHW, DataFormat::FRACTAL_Z_3D};
+ ret = GetTransDataMaxMinTmpSize(plat, ncdhwShape, fractalzShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, 69376);
+ EXPECT_EQ(minSize, 69376);
+
+ config = {DataFormat::FRACTAL_Z_3D, DataFormat::NCDHW};
+ ret = GetTransDataMaxMinTmpSize(plat, fractalzShape, ncdhwShape, ge::DataType::DT_FLOAT16, config, maxSize, minSize);
+
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(maxSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2 * 2);
+ EXPECT_EQ(minSize, n1 * n0 * c1 * c0 * d * hw0 * hw1 * 2 * 2);
+}
+
TEST_F(TestTiling, TestReduceXorSumTilingInt16)
{
std::vector shapeDims = { 128, 128 };
diff --git a/tests/transpose/transdata/test_operator_transdata.cpp b/tests/transpose/transdata/test_operator_transdata.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d50408a97e8ab49666de9d50fb3fb65d5d74e982
--- /dev/null
+++ b/tests/transpose/transdata/test_operator_transdata.cpp
@@ -0,0 +1,267 @@
+/**
+ * Copyright (c) 2025 Huawei Technologies Co., Ltd.
+ * This file is a part of the CANN Open Software.
+ * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
+ * Please refer to the License for details. You may not use this file except in compliance with the License.
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
+ * See LICENSE in the root of the software repository for the full text of the License.
+ */
+#include
+#include "kernel_operator.h"
+
+#include
+#include
+
+namespace AscendC {
+
+namespace {
+
+constexpr uint32_t NCDHW_FractalZ3D = 1;
+constexpr uint32_t FractalZ3D_NCDHW = 2;
+constexpr uint32_t NCDHW_NDC1HWC0 = 3;
+constexpr uint32_t NDC1HWC0_NCDHW = 4;
+
+
+constexpr TransDataConfig config1 = {DataFormat::NCDHW, DataFormat::FRACTAL_Z_3D};
+constexpr TransDataConfig config2 = {DataFormat::FRACTAL_Z_3D, DataFormat::NCDHW};
+constexpr TransDataConfig config3 = {DataFormat::NCDHW, DataFormat::NDC1HWC0};
+constexpr TransDataConfig config4 = {DataFormat::NDC1HWC0, DataFormat::NCDHW};
+
+}
+
+template
+class KernelTransData {
+public:
+__aicore__ inline KernelTransData() {}
+__aicore__ inline void Init(GM_ADDR srcGm, GM_ADDR dstGm,
+ int32_t n, int32_t c, int32_t d, int32_t h, int32_t w, TPipe *tpipe)
+{
+ this->d = d;
+ this->c = c;
+ this->h = h;
+ this->w = w;
+ this->n = n;
+ this->c1 = (c + c0 - 1) / c0;
+ this->n1 = (n + n0 - 1) / n0;
+ this->hw1 = (h*w + hw0 - 1) / hw0;
+
+ if (mode == NDC1HWC0_NCDHW) {
+ this->srcShapeSize = n * c1 * c0 * d * h * w;
+ this->dstShapeSize = n * d * c * hw0;
+ this->tmpShapeSize = 512 + d * c1 * c0 * hw0 * hw1;
+ uint32_t dstGmSize = n * c * d * h * w;
+ srcGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGm), srcShapeSize * sizeof(T));
+ dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGm), dstGmSize * sizeof(T));
+ } else {
+ if constexpr (mode == NCDHW_FractalZ3D) {
+ srcShapeSize = n * d * c * hw0 * hw1;
+ dstShapeSize = n1 * n0 * c1 * c0 * d * h * w;
+ if ((h*w) % 16 != 0 ) {
+ needPad = true;
+ dstShapeSize = n1 * n0 * c1 * c0 * d * hw0 * hw1;
+ }
+ tmpShapeSize = c * d * hw0 * hw1 + n1 * n0 * d * c1 * c0 * hw0 * hw1;
+ } else if constexpr (mode == FractalZ3D_NCDHW) {
+ this->srcShapeSize = d * c1 * h * w * n1 * n0 * c0;
+ this->dstShapeSize = n * c * d * (hw1 * hw0);
+ this->tmpShapeSize = d * c1 * (hw1 * hw0) * n1 * n0 * c0 * 2;
+ } else if constexpr (mode == NCDHW_NDC1HWC0) {
+ this->srcShapeSize = n * d * c * hw0;
+ this->dstShapeSize = n * c1 * c0 * d * h * w;
+ this->tmpShapeSize = d * hw0 * hw1 + d * c1 * c0 * hw0 * hw1;
+ }
+ srcGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(srcGm), srcShapeSize * sizeof(T));
+ dstGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(dstGm), dstShapeSize * sizeof(T));
+ }
+
+
+ this->pipe = tpipe;
+ pipe->InitBuffer(inQueue, 1, srcShapeSize * sizeof(T));
+ pipe->InitBuffer(outQueue, 1, dstShapeSize * sizeof(T));
+ pipe->InitBuffer(tmpBuf, tmpShapeSize * sizeof(T));
+
+}
+__aicore__ inline void Process()
+{
+ CopyIn();
+ Compute();
+ CopyOut();
+}
+
+private:
+__aicore__ inline void CopyIn()
+{
+ LocalTensor srcLocal = inQueue.AllocTensor();
+ if constexpr (mode == NCDHW_FractalZ3D || mode == NCDHW_NDC1HWC0) {
+ DataCopyExtParams extParam = {static_cast(n * c * d),
+ static_cast(h * w * sizeof(T)), 0, 0, 0};
+ DataCopyPadExtParams padParam = {true, 0, 0, 0};
+ if (needPad) {
+ DataCopyPad(srcLocal, srcGlobal, extParam, padParam);
+ } else {
+ DataCopy(srcLocal, srcGlobal, srcShapeSize);
+ }
+ } else if constexpr (mode == FractalZ3D_NCDHW || mode == NDC1HWC0_NCDHW) {
+ DataCopy(srcLocal, srcGlobal, srcShapeSize);
+ }
+
+ inQueue.EnQue(srcLocal);
+}
+__aicore__ inline void Compute()
+{
+ LocalTensor dstLocal = outQueue.AllocTensor();
+ LocalTensor tmp = tmpBuf.Get();
+ LocalTensor srcLocal = inQueue.DeQue();
+ PipeBarrier();
+
+ Layout ncdhwLayout = MakeLayout(MakeShape(n, c, d, h, w), MakeStride());
+ Layout ndc1hwc0Layout = MakeLayout(MakeShape(n, d, c1, h, w, c0), MakeStride());
+ Layout fractalLayout = MakeLayout(MakeShape(d, c1, h, w, n1, n0, c0), MakeStride());
+
+ if constexpr (mode == NCDHW_FractalZ3D) {
+ TransDataParams params = {ncdhwLayout, fractalLayout};
+ TransData(dstLocal, srcLocal, tmp, params);
+ } else if constexpr (mode == FractalZ3D_NCDHW) {
+ TransDataParams params = {fractalLayout, ncdhwLayout};
+ TransData(dstLocal, srcLocal, tmp, params);
+ } else if constexpr (mode == NCDHW_NDC1HWC0) {
+ TransDataParams params = {ncdhwLayout, ndc1hwc0Layout};
+ TransData(dstLocal, srcLocal, tmp, params);
+ } else if constexpr (mode == NDC1HWC0_NCDHW) {
+ TransDataParams params = {ndc1hwc0Layout, ncdhwLayout};
+ TransData(dstLocal, srcLocal, tmp, params);
+ }
+
+ outQueue.EnQue(dstLocal);
+ inQueue.FreeTensor(srcLocal);
+
+}
+__aicore__ inline void CopyOut()
+{
+ LocalTensor dstLocal = outQueue.DeQue();
+ DataCopyExtParams extParam {static_cast(n * c * d), static_cast(h*w*sizeof(T)), 0, 0, 0};
+ if constexpr (mode == NCDHW_FractalZ3D) {
+ DataCopy(dstGlobal, dstLocal, n1 * n0 * c1);
+ } else if constexpr (mode == FractalZ3D_NCDHW) {
+ DataCopy(dstGlobal, dstLocal, dstShapeSize);
+ } else if constexpr (mode == NCDHW_NDC1HWC0) {
+ DataCopy(dstGlobal, dstLocal, dstShapeSize);
+ } else if constexpr (mode == NDC1HWC0_NCDHW) {
+ DataCopyPad(dstGlobal, dstLocal, extParam);
+ }
+ outQueue.FreeTensor(dstLocal);
+}
+
+private:
+ GlobalTensor srcGlobal;
+ GlobalTensor dstGlobal;
+ TPipe *pipe;
+ TQue inQueue;
+ TQue outQueue;
+ TBuf tmpBuf;
+ bool needPad = false;
+ int32_t n = 0;
+ int32_t c = 0;
+ int32_t d = 0;
+ int32_t h = 0;
+ int32_t w = 0;
+ int32_t n1 = 0;
+ int32_t c1 = 0;
+ int32_t hw1 = 0;
+ int32_t c0 = 16;
+ int32_t n0 = 16;
+ int32_t hw0 = 16;
+ uint32_t srcShapeSize = 0;
+ uint32_t dstShapeSize = 0;
+ uint32_t tmpShapeSize = 0;
+};
+} // namespace AscendC
+
+template
+__global__ __aicore__ void MainTransdata(
+ __gm__ uint8_t* dstGm, __gm__ uint8_t* srcGm, uint64_t n, uint64_t c, uint64_t d, uint64_t h, uint64_t w)
+{
+ if (g_coreType == AscendC::AIC || AscendC::GetBlockIdx() > 0) {
+ return;
+ }
+ AscendC::TPipe pipe;
+ AscendC::KernelTransData op;
+ op.Init(srcGm, dstGm, n, c, d, h, w, &pipe);
+ op.Process();
+}
+
+struct TransDataTestParams {
+ int32_t n;
+ int32_t c;
+ int32_t d;
+ int32_t h;
+ int32_t w;
+ uint32_t mode;
+ void (*cal_func)(uint8_t*, uint8_t*, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t);
+};
+
+class TransDataTestsuite : public testing::Test, public testing::WithParamInterface {
+protected:
+ void SetUp()
+ {
+ AscendC::SetGCoreType(2);
+ }
+
+ void TearDown()
+ {
+ AscendC::SetGCoreType(0);
+ }
+};
+
+INSTANTIATE_TEST_CASE_P(TEST_OPERATTION_TRANSDATA, TransDataTestsuite,
+ ::testing::Values(
+ TransDataTestParams { 5, 32, 2, 1, 16, 1, MainTransdata },
+ TransDataTestParams { 4, 31, 1, 6, 7, 2, MainTransdata },
+ TransDataTestParams { 4, 20, 2, 3, 1, 3, MainTransdata },
+ TransDataTestParams { 8, 14, 2, 1, 16, 4, MainTransdata },
+ TransDataTestParams { 5, 32, 2, 1, 16, 1, MainTransdata },
+ TransDataTestParams { 4, 31, 1, 6, 7, 2, MainTransdata },
+ TransDataTestParams { 4, 20, 2, 3, 1, 3, MainTransdata },
+ TransDataTestParams { 8, 14, 2, 1, 16, 4, MainTransdata }
+
+ ));
+
+TEST_P(TransDataTestsuite, TransDataOpTestCase)
+{
+ auto params = GetParam();
+ auto n = params.n;
+ auto c = params.c;
+ auto d = params.d;
+ auto h = params.h;
+ auto w = params.w;
+ auto mode = params.mode;
+ uint32_t srcShapeSize;
+ uint32_t dstShapeSize;
+ int32_t hw0 = 16;
+ int32_t hw1 = (h * w + hw0 - 1) / hw0;
+ int32_t c0 = 16;
+ int32_t n0 = 16;
+ int32_t c1 = (c + c0 - 1) / c0;
+ int32_t n1 = (n + n0 - 1) / n0;
+ if (mode == 1) {
+ srcShapeSize = n * d * c * hw0 * hw1;
+ dstShapeSize = n1 * n0 * c1 * c0 * d * h * w;
+ if ((h*w) % 16 != 0 ) {
+ dstShapeSize = n1 * n0 * c1 * c0 * d * hw0 * hw1;
+ }
+ } else if (mode == 2) {
+ srcShapeSize = d * c1 * h * w * n1 * n0 * c0;
+ dstShapeSize = n * c * d * (hw1 * hw0);
+ } else if (mode == 3) {
+ srcShapeSize = n * d * c * hw0;
+ dstShapeSize = n * c1 * c0 * d * h * w;
+ } else if (mode == 4) {
+ srcShapeSize = n * c1 * c0 * d * h * w;
+ dstShapeSize = n * d * c * hw0;
+ }
+ uint8_t srcGm[srcShapeSize * sizeof(half)] = {0}; // 外部保证inner是32B对齐
+ uint8_t dstGm[dstShapeSize * sizeof(half)] = {0};
+ params.cal_func(dstGm, srcGm, n, c, d, h, w);
+ EXPECT_EQ(dstGm[0], 0);
+}