diff --git a/atvc/include/common/atvc_op_check.h b/atvc/include/common/atvc_op_check.h index 06ae33e11bb27c4e9566c40ca6f0c07158dc27d4..095e6b1507be7e9122c0ec1c15477e24467edb91 100644 --- a/atvc/include/common/atvc_op_check.h +++ b/atvc/include/common/atvc_op_check.h @@ -63,7 +63,8 @@ bool DebugCheck() { if constexpr (templateType == ATVC::TemplateType::REDUCE || templateType == ATVC::TemplateType::BROADCAST) { if (!CheckSameDtype_()) { - printf("[ERROR]: Different input/output data types is not surpport in Reduce or Broadcast template.\n"); + printf("[ERROR]: [ATVC][OpTraits] Different input/output data types is not support " + "in Reduce or Broadcast template.\n"); return false; } } diff --git a/atvc/include/elewise/elewise_host.h b/atvc/include/elewise/elewise_host.h index 6b4d221fd534ee3d97a0ed288a271c0701b04efe..ca91426ecca4f45121bcce96432ddab9cd14a9e6 100644 --- a/atvc/include/elewise/elewise_host.h +++ b/atvc/include/elewise/elewise_host.h @@ -58,7 +58,7 @@ bool CheckEleWiseHyperParam(const EleWiseTilingHyperParam &hyperParam) "ubSizeLimitThreshold(%f) must be in [0.5, 0.96].\n", hyperParam.ubSizeLimitThreshold); return false; } - if(hyperParam.nBufferNum > MAX_BUF_NUM || hyperParam.singleCoreBaseLine < MIN_BUF_NUM) { + if(hyperParam.nBufferNum > MAX_BUF_NUM || hyperParam.nBufferNum < MIN_BUF_NUM) { printf("[ERROR]: [ATVC][EleWise] Tiling hyperParam is invalid: nBufferNum(%u) must be in [1, 2].\n", hyperParam.nBufferNum); return false; @@ -96,9 +96,6 @@ int32_t GetEleWiseBasicCnt(const EleWiseTilingHyperParam &hyperParam, if (blockNum == 0) { return 0; } - if (!CheckEleWiseHyperParam(hyperParam)) { - return 0; - } uint32_t avgElePerBlock = totalCnt / blockNum; for (uint32_t i =0; i < MAX_SHAPE_NODE; i++) { if (avgElePerBlock <= hyperParam.splitDataShape[i]) { @@ -132,6 +129,9 @@ template bool CalcEleWiseTiling(int32_t totalCnt, ATVC::EleWiseParam ¶m, EleWiseTilingHyperParam hyperParam = EleWiseTilingHyperParam()) { + if (!CheckEleWiseHyperParam(hyperParam)) { + return false; + } using Inputs = typename OpTraits::In::types; using Outputs = typename OpTraits::Out::types; using Temps = typename OpTraits::Temp::types;