From 66316b560884db252b4e07a46968e0dd5abc8664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=98=AE=E8=B5=A2=E6=B4=8B?= Date: Fri, 20 Dec 2024 10:30:32 +0800 Subject: [PATCH] fix matmul api constant bug --- ...tCustom.json => MatmulApiConstantCustom.json} | 0 .../op_kernel/matmul_api_constant_custom.cpp | 16 ++++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) rename operator/ascendc/2_features/14_matmul_api_constant/{MatmulApiContantCustom.json => MatmulApiConstantCustom.json} (100%) diff --git a/operator/ascendc/2_features/14_matmul_api_constant/MatmulApiContantCustom.json b/operator/ascendc/2_features/14_matmul_api_constant/MatmulApiConstantCustom.json similarity index 100% rename from operator/ascendc/2_features/14_matmul_api_constant/MatmulApiContantCustom.json rename to operator/ascendc/2_features/14_matmul_api_constant/MatmulApiConstantCustom.json diff --git a/operator/ascendc/2_features/14_matmul_api_constant/MatmulApiConstantCustom/op_kernel/matmul_api_constant_custom.cpp b/operator/ascendc/2_features/14_matmul_api_constant/MatmulApiConstantCustom/op_kernel/matmul_api_constant_custom.cpp index c95abf299..4d1a85455 100644 --- a/operator/ascendc/2_features/14_matmul_api_constant/MatmulApiConstantCustom/op_kernel/matmul_api_constant_custom.cpp +++ b/operator/ascendc/2_features/14_matmul_api_constant/MatmulApiConstantCustom/op_kernel/matmul_api_constant_custom.cpp @@ -17,13 +17,6 @@ __aicore__ inline uint32_t Ceiling(uint32_t a, uint32_t b) return (a + b - 1) / b; } -// The specified value remains consistent with the runtime tiling paramters. -// singleCoreM, singleCoreN, singleCoreK, baseM, baseN, baseK. -constexpr static MatmulShapeParams shapeParams = {512, 640, 256, 128, 128, 128}; -constexpr static MatmulConfig mmConfig = GetMMConfig(shapeParams); -// Get the fully constant template parameters. -constexpr static MatmulApiStaticTiling staticConfig = GetMatmulApiTiling(mmConfig); - template class MatmulApiConstantKernel { public: __aicore__ inline MatmulApiConstantKernel(){}; @@ -39,6 +32,13 @@ public: typedef MatmulType cMatmulType; typedef MatmulType biasMatmulType; + // The specified value remains consistent with the runtime tiling paramters. + // singleCoreM, singleCoreN, singleCoreK, baseM, baseN, baseK. + constexpr static MatmulShapeParams shapeParams = {512, 640, 256, 128, 128, 128}; + constexpr static MatmulConfig mmConfig = GetMMConfig(shapeParams); + // Get the fully constant template parameters. + constexpr static MatmulApiStaticTiling staticConfig = GetMatmulApiTiling(mmConfig); + Matmul matmulObj; AscendC::GlobalTensor aGlobal; @@ -111,7 +111,7 @@ extern "C" __global__ __aicore__ void matmul_api_constant_custom( AscendC::TPipe pipe; // With the fully constant template parameters, nullptr can be passed into REGIST_MATMUL_OBJ to replace tiling. - REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), matmulApiConstantKernel.matmulObj, nullptr); + REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), matmulApiConstantKernel.matmulObj, &tilingData.cubeTilingData); matmulApiConstantKernel.Init(a, b, bias, c, workspace, tilingData.cubeTilingData); matmulApiConstantKernel.Process(&pipe); } \ No newline at end of file -- Gitee