From d5a31a3e1f5b2e2a90f3e88ee851dcc033990c4b Mon Sep 17 00:00:00 2001 From: wang-xiangX Date: Tue, 16 Jul 2024 15:06:35 +0800 Subject: [PATCH] dequant support V200 --- .../dequant/ascend_dequant_common_impl.h | 13 +++ lib/quantization/ascend_dequant.h | 2 +- tests/CMakeLists.txt | 1 + .../dequant/test_operator_dequant_v200.cpp | 94 +++++++++++++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 tests/quantization/dequant/test_operator_dequant_v200.cpp diff --git a/impl/quantization/dequant/ascend_dequant_common_impl.h b/impl/quantization/dequant/ascend_dequant_common_impl.h index 14c399cc..1ec0eb9e 100644 --- a/impl/quantization/dequant/ascend_dequant_common_impl.h +++ b/impl/quantization/dequant/ascend_dequant_common_impl.h @@ -97,16 +97,25 @@ __aicore__ inline constexpr bool IsTemplateValid() // dst dtype: half , float, float, bfloat16_t constexpr bool isValid1 = (IsSameType::value) && (IsSameType::value); constexpr bool isValid2 = (IsSameType::value) && (IsSameType::value); +#if defined(__CCE_AICORE__) && (__CCE_AICORE__ == 200) + return isValid1 || isValid2; +#else constexpr bool isValid3 = (IsSameType::value) && (IsSameType::value); constexpr bool isValid4 = (IsSameType::value) && (IsSameType::value); return isValid1 || isValid2 || isValid3 || isValid4; +#endif } else { // dtype only support deqScale dtype: bfloat16_t, bfloat16_t, float // dst dtype: bfloat16_t, float, float +#if defined(__CCE_AICORE__) && (__CCE_AICORE__ == 200) + constexpr bool isValid1 = (IsSameType::value) && (IsSameType::value); + return isValid1; +#else constexpr bool isValid1 = (IsSameType::value) && (IsSameType::value); constexpr bool isValid2 = (IsSameType::value) && (IsSameType::value); constexpr bool isValid3 = (IsSameType::value) && (IsSameType::value); return isValid1 || isValid2 || isValid3; +#endif } } @@ -150,8 +159,12 @@ __aicore__ inline void AscendDequantTmpCalc(const LocalTensor& srcTenso template __aicore__ inline RoundMode GetFP32CastMode() { +#if defined(__CCE_AICORE__) && (__CCE_AICORE__ == 200) + return RoundMode::CAST_NONE; +#else constexpr RoundMode castMode = IsSameType::value ? RoundMode::CAST_RINT: RoundMode::CAST_NONE; return castMode; +#endif } // Update dqParams if format is {1, m*n, n} diff --git a/lib/quantization/ascend_dequant.h b/lib/quantization/ascend_dequant.h index c79a57bc..11f7790d 100644 --- a/lib/quantization/ascend_dequant.h +++ b/lib/quantization/ascend_dequant.h @@ -14,7 +14,7 @@ */ #ifndef LIB_QUANTIZATION_ASCEND_DEQUANT_H #define LIB_QUANTIZATION_ASCEND_DEQUANT_H -#if __CCE_AICORE__ == 220 +#if defined(__CCE_AICORE__) && (__CCE_AICORE__ == 220 || __CCE_AICORE__ == 200) #include "kernel_tensor.h" #include "../../impl/quantization/dequant/ascend_dequant_common_impl.h" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8bce851b..93bc4d5b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -48,6 +48,7 @@ file(GLOB ASCENDC_TEST_ascend310p_CASE_SRC_FILES ${ASCENDC_TESTS_DIR}/activation/sigmoid/test_operator_vec_sigmoid.cpp ${ASCENDC_TESTS_DIR}/quantization/antiquant/test_ascend_quant.cpp ${ASCENDC_TESTS_DIR}/quantization/antiquant/test_ascend_quant_per_channel.cpp + ${ASCENDC_TESTS_DIR}/quantization/dequant/test_operator_dequant_v200.cpp ) # ascend910B1 aiv test cases diff --git a/tests/quantization/dequant/test_operator_dequant_v200.cpp b/tests/quantization/dequant/test_operator_dequant_v200.cpp new file mode 100644 index 00000000..b5deec79 --- /dev/null +++ b/tests/quantization/dequant/test_operator_dequant_v200.cpp @@ -0,0 +1,94 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "kernel_operator.h" + +using namespace std; +using namespace AscendC; + +constexpr uint32_t FLOAT_PER_BLOCK = 8; +constexpr uint32_t DTYPE16_PER_BLOCK = 16; // half + +template +void AscendDequantKernel(__gm__ uint8_t* __restrict__ srcGm, __gm__ uint8_t* __restrict__ dstGm, + __gm__ uint8_t* __restrict__ deqScaleGm, __gm__ int32_t dataSize) +{ + TPipe tpipe; + TQue vecQue; + TQue vecDeqQue; + TQue vecOutQue; + GlobalTensor inputGlobal; + GlobalTensor outputGlobal; + GlobalTensor deqScaleGlobal; + inputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(srcGm), dataSize); + outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dstT*>(dstGm), dataSize); + deqScaleGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ scaleT*>(deqScaleGm), dataSize); + tpipe.InitBuffer(vecQue, 1, dataSize * sizeof(int32_t)); + tpipe.InitBuffer(vecDeqQue, 1, dataSize * sizeof(scaleT)); + tpipe.InitBuffer(vecOutQue, 1, dataSize * sizeof(dstT)); + LocalTensor inputLocal = vecQue.AllocTensor(); + + LocalTensor deqScaleLocal = vecDeqQue.AllocTensor(); + + LocalTensor outputLocal = vecOutQue.AllocTensor(); + + DataCopy(inputLocal, inputGlobal, dataSize); + PipeBarrier(); + DataCopy(deqScaleLocal, deqScaleGlobal, dataSize); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + AscendDequant(outputLocal, inputLocal, deqScaleLocal); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + DataCopy(outputGlobal, outputLocal, dataSize); + PipeBarrier(); + vecQue.FreeTensor(inputLocal); + vecDeqQue.FreeTensor(deqScaleLocal); + vecOutQue.FreeTensor(outputLocal); +} + +struct AscendDequantTestParams { + int32_t dataSize; + int32_t dataBitSize; + void (*calFunc)(uint8_t*, uint8_t*, uint8_t*, int32_t); +}; + +class AscendDequantTestsuite : public testing::Test, public testing::WithParamInterface { +protected: + void SetUp() {} + void TearDown() {} +}; + +INSTANTIATE_TEST_CASE_P(TEST_ASCEND_DEQUANT, AscendDequantTestsuite, + ::testing::Values(AscendDequantTestParams { 256, 4, AscendDequantKernel }, + AscendDequantTestParams { 512, 4, AscendDequantKernel }, + AscendDequantTestParams { 24, 4, AscendDequantKernel }, + AscendDequantTestParams { 24, 4, AscendDequantKernel }, + AscendDequantTestParams { 256, 4, AscendDequantKernel }, + AscendDequantTestParams { 512, 4, AscendDequantKernel }, + AscendDequantTestParams { 24, 4, AscendDequantKernel }, + AscendDequantTestParams { 24, 4, AscendDequantKernel } + )); + +TEST_P(AscendDequantTestsuite, AscendDequantTestCase) +{ + auto param = GetParam(); + uint8_t srcGm[param.dataSize * sizeof(int32_t)] = {0}; + uint8_t dstGm[param.dataSize * sizeof(half)] = {0}; + uint8_t deqScaleGm[param.dataSize * sizeof(uint64_t)] = {0}; + + param.calFunc(srcGm, dstGm, deqScaleGm, param.dataSize); + for (int32_t i = 0; i < param.dataSize; i++) { + EXPECT_EQ(dstGm[i], 0x00); + } +} \ No newline at end of file -- Gitee